Inspecting and Visualizing Torch FX Graph

Introduction

PyTorch modules can consist of nested modules. Unlike ONNX models, understanding a PyTorch module can be challenging because we will have to look into the source code of the module to understand its structure and operations. Torch FX graph is the intermediate representation (IR) of PyTorch modules. It can be inspected and visualized to help us understand the structure and operations of PyTorch modules.

In this blog post, I would like to quickly demonstrate how to inspect and visualize Torch FX graphs using Torch FX FxGraphDrawer on a simple multi-layer perceptron (MLP) module. I will also show how to use TorchFunctionMode and TorchDispatchMode to log the function calls and dispatches during the execution of the module.

Torch FX Graph Inspection and Visualization

The following program creates a simple MLP module with two linear layers and a ReLU activation function. It then performs Torch FX symbolic tracing, Torch Export to ATen IR, and Torch Export to Core ATen IR on the MLP module. The resulting graphs are printed and visualized using FxGraphDrawer. Finally, it uses TorchFunctionMode and TorchDispatchMode to log the function calls and dispatches during the execution of the module.

python torch_fx_graph_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.fx
from torch.fx.passes.graph_drawer import FxGraphDrawer

from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode


# https://docs.pytorch.org/docs/2.9/notes/extending.html#extending-all-torch-api-with-modes
class FunctionLog(TorchFunctionMode):

def __torch_function__(self, func, types, args, kwargs=None):
# print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
print(f"Function Log: {resolve_name(func)}")
return func(*args, **(kwargs or {}))


class DispatchLog(TorchDispatchMode):

def __torch_dispatch__(self, func, types, args, kwargs=None):
# print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
print(f"Dispatch Log: {func}")
return func(*args, **(kwargs or {}))


class MLP(torch.nn.Module):

def __init__(self, input_dim=16, hidden_dim=32, output_dim=16):
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x


if __name__ == "__main__":

input_dim = 16
hidden_dim = 32
output_dim = 16

module = MLP(input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim)
module.eval()

# Torch FX Symbolic Tracing
torch_symbolic_traced = torch.fx.symbolic_trace(module)
torch_symbolic_traced_graph_drawer = FxGraphDrawer(
torch_symbolic_traced, "mlp_symbolic_traced_graph_drawer")
print("Torch FX Symbolic Traced Graph:")
print(torch_symbolic_traced.graph)

# Get the graph object and write to a file
with open("mlp_symbolic_traced_graph.svg", "wb") as f:
f.write(
torch_symbolic_traced_graph_drawer.get_dot_graph().create_svg())

# Torch Export to ATen IR
args = (torch.randn(1, input_dim), )
exported_program = torch.export.export(module, args)
exported_aten_graph_module = exported_program.graph_module
print("MLP Exported ATen Graph:")
print(exported_aten_graph_module)

exported_aten_graph_drawer = FxGraphDrawer(
exported_aten_graph_module, "mlp_exported_aten_graph_drawer")
with open("mlp_exported_aten_graph.svg", "wb") as f:
f.write(exported_aten_graph_drawer.get_dot_graph().create_svg())

# Torch Export to Core ATen IR
core_aten_exported_program = exported_program.run_decompositions()
core_aten_exported_aten_graph_module = core_aten_exported_program.graph_module
print("MLP Core ATen Exported Graph:")
print(core_aten_exported_aten_graph_module)

core_aten_exported_aten_graph_drawer = FxGraphDrawer(
core_aten_exported_aten_graph_module,
"mlp_core_aten_exported_aten_graph_drawer")
with open("mlp_core_aten_exported_aten_graph.svg", "wb") as f:
f.write(
core_aten_exported_aten_graph_drawer.get_dot_graph().create_svg())

print("TorchFunctionMode logging:")
with torch.inference_mode(), FunctionLog():
# result = module(*args)
# result = torch_symbolic_traced(*args)
# result = exported_program.module()(*args)
result = core_aten_exported_program.module()(*args)

print("TorchDispatchMode logging:")
with torch.inference_mode(), DispatchLog():
# result = module(*args)
# result = torch_symbolic_traced(*args)
# result = exported_program.module()(*args)
result = core_aten_exported_program.module()(*args)

We could use NVIDIA PyTorch Docker container to run the PyTorch program.

1
$ docker run -it --rm --gpus all -v $(pwd):/mnt -w /mnt nvcr.io/nvidia/pytorch:25.11-py3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# graphviz and pydot are required for FxGraphDrawer
$ apt update && apt install -y graphviz
$ pip install pydot
$ python torch_fx_graph_mlp.py
Torch FX Symbolic Traced Graph:
graph():
%x : [num_users=1] = placeholder[target=x]
%fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%fc1,), kwargs = {})
%fc2 : [num_users=1] = call_module[target=fc2](args = (%relu,), kwargs = {})
return fc2
MLP Exported ATen Graph:
GraphModule()



