PyTorch CUDA Graph Capture
Introduction
CUDA graph is a useful feature for optimizing GPU system performance by reducing CPU overhead for launching GPU kernels. It is especially useful when the GPU workload is small and the CPU overhead for launching GPU kernels becomes a system performance bottleneck. The NVIDIA native CUDA Graph APIs cannot be used directly for PyTorch programs, as PyTorch has its own dynamic memory management and execution model. PyTorch provides two main APIs for capturing and replaying CUDA graphs, torch.cuda.graph and torch.cuda.make_graphed_callables, that convert the dynamic memory management and execution model of PyTorch programs into static ones.
In this blog post, I would like to discuss how to use these two APIs to capture and replay CUDA graphs in PyTorch, what are the differences between them, and how they can help improve the performance of PyTorch models in different scenarios.
PyTorch CUDA Graph Capture
PyTorch exposes graphs via a raw torch.cuda.CUDAGraph class and two convenience wrappers, torch.cuda.graph and torch.cuda.make_graphed_callables. They are useful for capturing and replaying CUDA graphs in slightly different scenarios. The examples below demonstrate how to use these two APIs to capture and replay CUDA graphs for training a simple MLP model.
torch.cuda.graph
Using the torch.cuda.graph API, we will have to manually manage the warmup, static buffers, graph capture and replay. This provides full control over what operations are included in the graph, even allowing us to capture the complete training step including loss computation and optimizer updates.
In the following example, with torch.cuda.graph, each entire training iteration is invoked as a single graph replay and there is no synchronization between the host and device during the entire training process.
1 | #!/usr/bin/env python3 |
1 | $ python torch_cuda_graph_manual_capture.py |
torch.cuda.make_graphed_callables
The torch.cuda.make_graphed_callables API simplifies CUDA graph usage by automatically handling warmup, static buffers, graph capture, and replay. It also allows more fine-grained control over what operations are included in the graph by wrapping individual callables (like models or submodules) and graphing their forward and backward operations. Compared to using the torch.cuda.graph API, it leaves loss computation and optimizer steps outside the graph. Consequently, its CPU overhead is slightly higher due to submitting more CUDA operations and graph replays in one training iteration.
1 | #!/usr/bin/env python3 |
1 | $ python torch_cuda_graph_make_graphed_callables.py |
Summary
The following table summarizes the CPU wall clock time for training a model with different levels of CUDA graph integration using different PyTorch APIs. The profiling traces were collected using torch.profiler and could be downloaded and viewed in Perfetto.
| CUDA Graph | API | Self CPU Time Total | Profiling Trace |
|---|---|---|---|
| No CUDA Graph | N/A | 24.142 ms | Trace |
| Graph: Full Model Forward + Loss + Full Model Backward + Optimizer | torch.cuda.graph |
8.690 ms | Trace |
| No CUDA Graph | N/A | 22.641 ms | Trace |
| Graph 0: Full Model Forward, Graph 1: Full Model Backward | torch.cuda.make_graphed_callables |
11.450 ms | Trace |
| Graph 0: Submodule Forward, Graph 1: Submodule Backward | torch.cuda.make_graphed_callables |
19.842 ms | Trace |
We could see that the torch.cuda.graph API provides the best performance since it captures the entire training iteration (forward, loss, backward, optimizer) into a single CUDA graph, minimizing CPU overhead. However, in practice, due to model complexity that introduces host and device synchronizations, and dynamic shaped tensors, it may not always be feasible to capture the entire training step into a single graph. The torch.cuda.make_graphed_callables API offers a more flexible approach by allowing partial graph captures of individual model components, balancing performance gains with ease of use and adaptability to dynamic workloads.
References
PyTorch CUDA Graph Capture