PyTorch Triton Kernel Transparent Tracing and Compilation
Introduction
PyTorch allows the user to create custom operations using torch.library.custom_op or C++/CUDA custom function and class. Those custom operations will be treated as opaque operators during tracing and compilation, which means that the internals of those custom operations will not be visible to the PyTorch compiler for optimizations if possible. This is also what usually happens for custom operations in other deep learning inference frameworks, such as TensorRT.
PyTorch also allows the user to create Triton kernel functions decorated with @triton.jit, Just-In-Time (JIT) compile them, and use them in models for training and inference, not only in eager execution but also in torch.compile and torch.export compilation workflows. Triton kernel functions can of course be treated as opaque custom operations with @torch.library.register_fake so that the FakeTensor-based symbolic tracing can work. But the disadvantage is that the Triton kernel cannot be optimized by the compiler and Triton JIT compilation is only available in the Python environment.
If the user would like to let the Triton kernel have an opportunity to be optimized by the compiler or want to use pre-compiled Triton kernels in C++/CUDA environment, the Triton kernel implementation must be visible to the compiler. In this blog post, I will discuss how to make Triton kernels visible to tracing and compilation by torch.compile, torch.export, and AOTInductor.
PyTorch Triton Kernel Transparent Tracing and Compilation
In the following example, I created a simple SiLU Triton kernel triton_silu_kernel and wrapped it in two different Python functions. The first function triton_silu_triton_op is registered as a custom operation with @triton_op and used wrap_triton to wrap the Triton kernel, which means it will be treated as an opaque Triton operator during tracing and compilation. The second function triton_silu_pytorch_op is not registered as a custom operation with @triton_op and no wrap_triton is used.
It turns out that both triton_silu_triton_op and triton_silu_pytorch_op can be traced and compiled by torch.compile(fullgraph=true). However, for torch.export, the Triton kernel triton_silu_kernel can only be exported in the following three cases:
triton_silu_triton_op(registered with@triton_opand wrapped withwrap_triton) andtorch.exportwithstrict=False.triton_silu_triton_op(registered with@triton_opand wrapped withwrap_triton) andtorch.exportwithstrict=True.triton_silu_pytorch_op(not registered with@triton_opand nowrap_triton) andtorch.exportwithstrict=True.
In the case of triton_silu_pytorch_op (not registered with @triton_op and no wrap_triton) and torch.export with strict=True, it works because torch.export with strict=True uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.
Triton Kernel, torch.compile, torch.export, and AOTInductor
1 | import argparse |
triton_silu_pytorch_op and strict=False
In this case, we encountered an error during torch.export, informing us that we should wrap the Triton kernel as an opaque operator, which is very confusing and violates our purpose.
1 | $ python torch_triton_export_aotinductor.py |
triton_silu_pytorch_op and strict=True
Using triton_silu_pytorch_op with torch.export with strict=True works because torch.export with strict=True uses TorchDynamo-based tracing, which can see through the Python function and trace into the Triton kernel, even though the Triton kernel is not registered as a transparent custom operation.
1 | $ python torch_triton_export_aotinductor.py --strict_export |
triton_silu_triton_op and strict=False
When triton_silu_triton_op is registered with @triton_op and wrapped with wrap_triton, it can be exported by torch.export regardless of whether strict=True or strict=False is used.
1 | $ python torch_triton_export_aotinductor.py --use_registered_triton_op |
triton_silu_triton_op and strict=True
When triton_silu_triton_op is registered with @triton_op and wrapped with wrap_triton, it can be exported by torch.export regardless of whether strict=True or strict=False is used.
1 | $ python torch_triton_export_aotinductor.py --use_registered_triton_op --strict_export |
Conclusions
We can use Triton kernels not only via JIT in PyTorch but also via pre-compilation in AOTInductor.
References
PyTorch Triton Kernel Transparent Tracing and Compilation
https://leimao.github.io/blog/PyTorch-Triton-Kernel-Transparent-Tracing-and-Compilation/