def forward(self, p_fc1_weight, p_fc1_bias, p_fc2_weight, p_fc2_bias, x):
linear = torch.ops.aten.linear.default(x, p_fc1_weight, p_fc1_bias); x = p_fc1_weight = p_fc1_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, p_fc2_weight, p_fc2_bias); relu = p_fc2_weight = p_fc2_bias = None
return (linear_1,)

# To see more debug info, please use `graph_module.print_readable()`
MLP Core ATen Exported Graph:
GraphModule()



def forward(self, p_fc1_weight, p_fc1_bias, p_fc2_weight, p_fc2_bias, x):
permute = torch.ops.aten.permute.default(p_fc1_weight, [1, 0]); p_fc1_weight = None
addmm = torch.ops.aten.addmm.default(p_fc1_bias, x, permute); p_fc1_bias = x = permute = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
permute_1 = torch.ops.aten.permute.default(p_fc2_weight, [1, 0]); p_fc2_weight = None
addmm_1 = torch.ops.aten.addmm.default(p_fc2_bias, relu, permute_1); p_fc2_bias = relu = permute_1 = None
return (addmm_1,)

# To see more debug info, please use `graph_module.print_readable()`
TorchFunctionMode logging:
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.grad_fn.__get__
Function Log: torch.Tensor.shape.__get__
Function Log: torch.Tensor.size
Function Log: None
Function Log: torch.Tensor.size
Function Log: None
Function Log: aten.permute.default
Function Log: aten.addmm.default
Function Log: aten.relu.default
Function Log: aten.permute.default
Function Log: aten.addmm.default
TorchDispatchMode logging:
Dispatch Log: aten.permute.default
Dispatch Log: aten.addmm.default
Dispatch Log: aten.relu.default
Dispatch Log: aten.permute.default
Dispatch Log: aten.addmm.default

The symbolic traced graph uses Torch high-level torch.nn.module to describe the module structure and operations. Because input tensors are specified during the symbolic tracing, it is unknown what the shapes of the tensors, especially the intermediate tensors, are.

Symbolic Traced Graph

The torch.export is a PyTorch 2.x API used for ahead-of-time (AOT) whole-graph capture of a model into an ExportedProgram object. It requires example inputs to run the model and captures the operations used during the execution of the model. The exported ATen graph uses ATen operators to describe the module structure and operations. The shapes of the tensors are known because example inputs are provided during the export. In this case, aten.linear operator is used to represent the nn.Linear module, and aten.relu operator is used to represent the nn.ReLU module.

Exported ATen Graph

Core ATen ops is the core subset of ATen operators that can be used to compose other operators. Core ATen IR is fully functional, i.e., there are no side-effects, and there is no inplace or _out variants in this opset. The previously exported ATen graph can be further lowered to Core ATen IR using run_decompositions, which decomposes higher-level ATen operators into Core ATen operators. In this case, aten.linear operator is decomposed into aten.permute and aten.addmm operators, while aten.relu operator remains the same.

Core ATen Exported ATen Graph

The torch.export also supports exporting modules with dynamic control flow using torch.cond and torch.loop. The branches of the control flow, however, will be saved as subgraphs in the exported graph module. The following program demonstrates a simple MLP module with dynamic control flow using torch.cond.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class DynamicControlFlowMLP(torch.nn.Module):

def __init__(self, input_dim=16, hidden_dim=32, output_dim=16):
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
# Dynamic control flow based on input using torch.cond
# torch.cond captures both branches in torch.export
def true_fn(x):
return self.fc2(x)

def false_fn(x):
return self.fc3(x)

condition = x.sum() > 0
x = torch.cond(condition, true_fn, false_fn, (x, ))
return x

After exporting the above module using torch.export, we can see that the true and false branches of the torch.cond are saved as subgraphs true_graph_0 and false_graph_0 in the exported graph module GraphModule, whereas previously the GraphModule do not have any subgraphs.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
GraphModule(
(true_graph_0): <lambda>()
(false_graph_0): <lambda>()
)



def forward(self, p_fc1_weight, p_fc1_bias, p_fc2_weight, p_fc2_bias, p_fc3_weight, p_fc3_bias, x):
linear = torch.ops.aten.linear.default(x, p_fc1_weight, p_fc1_bias); x = p_fc1_weight = p_fc1_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
sum_1 = torch.ops.aten.sum.default(relu)
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (relu, p_fc2_bias, p_fc2_weight, p_fc3_bias, p_fc3_weight)); gt = true_graph_0 = false_graph_0 = relu = p_fc2_bias = p_fc2_weight = p_fc3_bias = p_fc3_weight = None
getitem = cond[0]; cond = None
return (getitem,)

Saving the exported program to graph for visualization using FxGraphDrawer will not include the subgraphs. The subgraphs have to be saved separately by calling FxGraphDrawer on each subgraph.

It might be very desirable to have a visualization tool that can visualize a PyTorch export model, just like what Netron does for ONNX models, especially for complex PyTorch models with dynamic control flows.

References

Author

Lei Mao

Posted on

12-31-2025

Updated on

12-31-2025

Licensed under


Comments