PyTorch Graph Symbolic Integer

Introduction

When converting a PyTorch eager model to a graph, there are graph breaks if model tracing encounters control flow that depends on input data rather than meta-data or operations whose output meta-data depend on input data rather than meta-data. The meta-data includes things like tensor shapes and data types. The former graph break problem might be solved using conditional operations such as torch.cond provided that certain constraints are satisfied, such as the output meta-data of the two branches of the conditional operation must be the same. The latter graph break problem is sometimes more difficult to solve because it usually requires re-designing the neural network model.

Symbolic integers (SymInts) are used to represent variables that can span a range. The graph can allow dynamic shapes if the graph meta-data can be described using symbolic integers when the actual input meta-data is known.

In this blog post, I would like to quickly discuss what symbolic integers are and use a simple example to show how to specify symbolic integers in the graph when exporting a PyTorch model using torch.export.

The TopK Example

Suppose my model only consists of one torch.topk layer. The output shape of the torch.topk layer depends on the input tensor meta-data and the k value. To allow dynamic shapes of the model that depends on the k value, the k value must be represented as a symbolic integer in the graph.

To specify the k value as a symbolic integer via torch.export, there are a few ways, some of which do not work.

K As An Integer

If we specify the k value as a Python integer, it will not work because the graph will treat it as a constant instead of a symbolic integer.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

class TopKModelWithInt(torch.nn.Module):
def forward(self, x, k):
# k is a Python int
values, indices = torch.topk(x, k=k, dim=-1)
return values, indices

model = TopKModelWithInt()
x = torch.randn(10, 100)
k_int = 5 # Python int, not a tensor

# Export with int argument
exported = torch.export.export(model, args=(x, k_int))

K Via Proxy Tensor

If we specify the k value as a proxy tensor, i.e., its value is encoded in the proxy tensor meta-data, it will work because integers in meta-data will be represented as symbolic integers in the graph.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

class TopKModelWithProxyTensor(torch.nn.Module):
def forward(self, x, k_proxy):
# k_proxy is a proxy tensor with meta-data that encodes the integer value
k = k_proxy.size(0) # Extract the integer value from the proxy tensor meta-data
values, indices = torch.topk(x, k=k, dim=-1)
return values, indices

model = TopKModelWithProxyTensor()
x = torch.randn(10, 100)
k_proxy = torch.empty(1) # Create a proxy tensor
k_proxy.resize_(5) # Prepare the proxy tensor meta-data to encode the integer value 5

# Create a symbolic dimension for the k value
k_dim = torch.export.Dim("k_value", min=1, max=100)

# Export with proxy tensor argument
exported = torch.export.export(model, args=(x, k_proxy), dynamic_shapes={"x": None, "k_proxy": {0: k_dim}})

K Via Tensor

The aforementioned approach of using a proxy tensor and dynamic shapes to specify the k value looks somewhat hacky. Consequently, PyTorch allows symbolic integers to be directly specified using tensors only if the tensor is the input of the graph.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

class TopKModelWithTensor(torch.nn.Module):
def forward(self, x, k_tensor):
# k_tensor is a tensor that directly specifies the symbolic integer value
k = k_tensor.item() # Extract the integer value from the tensor
values, indices = torch.topk(x, k=k, dim=-1)
return values, indices

model = TopKModelWithTensor()
x = torch.randn(10, 100)
k_tensor = torch.tensor(5) # Create a tensor that directly specifies the integer value

# Export with tensor argument
exported = torch.export.export(model, args=(x, k_tensor))

Note that item() is used to extract the integer value from the tensor and typically will cause graph breaks. This is because in many cases, the output meta-data of some operations that depends on it will be data-dependent. The torch.topk operation is one of such operations because its output shape depends on the k_tensor value. PyTorch made an exception for this case to allow the k_tensor value to be specified as a tensor input of the graph, by evaluating the k_tensor value at runtime so that the k value is known before graph execution. The downside is that if k_tensor is a tensor on GPU, it will cause a GPU to CPU copy and also notice that we could not specify the range of k values via the standard dynamic shapes mechanism as used in the previous proxy tensor approach. Other graph frameworks and runtimes, such as ONNX, usually do not allow this kind of behaviors and might evaluate the k_tensor value at export time and treat it as a constant.

If the k_tensor is not the immediate input of the graph, rather than being an intermediate tensor in the graph, then it is a graph break and PyTorch could not tolerate it.

Conclusions

There are pros and cons for the “K Via Proxy Tensor” approach and the “K Via Tensor” approach.

The “K Via Proxy Tensor” approach looks more grammatically correct to me and the graph can be wrong on pure device without GPU to CPU copy. However, if k is large, the user would have to construct a large proxy tensor whose only useful information is its meta-data and copy it to the device before graph execution, which can be inefficient. Of course, if the proxy tensor can just be constructed as a meta tensor without any actual data, then it will be very efficient.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

class TopKModelWithProxyTensor(torch.nn.Module):
def forward(self, x, k_proxy):
# k_proxy is a proxy tensor with meta-data that encodes the integer value
k = k_proxy.size(0) # Extract the integer value from the proxy tensor meta-data
values, indices = torch.topk(x, k=k, dim=-1)
return values, indices

model = TopKModelWithProxyTensor()
x = torch.randn(10, 100)
k_proxy = torch.empty(1, device='meta') # Create a proxy meta tensor
k_proxy.resize_(5) # Prepare the proxy tensor meta-data to encode the integer value 5

# Create a symbolic dimension for the k value
k_dim = torch.export.Dim("k_value", min=1, max=100)

# Export with proxy tensor argument
exported = torch.export.export(model, args=(x, k_proxy), dynamic_shapes={"x": None, "k_proxy": {0: k_dim}})

# Lower the exported model using AOTInductor
package_path = torch._inductor.aoti_compile_and_package(
exported,
package_path="model.pt2"
)

# Load the compiled package
lowered = torch._inductor.aoti_load_package(package_path)

# Run inference using meta tensor.
selected_from_exported = exported.module()(x, k_proxy)
selected_from_lowered = lowered(x, k_proxy)

The “K Via Tensor” approach looks weird because it appears that PyTorch can allow graph breaks in the graph, which is against the general principle of graph representation. However, it is usually more efficient because there is no need to construct a proxy tensor and copy it to the device before graph execution sometimes will not introduce too much overhead.

References

Author

Lei Mao

Posted on

04-05-2026

Updated on

04-05-2026

Licensed under


Comments