PyTorch CUDA Graph Capture

Introduction

CUDA graph is a useful feature for optimizing GPU system performance by reducing CPU overhead for launching GPU kernels. It is especially useful when the GPU workload is small and the CPU overhead for launching GPU kernels becomes a system performance bottleneck. The NVIDIA native CUDA Graph APIs cannot be used directly for PyTorch programs, as PyTorch has its own dynamic memory management and execution model. PyTorch provides two main APIs for capturing and replaying CUDA graphs, torch.cuda.graph and torch.cuda.make_graphed_callables, that convert the dynamic memory management and execution model of PyTorch programs into static ones.

In this blog post, I would like to discuss how to use these two APIs to capture and replay CUDA graphs in PyTorch, what are the differences between them, and how they can help improve the performance of PyTorch models in different scenarios.

PyTorch CUDA Graph Capture

PyTorch exposes graphs via a raw torch.cuda.CUDAGraph class and two convenience wrappers, torch.cuda.graph and torch.cuda.make_graphed_callables. They are useful for capturing and replaying CUDA graphs in slightly different scenarios. The examples below demonstrate how to use these two APIs to capture and replay CUDA graphs for training a simple MLP model.

torch.cuda.graph

Using the torch.cuda.graph API, we will have to manually manage the warmup, static buffers, graph capture and replay. This provides full control over what operations are included in the graph, even allowing us to capture the complete training step including loss computation and optimizer updates.

In the following example, with torch.cuda.graph, each entire training iteration is invoked as a single graph replay and there is no synchronization between the host and device during the entire training process.

torch_cuda_graph_manual_capture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
"""
CUDA Graph Manual Capture Example

This script demonstrates how to manually capture and replay CUDA graphs for an
entire training iteration (forward pass, loss computation, backward pass, and
optimizer step). It profiles training with and without CUDA graphs for comparison.

Manual capture using torch.cuda.graph() provides full control over what operations
are included in the graph, allowing you to capture the complete training step
including loss computation and optimizer updates.
"""

import torch
import torch.nn as nn
from torch.profiler import record_function
from common import (train_without_cuda_graph, setup_model_and_data,
create_model, create_profiler, save_and_print_profile)


def prepare_cuda_graph(model, loss_fn, optimizer, static_input, static_target):
"""Warmup and capture CUDA graph (not profiled)."""
print(" Performing warmup iterations...")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
optimizer.zero_grad(set_to_none=True)
y_pred = model(static_input)
loss = loss_fn(y_pred, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)

# Capture
print(" Capturing CUDA graph...")
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
static_y_pred = model(static_input)
static_loss = loss_fn(static_y_pred, static_target)
static_loss.backward()
optimizer.step()

return g, static_loss


def train_with_cuda_graph(graph,
inputs,
targets,
static_input,
static_target,
static_loss,
profiler=None):
"""Train using CUDA graph for optimized replay (profiled part only)."""
print(" Training with graph replay...")

for i, (data, target) in enumerate(zip(inputs, targets)):
with record_function("## copy_input_data ##"):
static_input.copy_(data)
static_target.copy_(target)

with record_function("## graph.replay ##"):
graph.replay()

if profiler is not None:
profiler.step()

# NOTE: Avoid calling .item() in the training loop as it triggers device-to-host
# memory copy and CPU-GPU synchronization, which damages performance.
# if i % 2 == 0:
# print(f" Iteration {i+1:2d}: Loss = {static_loss.item():.4f}")

print(f" Completed {len(inputs)} iterations.")
print()


def main():
print("CUDA Graph Whole Network Capture Example")
print("=" * 70)

# Check CUDA availability
if not torch.cuda.is_available():
print(
"Error: CUDA is not available. This example requires a CUDA-capable GPU."
)
return

device = torch.device('cuda')
print(f"Using device: {torch.cuda.get_device_name(0)}")
print()

# Configuration
trace_dir = "traces" # Directory for trace files

