PyTorch Triton Kernel Transparent Tracing and Compilation

Introduction

PyTorch allows the user to create custom operations using torch.library.custom_op or C++/CUDA custom function and class. Those custom operations will be treated as opaque operators during tracing and compilation, which means that the internals of those custom operations will not be visible to the PyTorch compiler for optimizations if possible. This is also what usually happens for custom operations in other deep learning inference frameworks, such as TensorRT.

PyTorch also allows the user to create Triton kernel functions decorated with @triton.jit, Just-In-Time (JIT) compile them, and use them in models for training and inference, not only in eager execution but also in torch.compile and torch.export compilation workflows. Triton kernel functions can of course be treated as opaque custom operations with @torch.library.register_fake so that the FakeTensor-based symbolic tracing can work. But the disadvantage is that the Triton kernel cannot be optimized by the compiler and Triton JIT compilation is only available in the Python environment.

If the user would like to let the Triton kernel have an opportunity to be optimized by the compiler or want to use pre-compiled Triton kernels in C++/CUDA environment, the Triton kernel implementation must be visible to the compiler. In this blog post, I will discuss how to make Triton kernels visible to tracing and compilation by torch.compile, torch.export, and AOTInductor.

PyTorch Triton Kernel Transparent Tracing and Compilation

In the following example, I created a simple SiLU Triton kernel triton_silu_kernel and wrapped it in two different Python functions. The first function triton_silu_triton_op is registered as a custom operation with @triton_op and used wrap_triton to wrap the Triton kernel, which means it will be treated as an opaque Triton operator during tracing and compilation. The second function triton_silu_pytorch_op is not registered as a custom operation with @triton_op and no wrap_triton is used.

It turns out that both triton_silu_triton_op and triton_silu_pytorch_op can be traced and compiled by torch.compile(fullgraph=true). However, for torch.export, the Triton kernel triton_silu_kernel can only be exported in the following three cases:

  • triton_silu_triton_op (registered with @triton_op and wrapped with wrap_triton) and torch.export with strict=False.
  • triton_silu_triton_op (registered with @triton_op and wrapped with wrap_triton) and torch.export with strict=True.
  • triton_silu_pytorch_op (not registered with @triton_op and no wrap_triton) and torch.export with strict=True.

In the case of triton_silu_pytorch_op (not registered with @triton_op and no wrap_triton) and torch.export with strict=True, it works because torch.export with strict=True uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.

Triton Kernel, torch.compile, torch.export, and AOTInductor

torch_triton_export_aotinductor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import argparse
import copy
import os
import random
import shutil

import torch
import torch.profiler
import triton
import triton.language as tl
from torch.export import export, Dim
from torch.library import triton_op, wrap_triton


# 1. Define the Pure Triton Kernel
@triton.jit
def triton_silu_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):

pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

x = tl.load(in_ptr + offsets, mask=mask)

# Keep everything in float32 for the math
x_f32 = x.to(tl.float32)
sigmoid_x = 1.0 / (1.0 + tl.exp(-x_f32))

# Multiply in float32, THEN cast to bfloat16
out = (x_f32 * sigmoid_x).to(x.dtype)

tl.store(out_ptr + offsets, out, mask=mask)


# 2. Register with triton_op to ensure transparent tracing by torch.export
@triton_op("custom_ops::triton_silu_triton_op", mutates_args={})
def triton_silu_triton_op(x: torch.Tensor) -> torch.Tensor:

# Enforce contiguous memory to ensure 1D pointer arithmetic is safe
assert x.is_contiguous(), "Input tensor must be contiguous"
out = torch.empty_like(x)
n_elements = x.numel()

def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )

wrap_triton(triton_silu_kernel)[grid](x, out, n_elements, BLOCK_SIZE=1024)
return out


# Alternative PyTorch operator that uses the same Triton kernel but is not registered with triton_op.
def triton_silu_pytorch_op(x: torch.Tensor) -> torch.Tensor:

# Enforce contiguous memory to ensure 1D pointer arithmetic is safe
assert x.is_contiguous(), "Input tensor must be contiguous"
out = torch.empty_like(x)
n_elements = x.numel()

def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )

triton_silu_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
return out


def pytorch_silu_eager(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(x)


# 3. Model Architecture
class CustomModel(torch.nn.Module):

def __init__(self,
in_features=128,
out_features=64,
silu_op="triton_silu_triton_op"):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features)
self.linear2 = torch.nn.Linear(out_features, in_features)
self.silu_op = silu_op

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
# Call the transparent operator
if self.silu_op == "triton_silu_triton_op":
x = torch.ops.custom_ops.triton_silu_triton_op.default(x)
elif self.silu_op == "triton_silu_pytorch_op":
x = triton_silu_pytorch_op(x)
elif self.silu_op == "pytorch_silu_eager":
x = pytorch_silu_eager(x)
else:
raise ValueError(f"Invalid silu_op: {self.silu_op}")
x = self.linear2(x)
return x


def main():

# Set random seed for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
random.seed(random_seed)
# For deterministic behavior (optional, may affect performance)
torch.use_deterministic_algorithms(True, warn_only=True)

# Remove TorchInductor and Triton kernel cache if present
inductor_cache_dir = "/tmp/torchinductor_root/"
if os.path.exists(inductor_cache_dir):
print(f"Removing TorchInductor cache directory: {inductor_cache_dir}")
shutil.rmtree(inductor_cache_dir)

parser = argparse.ArgumentParser(
description="Torch Triton Export with AOTInductor")
parser.add_argument(
"--use_registered_triton_op",
action="store_true",
default=False,
help="Use registered custom Triton op (default: False)")
parser.add_argument(
"--strict_export",
action="store_true",
default=False,
help="Enable strict export (TorchDynamo tracing, default: False)")
# Always use torch.compile(fullgraph=True), so remove dynamo_fullgraph arg
args = parser.parse_args()

use_registered_triton_op = args.use_registered_triton_op
strict_export = args.strict_export

silu_op = "triton_silu_triton_op" if use_registered_triton_op else "triton_silu_pytorch_op"
print(f"Using silu_op: {silu_op}")

in_features = 128
out_features = 64

device = "cuda"
dtype = torch.bfloat16

atol = 1e-5
rtol = 1e-5

# 4. Instantiate model and weights in BF16
model = CustomModel(in_features=in_features,
out_features=out_features,
silu_op=silu_op).to(device=device, dtype=dtype)
# Randomly initialize all parameters
for param in model.parameters():
if param.requires_grad:
torch.nn.init.uniform_(param, -1.0, 1.0)
model.eval()

# 5. Define dynamic batch dimension constraints (batch size can vary from 1 to 1024)
batch_dim = Dim("batch", min=1, max=1024)
dynamic_shapes = {"x": {0: batch_dim}}

# 6. Prepare sample input
sample_input = torch.randn(8, in_features, device=device, dtype=dtype)

# 7. Always TorchDynamo fullgraph compile
print("--- Cleaning torch.compile cache ---")
torch._dynamo.reset() # Clean torch.compile cache

print("--- Compiling model with torch.compile(fullgraph=True) ---")
torch_compiled_model = torch.compile(model, fullgraph=True)

# 8. Trace the model using torch.export (AOTInductor path)
print(f"--- Exporting model via torch.export (strict={strict_export}) ---")
exported_program = export(model,
args=(sample_input, ),
dynamic_shapes=dynamic_shapes,
strict=strict_export)
print("Model successfully traced and exported!")

print("\nGraph Nodes Extracted:")
exported_program.graph.print_tabular()

# 9. Compile and Package via AOTInductor
print("\n--- Compiling and Packaging via AOTInductor ---")
output_package = "/tmp/compiled_model.pt2"

# Clean up previous artifacts if they exist
if os.path.exists(output_package):
os.remove(output_package)

# Instruct Inductor to accept user-defined Triton kernels natively
torch._inductor.config.static_launch_user_defined_triton_kernels = True

# Use the unified package compiler. This wraps weights, metadata,
# and the compiled binary artifact natively inside a single zipped .pt2 container file.
package_path = torch._inductor.aoti_compile_and_package(
exported_program,
package_path=output_package,
)
print(
f"Compilation finished! Self-contained package saved to: {package_path}"
)

# 12. Correctness Verification
# Prepare inference input for correctness and profiling
inference_input = torch.randn(16, in_features, device=device, dtype=dtype)

# Run both models to get outputs for correctness check
with torch.no_grad():
torch_compiled_output = torch_compiled_model(inference_input)
compiled_runner = torch._inductor.aoti_load_package(package_path)
with torch.no_grad():
aotinductor_compiled_output = compiled_runner(inference_input)

# Reference output: copy model and set silu_op to pytorch_silu_eager
reference_model = copy.deepcopy(model)
reference_model.silu_op = "pytorch_silu_eager"
reference_model.eval()
with torch.no_grad():
eager_output = reference_model(inference_input)

is_torch_compile_correct = torch.allclose(torch_compiled_output,
eager_output,
atol=atol,
rtol=rtol)
is_aotinductor_correct = torch.allclose(aotinductor_compiled_output,
eager_output,
atol=atol,
rtol=rtol)
is_outputs_match = torch.allclose(torch_compiled_output,
aotinductor_compiled_output,
atol=atol,
rtol=rtol)

print(f"torch_compiled_model output shape: {torch_compiled_output.shape}")
print(
f"aotinductor_compiled_model output shape: {aotinductor_compiled_output.shape}"
)
print(f"eager output shape: {eager_output.shape}")
print(
f"torch.compile correctness vs eager? -> **{is_torch_compile_correct}**"
)
print(
f"aotinductor correctness vs eager? -> **{is_aotinductor_correct}**"
)
print(f"torch.compile vs aotinductor match? -> **{is_outputs_match}**")

# --- Run profiling at the end of the program ---
# Define file paths for profiling traces
torch_compile_profiler_path = "./torch_compile_profiler.json"
aotinductor_profiler_path = "./aotinductor_profiler.json"

print(
"\n--- Running torch_compiled_model (fullgraph=True) with profiler ---"
)
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
warmup = 3
steps = 5
schedule = torch.profiler.schedule(wait=0,
warmup=warmup,
active=steps,
repeat=1)
with torch.profiler.profile(
activities=activities,
schedule=schedule,
record_shapes=True,
with_flops=True,
) as prof:
for step in range(warmup + steps):
with torch.profiler.record_function(f"step_{step}"):
with torch.no_grad():
torch_compiled_output = torch_compiled_model(
inference_input)
prof.step()
prof.export_chrome_trace(torch_compile_profiler_path)
print(
f"Profiling trace for torch_compiled_model saved to {torch_compile_profiler_path}"
)

print(
"\n--- Loading AOTInductor Compiled Model Package & Running Inference with profiler ---"
)
compiled_runner = torch._inductor.aoti_load_package(package_path)
with torch.profiler.profile(
activities=activities,
schedule=schedule,
record_shapes=True,
with_flops=True,
) as prof:
for step in range(warmup + steps):
with torch.profiler.record_function(f"step_{step}"):
with torch.no_grad():
aotinductor_compiled_output = compiled_runner(
inference_input)
prof.step()
prof.export_chrome_trace(aotinductor_profiler_path)
print(
f"Profiling trace for aotinductor_compiled_output saved to {aotinductor_profiler_path}"
)


if __name__ == "__main__":

main()

triton_silu_pytorch_op and strict=False