# Model setup and data generation
config, real_inputs, real_targets = setup_model_and_data(device)

# Placeholders for graph capture
static_input = torch.randn(config['N'], config['D_in'], device=device)
static_target = torch.randn(config['N'], config['D_out'], device=device)

# ========================================================================
# Training WITHOUT CUDA Graph
# ========================================================================
print("=" * 70)
print("SCENARIO 1: Training WITHOUT CUDA Graph")
print("=" * 70)

model_no_graph = create_model(config, device)
loss_fn_no_graph = torch.nn.MSELoss()
optimizer_no_graph = torch.optim.SGD(model_no_graph.parameters(), lr=0.1)

with create_profiler() as prof_no_graph:
train_without_cuda_graph(model_no_graph,
loss_fn_no_graph,
optimizer_no_graph,
real_inputs,
real_targets,
profiler=prof_no_graph)

# Save profiling trace and print summary
trace_file_no_graph = trace_dir + "/" + "trace_without_manual_capture.json"
save_and_print_profile(prof_no_graph, trace_file_no_graph,
"without CUDA graph")

# ========================================================================
# Training WITH CUDA Graph
# ========================================================================
print("=" * 70)
print("SCENARIO 2: Training WITH CUDA Graph")
print("=" * 70)

model_with_graph = create_model(config, device)
loss_fn_with_graph = torch.nn.MSELoss()
optimizer_with_graph = torch.optim.SGD(model_with_graph.parameters(),
lr=0.1)

# Prepare graph (warmup + capture) - NOT profiled
print("Preparing CUDA graph (warmup + capture)...")
graph, static_loss = prepare_cuda_graph(model_with_graph,
loss_fn_with_graph,
optimizer_with_graph, static_input,
static_target)
print("CUDA graph ready.")
print()

# Profile only the training iterations
with create_profiler() as prof_with_graph:
train_with_cuda_graph(graph,
real_inputs,
real_targets,
static_input,
static_target,
static_loss,
profiler=prof_with_graph)

# Save profiling trace and print summary
trace_file_with_graph = trace_dir + "/" + "trace_with_manual_capture.json"
save_and_print_profile(prof_with_graph, trace_file_with_graph,
"with CUDA graph")

print("=" * 70)
print("Profiling completed successfully!")
print(f"View traces in Chrome: chrome://tracing")
print(f" - {trace_file_no_graph}")
print(f" - {trace_file_with_graph}")
print("=" * 70)


if __name__ == "__main__":
main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
$ python torch_cuda_graph_manual_capture.py
CUDA Graph Whole Network Capture Example
======================================================================
Using device: NVIDIA GeForce RTX 5080

Model configuration:
Batch size: 640
Input dim: 4096
Hidden dims: 2048 -> 1024 -> 512
Output dim: 256

======================================================================
SCENARIO 1: Training WITHOUT CUDA Graph
======================================================================
Training WITHOUT CUDA graph...
Completed 10 iterations.

Profiling trace saved to: traces/trace_without_manual_capture.json

Top 10 operations by CUDA time (without CUDA graph):
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
## forward_pass ## 0.00% 0.000us 0.00% 0.000us 0.000us 3.624ms 59.56% 3.624ms 517.737us 0 B 0 B 0 B 0 B 7
ProfilerStep* 3.22% 778.095us 74.15% 17.902ms 2.557ms 0.000us 0.00% 2.978ms 425.435us 0 B 0 B 0 B 0 B 7
autograd::engine::evaluate_function: AddmmBackward0 1.91% 461.482us 14.53% 3.509ms 125.313us 0.000us 0.00% 2.971ms 106.111us 0 B 0 B 231.98 MB -126.88 MB 28
AddmmBackward0 1.31% 317.453us 9.84% 2.375ms 84.836us 0.000us 0.00% 2.717ms 97.051us 0 B 0 B 358.75 MB 0 B 28
aten::mm 4.52% 1.091ms 6.09% 1.470ms 29.999us 2.717ms 44.66% 2.717ms 55.458us 0 B 0 B 358.75 MB 358.75 MB 49
## forward_pass ## 7.00% 1.691ms 18.60% 4.491ms 641.548us 0.000us 0.00% 2.069ms 295.579us 0 B 0 B 137.81 MB -65.62 MB 7
aten::linear 0.40% 96.651us 5.85% 1.413ms 50.447us 0.000us 0.00% 1.951ms 69.673us 0 B 0 B 65.62 MB 0 B 28
aten::addmm 3.45% 832.270us 4.48% 1.080ms 38.587us 1.951ms 32.06% 1.951ms 69.673us 0 B 0 B 65.62 MB 65.62 MB 28
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.791ms 29.43% 1.791ms 127.895us 0 B 0 B 0 B 0 B 14
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.593ms 26.19% 1.593ms 227.621us 0 B 0 B 0 B 0 B 7
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 24.142ms
Self CUDA time total: 6.085ms


======================================================================
SCENARIO 2: Training WITH CUDA Graph
======================================================================
Preparing CUDA graph (warmup + capture)...
Performing warmup iterations...
Capturing CUDA graph...
CUDA graph ready.

Training with graph replay...
Completed 10 iterations.

Profiling trace saved to: traces/trace_with_manual_capture.json