In this case, we encountered an error during torch.export, informing us that we should wrap the Triton kernel as an opaque operator, which is very confusing and violates our purpose.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
$ python torch_triton_export_aotinductor.py
Removing TorchInductor cache directory: /tmp/torchinductor_root/
Using silu_op: triton_silu_pytorch_op
--- Cleaning torch.compile cache ---
--- Compiling model with torch.compile(fullgraph=True) ---
--- Exporting model via torch.export (strict=False) ---
Traceback (most recent call last):
File "/mnt/torch_triton_export_aotinductor.py", line 249, in <module>
main()
File "/mnt/torch_triton_export_aotinductor.py", line 169, in main
exported_program = export(model,
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 205, in export
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/__init__.py", line 171, in export
return _export(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1343, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1309, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2508, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1343, in wrapper
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1309, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2296, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2225, in _non_strict_export
aten_export_artifact = _to_aten_func(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2002, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2132, in _aot_export_non_strict
gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1910, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2826, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2727, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2688, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 54, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1255, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1533, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2264, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1255, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 890, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1603, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 1794, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 204, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1507, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2353, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 572, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 857, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 2116, in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 864, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2353, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 572, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 857, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/torch_triton_export_aotinductor.py", line 87, in forward
x = triton_silu_pytorch_op(x)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/torch_triton_export_aotinductor.py", line 61, in triton_silu_pytorch_op
triton_silu_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 370, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 723, in run
bound_args, specialization, options = binder(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "<string>", line 4, in dynamic_func
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1654, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1725, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_export/non_strict_utils.py", line 1159, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

triton_silu_pytorch_op and strict=True

Using triton_silu_pytorch_op with torch.export with strict=True works because torch.export with strict=True uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
$ python torch_triton_export_aotinductor.py --strict_export
Removing TorchInductor cache directory: /tmp/torchinductor_root/
Using silu_op: triton_silu_pytorch_op
--- Cleaning torch.compile cache ---
--- Compiling model with torch.compile(fullgraph=True) ---
--- Exporting model via torch.export (strict=True) ---
/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch.
config[key] = copy.deepcopy(getattr(self, key))
Model successfully traced and exported!

Graph Nodes Extracted:
opcode name target args kwargs
------------- ------------------------------------ ------------------------------ ---------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
placeholder p_linear1_weight p_linear1_weight () {}
placeholder p_linear1_bias p_linear1_bias () {}
placeholder p_linear2_weight p_linear2_weight () {}
placeholder p_linear2_bias p_linear2_bias () {}
placeholder x x () {}
call_function sym_size_int aten.sym_size.int (x, 0) {}
call_function linear aten.linear.default (x, p_linear1_weight, p_linear1_bias) {}
call_function empty_like aten.empty_like.default (linear,) {'pin_memory': False}
call_function mul <built-in function mul> (64, sym_size_int) {}
call_function add <built-in function add> (mul, 1024) {}
call_function sub <built-in function sub> (add, 1) {}
call_function floordiv <built-in function floordiv> (sub, 1024) {}
call_function triton_kernel_wrapper_mutation_proxy triton_kernel_wrapper_mutation () {'kernel_idx': 0, 'constant_args_idx': 0, 'grid': [(floordiv, 1, 1)], 'tma_descriptor_metadata': {}, 'kwargs': {'in_ptr': linear, 'out_ptr': empty_like, 'n_elements': mul}}
call_function linear_1 aten.linear.default (empty_like, p_linear2_weight, p_linear2_bias) {}
output output output ((linear_1,),) {}

--- Compiling and Packaging via AOTInductor ---
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2

--- Running torch_compiled_model (fullgraph=True) ---

--- Loading AOTInductor Compiled Model Package & Running Inference ---
torch_compiled_model output shape: torch.Size([16, 128])
aotinductor_compiled_model output shape: torch.Size([16, 128])
eager output shape: torch.Size([16, 128])
torch.compile correctness vs eager? -> **True**
aotinductor correctness vs eager? -> **True**
torch.compile vs aotinductor match? -> **True**

triton_silu_triton_op and strict=False

When triton_silu_triton_op is registered with @triton_op and wrapped with wrap_triton, it can be exported by torch.export regardless of whether strict=True or strict=False is used.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
$ python torch_triton_export_aotinductor.py --use_registered_triton_op
Removing TorchInductor cache directory: /tmp/torchinductor_root/
Using silu_op: triton_silu_triton_op
--- Cleaning torch.compile cache ---
--- Compiling model with torch.compile(fullgraph=True) ---
--- Exporting model via torch.export (strict=False) ---
Model successfully traced and exported!

Graph Nodes Extracted:
opcode name target args kwargs
------------- --------------------- ---------------------------------------- --------------------------------------------------------- --------
placeholder p_linear1_weight p_linear1_weight () {}
placeholder p_linear1_bias p_linear1_bias () {}
placeholder p_linear2_weight p_linear2_weight () {}
placeholder p_linear2_bias p_linear2_bias () {}
placeholder x x () {}
call_function linear aten.linear.default (x, p_linear1_weight, p_linear1_bias) {}
call_function triton_silu_triton_op custom_ops.triton_silu_triton_op.default (linear,) {}
call_function linear_1 aten.linear.default (triton_silu_triton_op, p_linear2_weight, p_linear2_bias) {}
output output output ((linear_1,),) {}

--- Compiling and Packaging via AOTInductor ---
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch.
config[key] = copy.deepcopy(getattr(self, key))
Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2

--- Running torch_compiled_model (fullgraph=True) ---

--- Loading AOTInductor Compiled Model Package & Running Inference ---
torch_compiled_model output shape: torch.Size([16, 128])
aotinductor_compiled_model output shape: torch.Size([16, 128])
eager output shape: torch.Size([16, 128])
torch.compile correctness vs eager? -> **True**
aotinductor correctness vs eager? -> **True**
torch.compile vs aotinductor match? -> **True**

triton_silu_triton_op and strict=True

When triton_silu_triton_op is registered with @triton_op and wrapped with wrap_triton, it can be exported by torch.export regardless of whether strict=True or strict=False is used.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
$ python torch_triton_export_aotinductor.py --use_registered_triton_op --strict_export
Removing TorchInductor cache directory: /tmp/torchinductor_root/
Using silu_op: triton_silu_triton_op
--- Cleaning torch.compile cache ---
--- Compiling model with torch.compile(fullgraph=True) ---
--- Exporting model via torch.export (strict=True) ---
/usr/local/lib/python3.12/dist-packages/torch/utils/_config_module.py:540: FutureWarning: torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit is deprecated and does not do anything. It will be removed in a future version of PyTorch.
config[key] = copy.deepcopy(getattr(self, key))
Model successfully traced and exported!

Graph Nodes Extracted:
opcode name target args kwargs
------------- --------------------- ---------------------------------------- --------------------------------------------------------- --------
placeholder p_linear1_weight p_linear1_weight () {}
placeholder p_linear1_bias p_linear1_bias () {}
placeholder p_linear2_weight p_linear2_weight () {}
placeholder p_linear2_bias p_linear2_bias () {}
placeholder x x () {}
call_function linear aten.linear.default (x, p_linear1_weight, p_linear1_bias) {}
call_function triton_silu_triton_op custom_ops.triton_silu_triton_op.default (linear,) {}
call_function linear_1 aten.linear.default (triton_silu_triton_op, p_linear2_weight, p_linear2_bias) {}
output output output ((linear_1,),) {}

--- Compiling and Packaging via AOTInductor ---
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
Compilation finished! Self-contained package saved to: /tmp/compiled_model.pt2

--- Running torch_compiled_model (fullgraph=True) ---

--- Loading AOTInductor Compiled Model Package & Running Inference ---
torch_compiled_model output shape: torch.Size([16, 128])
aotinductor_compiled_model output shape: torch.Size([16, 128])
eager output shape: torch.Size([16, 128])
torch.compile correctness vs eager? -> **True**
aotinductor correctness vs eager? -> **True**
torch.compile vs aotinductor match? -> **True**

Conclusions

We can use Triton kernels not only via JIT in PyTorch but also via pre-compilation in AOTInductor.

References

Author

Lei Mao

Posted on

05-22-2026

Updated on

05-22-2026

Licensed under


Comments