Top 10 operations by CUDA time (with CUDA graph):
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
## graph.replay ## 0.00% 0.000us 0.00% 0.000us 0.000us 6.564ms 76.16% 6.564ms 937.756us 0 B 0 B 7
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 2.663ms 30.89% 2.663ms 295.836us 0 B 0 B 9
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 2.061ms 23.91% 2.061ms 121.242us 0 B 0 B 17
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 1.172ms 13.59% 1.172ms 130.176us 0 B 0 B 9
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 573.472us 6.65% 573.472us 63.719us 0 B 0 B 9
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 456.896us 5.30% 456.896us 50.766us 0 B 0 B 9
void at::native::reduce_kernel<128, 4, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 333.825us 3.87% 333.825us 9.273us 0 B 0 B 36
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 226.624us 2.63% 226.624us 25.180us 0 B 0 B 9
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 205.058us 2.38% 205.058us 11.392us 0 B 0 B 18
Memcpy DtoD (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 169.183us 1.96% 169.183us 10.574us 0 B 0 B 16
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 8.690ms
Self CUDA time total: 8.619ms


======================================================================
Profiling completed successfully!
View traces in Chrome: chrome://tracing
- traces/trace_without_manual_capture.json
- traces/trace_with_manual_capture.json
======================================================================

torch.cuda.make_graphed_callables

The torch.cuda.make_graphed_callables API simplifies CUDA graph usage by automatically handling warmup, static buffers, graph capture, and replay. It also allows more fine-grained control over what operations are included in the graph by wrapping individual callables (like models or submodules) and graphing their forward and backward operations. Compared to using the torch.cuda.graph API, it leaves loss computation and optimizer steps outside the graph. Consequently, its CPU overhead is slightly higher due to submitting more CUDA operations and graph replays in one training iteration.

torch_cuda_graph_make_graphed_callables.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#!/usr/bin/env python3
"""
CUDA Graph with make_graphed_callables Example

This script demonstrates how to use make_graphed_callables to automatically capture
and replay CUDA graphs for model forward and backward passes.
It profiles training with and without CUDA graphs for comparison.

make_graphed_callables simplifies CUDA graph usage by automatically handling warmup
and capture. It wraps individual callables (like models) and graphs their forward
and backward operations, while leaving loss computation and optimizer steps outside
the graph.

Three scenarios are demonstrated:
1. Training WITHOUT CUDA graph (baseline)
2. Training WITH full CUDA graph (entire model)
3. Training WITH partial CUDA graph (only block2 submodule)
"""

import torch
import torch.nn as nn
from torch.profiler import record_function
from torch.cuda import make_graphed_callables
from common import (MLPModel, train_without_cuda_graph, setup_model_and_data,
create_model, create_profiler, save_and_print_profile)


def prepare_cuda_graph(model, static_input):
"""Prepare CUDA graph using make_graphed_callables (not profiled)."""
print(" Creating graphed model...")

# Wrap the model with make_graphed_callables
# This will graph the forward and backward passes of the model
graphed_model = make_graphed_callables(model, (static_input, ))

print(" CUDA graph model ready.")

return graphed_model


def prepare_partial_cuda_graph(model, static_input):
"""Prepare CUDA graph for only block2 submodule (not profiled)."""
print(" Creating partially graphed model (only block2)...")

# First, do a forward pass to determine the input shape for block2
with torch.no_grad():
block1_output = model.block1(static_input)

# Wrap only block2 with make_graphed_callables
# This will graph only the forward and backward passes of block2
# When passing a single callable, it returns the graphed callable directly
graphed_block2 = make_graphed_callables(model.block2, (block1_output, ))
model.block2 = graphed_block2

print(" CUDA graph for block2 ready.")

return model


def train_with_cuda_graph(graphed_model,
loss_fn,
optimizer,
inputs,
targets,
profiler=None):
"""Train using CUDA graph model for optimized replay (profiled part only)."""
print(" Training with graph replay...")

for i, (data, target) in enumerate(zip(inputs, targets)):
with record_function("## optimizer.zero_grad ##"):
optimizer.zero_grad(set_to_none=True)

# Forward pass runs as a graph
with record_function("## forward_pass_graphed ##"):
y_pred = graphed_model(data)

# NOTE: Loss computation is NOT part of the CUDA graph
# Only the model's forward/backward passes are graphed
with record_function("## loss_computation ##"):
loss = loss_fn(y_pred, target)

# Backward pass runs as a graph
with record_function("## backward_pass_graphed ##"):
loss.backward()

# NOTE: Optimizer step is NOT part of the CUDA graph
with record_function("## optimizer.step ##"):
optimizer.step()

if profiler is not None:
profiler.step()

# NOTE: Avoid calling .item() in the training loop as it triggers device-to-host
# memory copy and CPU-GPU synchronization, which damages performance.
# if i % 2 == 0:
# print(f" Iteration {i+1:2d}: Loss = {loss.item():.4f}")

print(f" Completed {len(inputs)} iterations.")
print()


def main():
print("CUDA Graph Whole Network Capture Example")
print("=" * 70)

# Check CUDA availability
if not torch.cuda.is_available():
print(
"Error: CUDA is not available. This example requires a CUDA-capable GPU."
)
return

device = torch.device('cuda')
print(f"Using device: {torch.cuda.get_device_name(0)}")
print()

# Configuration
trace_dir = "traces" # Directory for trace files

# Model setup and data generation
config, real_inputs, real_targets = setup_model_and_data(device)

# Placeholders for graph capture
static_input = torch.randn(config['N'], config['D_in'], device=device)
static_target = torch.randn(config['N'], config['D_out'], device=device)

# ========================================================================
# Training WITHOUT CUDA Graph
# ========================================================================
print("=" * 70)
print("SCENARIO 1: Training WITHOUT CUDA Graph")
print("=" * 70)

model_no_graph = create_model(config, device)
loss_fn_no_graph = torch.nn.MSELoss()
optimizer_no_graph = torch.optim.SGD(model_no_graph.parameters(), lr=0.1)

with create_profiler() as prof_no_graph:
train_without_cuda_graph(model_no_graph,
loss_fn_no_graph,
optimizer_no_graph,
real_inputs,
real_targets,
profiler=prof_no_graph)

# Save profiling trace and print summary
trace_file_no_graph = trace_dir + "/" + "trace_without_make_graphed_callables.json"
save_and_print_profile(prof_no_graph, trace_file_no_graph,
"without CUDA graph")

# ========================================================================
# Training WITH CUDA Graph
# ========================================================================
print("=" * 70)
print("SCENARIO 2: Training WITH CUDA Graph")
print("=" * 70)

model_with_graph = create_model(config, device)
loss_fn_with_graph = torch.nn.MSELoss()
optimizer_with_graph = torch.optim.SGD(model_with_graph.parameters(),
lr=0.1)

# Prepare graph (warmup + capture) - NOT profiled
print("Preparing CUDA graph (warmup + capture)...")
graphed_model = prepare_cuda_graph(model_with_graph, static_input)
print("CUDA graph ready.")
print()

# Profile only the training iterations
with create_profiler() as prof_with_graph:
train_with_cuda_graph(graphed_model,
loss_fn_with_graph,
optimizer_with_graph,
real_inputs,
real_targets,
profiler=prof_with_graph)

# Save profiling trace and print summary
trace_file_with_graph = trace_dir + "/" + "trace_with_make_graphed_callables.json"
save_and_print_profile(prof_with_graph, trace_file_with_graph,
"with CUDA graph")

print("=" * 70)
print("Profiling completed successfully!")
print(f"View traces in Chrome: chrome://tracing")
print(f" - {trace_file_no_graph}")
print(f" - {trace_file_with_graph}")
print("=" * 70)

# ========================================================================
# Training WITH PARTIAL CUDA Graph (only block2)
# ========================================================================
print()
print("=" * 70)
print("SCENARIO 3: Training WITH PARTIAL CUDA Graph (only block2)")
print("=" * 70)

model_partial_graph = create_model(config, device)
loss_fn_partial_graph = torch.nn.MSELoss()
optimizer_partial_graph = torch.optim.SGD(model_partial_graph.parameters(),
lr=0.1)

# Prepare partial graph (only block2) - NOT profiled
print("Preparing CUDA graph for block2 only (warmup + capture)...")
model_partial_graph = prepare_partial_cuda_graph(model_partial_graph,
static_input)
print("CUDA graph for block2 ready.")
print()

# Profile only the training iterations
with create_profiler() as prof_partial_graph:
train_with_cuda_graph(model_partial_graph,
loss_fn_partial_graph,
optimizer_partial_graph,
real_inputs,
real_targets,
profiler=prof_partial_graph)

# Save profiling trace and print summary
trace_file_partial_graph = trace_dir + "/" + "trace_with_partial_make_graphed_callables.json"
save_and_print_profile(prof_partial_graph, trace_file_partial_graph,
"with partial CUDA graph - block2 only")

print("=" * 70)
print("All profiling completed successfully!")
print(f"View traces in Chrome: chrome://tracing")
print(f" - {trace_file_no_graph}")
print(f" - {trace_file_with_graph}")
print(f" - {trace_file_partial_graph}")
print("=" * 70)


if __name__ == "__main__":
main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
$ python torch_cuda_graph_make_graphed_callables.py
CUDA Graph Whole Network Capture Example
======================================================================
Using device: NVIDIA GeForce RTX 5080

Model configuration:
Batch size: 640
Input dim: 4096
Hidden dims: 2048 -> 1024 -> 512
Output dim: 256

======================================================================
SCENARIO 1: Training WITHOUT CUDA Graph
======================================================================
Training WITHOUT CUDA graph...
Completed 10 iterations.

Profiling trace saved to: traces/trace_without_make_graphed_callables.json

Top 10 operations by CUDA time (without CUDA graph):
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
## forward_pass ## 0.00% 0.000us 0.00% 0.000us 0.000us 3.786ms 62.30% 3.786ms 540.851us 0 B 0 B 0 B 0 B 7
ProfilerStep* 3.20% 725.602us 75.51% 17.097ms 2.442ms 0.000us 0.00% 2.974ms 424.860us 0 B 0 B 0 B 0 B 7
autograd::engine::evaluate_function: AddmmBackward0 1.84% 416.249us 13.71% 3.104ms 110.858us 0.000us 0.00% 2.973ms 106.185us 0 B 0 B 231.98 MB -126.88 MB 28
AddmmBackward0 1.25% 282.658us 9.21% 2.085ms 74.461us 0.000us 0.00% 2.722ms 97.223us 0 B 0 B 358.75 MB 0 B 28
aten::mm 4.27% 967.644us 5.73% 1.297ms 26.465us 2.722ms 44.79% 2.722ms 55.556us 0 B 0 B 358.75 MB 358.75 MB 49
## forward_pass ## 7.58% 1.717ms 20.62% 4.669ms 667.061us 0.000us 0.00% 2.069ms 295.561us 0 B 0 B 137.81 MB -65.62 MB 7
aten::linear 0.46% 103.960us 6.63% 1.502ms 53.640us 0.000us 0.00% 1.956ms 69.849us 0 B 0 B 65.62 MB 0 B 28
aten::addmm 3.89% 879.966us 5.10% 1.155ms 41.243us 1.956ms 32.18% 1.956ms 69.849us 0 B 0 B 65.62 MB 65.62 MB 28
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.795ms 29.53% 1.795ms 128.192us 0 B 0 B 0 B 0 B 14
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.597ms 26.28% 1.597ms 228.169us 0 B 0 B 0 B 0 B 7
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 22.641ms
Self CUDA time total: 6.077ms


======================================================================
SCENARIO 2: Training WITH CUDA Graph
======================================================================
Preparing CUDA graph (warmup + capture)...
Creating graphed model...
CUDA graph model ready.
CUDA graph ready.

Training with graph replay...
Completed 10 iterations.

Profiling trace saved to: traces/trace_with_make_graphed_callables.json

Top 10 operations by CUDA time (with CUDA graph):
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
ProfilerStep* 5.42% 620.940us 81.15% 9.292ms 1.327ms 0.000us 0.00% 3.128ms 446.854us 0 B 0 B 0 B 0 B 7
autograd::engine::evaluate_function: GraphedBackward... 1.73% 197.721us 9.66% 1.106ms 157.971us 0.000us 0.00% 3.063ms 437.572us 0 B 0 B -4.38 MB -4.38 MB 7
GraphedBackward 4.24% 484.930us 7.65% 876.284us 125.183us 3.053ms 48.09% 3.063ms 437.572us 0 B 0 B 0 B 0 B 7
## forward_pass_graphed ## 0.00% 0.000us 0.00% 0.000us 0.000us 2.316ms 36.47% 2.316ms 330.794us 0 B 0 B 0 B 0 B 7
## forward_pass_graphed ## 4.92% 563.566us 11.65% 1.333ms 190.490us 0.000us 0.00% 2.192ms 313.103us 0 B 0 B 0 B 0 B 7
Graphed 2.76% 315.584us 6.72% 769.863us 109.980us 2.045ms 32.21% 2.192ms 313.103us 0 B 0 B 0 B 0 B 7
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.759ms 27.70% 1.759ms 125.641us 0 B 0 B 0 B 0 B 14
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.556ms 24.51% 1.556ms 222.281us 0 B 0 B 0 B 0 B 7
void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 1.017ms 16.02% 1.017ms 127.164us 0 B 0 B 0 B 0 B 8
## optimizer.step ## 2.78% 318.836us 13.08% 1.498ms 214.023us 0.000us 0.00% 883.811us 126.259us 0 B 0 B 0 B 0 B 7
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 11.450ms
Self CUDA time total: 6.349ms


======================================================================
Profiling completed successfully!
View traces in Chrome: chrome://tracing
- traces/trace_without_make_graphed_callables.json
- traces/trace_with_make_graphed_callables.json
======================================================================

======================================================================
SCENARIO 3: Training WITH PARTIAL CUDA Graph (only block2)
======================================================================
Preparing CUDA graph for block2 only (warmup + capture)...
Creating partially graphed model (only block2)...
CUDA graph for block2 ready.
CUDA graph for block2 ready.

Training with graph replay...
Completed 10 iterations.

Profiling trace saved to: traces/trace_with_partial_make_graphed_callables.json

Top 10 operations by CUDA time (with partial CUDA graph - block2 only):
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
## forward_pass_graphed ## 0.00% 0.000us 0.00% 0.000us 0.000us 3.626ms 109.36% 3.626ms 517.934us 0 B 0 B 0 B 0 B 7
ProfilerStep* 3.68% 730.731us 78.39% 15.555ms 2.222ms 0.000us 0.00% 2.259ms 322.729us 0 B 0 B 0 B 0 B 7
## forward_pass_graphed ## 8.58% 1.703ms 22.93% 4.549ms 649.839us 0.000us 0.00% 2.106ms 300.809us 0 B 0 B 69.56 MB -88.38 MB 7
void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_... 0.00% 0.000us 0.00% 0.000us 0.000us 1.793ms 54.08% 1.793ms 128.066us 0 B 0 B 0 B 0 B 14
aten::linear 0.38% 74.894us 5.74% 1.138ms 54.190us 0.000us 0.00% 1.545ms 73.556us 0 B 0 B 53.38 MB 0 B 21
aten::addmm 3.50% 694.782us 4.42% 876.082us 41.718us 1.545ms 46.59% 1.545ms 73.556us 0 B 0 B 53.38 MB 53.38 MB 21
autograd::engine::evaluate_function: GraphedBackward... 0.58% 115.669us 4.10% 813.511us 116.216us 0.000us 0.00% 561.567us 80.224us 0 B 0 B -17.50 MB -17.50 MB 7
GraphedBackward 2.12% 420.204us 3.44% 681.689us 97.384us 546.366us 16.48% 561.567us 80.224us 0 B 0 B 0 B 0 B 7
Graphed 1.65% 327.433us 3.81% 756.580us 108.083us 444.992us 13.42% 481.505us 68.786us 0 B 0 B 0 B 0 B 7
autograd::engine::evaluate_function: AddmmBackward0 1.19% 235.448us 9.29% 1.844ms 87.803us 0.000us 0.00% 453.055us 21.574us 0 B 0 B 16.65 MB -27.12 MB 21
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 19.842ms
Self CUDA time total: 3.315ms


======================================================================
All profiling completed successfully!
View traces in Chrome: chrome://tracing
- traces/trace_without_make_graphed_callables.json
- traces/trace_with_make_graphed_callables.json
- traces/trace_with_partial_make_graphed_callables.json
======================================================================

Summary

The following table summarizes the CPU wall clock time for training a model with different levels of CUDA graph integration using different PyTorch APIs. The profiling traces were collected using torch.profiler and could be downloaded and viewed in Perfetto.

CUDA Graph API Self CPU Time Total Profiling Trace
No CUDA Graph N/A 24.142 ms Trace
Graph: Full Model Forward + Loss + Full Model Backward + Optimizer torch.cuda.graph 8.690 ms Trace
No CUDA Graph N/A 22.641 ms Trace
Graph 0: Full Model Forward, Graph 1: Full Model Backward torch.cuda.make_graphed_callables 11.450 ms Trace
Graph 0: Submodule Forward, Graph 1: Submodule Backward torch.cuda.make_graphed_callables 19.842 ms Trace

We could see that the torch.cuda.graph API provides the best performance since it captures the entire training iteration (forward, loss, backward, optimizer) into a single CUDA graph, minimizing CPU overhead. However, in practice, due to model complexity that introduces host and device synchronizations, and dynamic shaped tensors, it may not always be feasible to capture the entire training step into a single graph. The torch.cuda.make_graphed_callables API offers a more flexible approach by allowing partial graph captures of individual model components, balancing performance gains with ease of use and adaptability to dynamic workloads.

References

Author

Lei Mao

Posted on

01-12-2026

Updated on

01-12-2026

Licensed under


Comments