Creating and Modifying ONNX Model Using ONNX Python API

Introduction

Open Neural Network Exchange (ONNX) is an open standard format for representing machine learning models. ONNX is the most widely used machine learning model format, supported by a community of partners who have implemented it in many frameworks and tools.

In this blog post, I would like to discuss how to use the ONNX Python API to create and modify ONNX models.

ONNX Data Structure

ONNX model is represented using protocol buffers. Specifically, the entire model information was encoded using the onnx.proto.

The major ONNX protocol buffers to describe a neural network are ModelProto, GraphProto, NodeProto, TensorProto, ValueInfoProto.

Key ONNX Protos Description
ModelProto It contains model description and GraphProto.
GraphProto It contains the node information, node initializers, and IO tensors in the model.
NodeProto It represents a node in the model. It contains the input and output tensor names, node initializers, and node attributes.
TensorProto It represents an node initializer (constant tensor in the node). In addition to the data type and shape, specific values were assigned.
ValueInfoProto It represents an IO tensor in the model in which only the data type and shape were defined.

Creating ONNX Model

To better understand the ONNX protocol buffers, let’s create a dummy convolutional classification neural network, consisting of convolution, batch normalization, ReLU, average pooling layers, from scratch using ONNX Python API (ONNX helper functions onnx.helper).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import numpy as np
import onnx


def create_initializer_tensor(
name: str,
tensor_array: np.ndarray,
data_type: onnx.TensorProto = onnx.TensorProto.FLOAT
) -> onnx.TensorProto:

# (TensorProto)
initializer_tensor = onnx.helper.make_tensor(
name=name,
data_type=data_type,
dims=tensor_array.shape,
vals=tensor_array.flatten().tolist())

return initializer_tensor


def main() -> None:

# Create a dummy convolutional neural network.

# IO tensors (ValueInfoProto).
model_input_name = "X"
X = onnx.helper.make_tensor_value_info(model_input_name,
onnx.TensorProto.FLOAT,
[None, 3, 32, 32])
model_output_name = "Y"
model_output_channels = 10
Y = onnx.helper.make_tensor_value_info(model_output_name,
onnx.TensorProto.FLOAT,
[None, model_output_channels, 1, 1])

# Create a Conv node (NodeProto).
# https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#conv
conv1_output_node_name = "Conv1_Y"
# Dummy weights for conv.
conv1_in_channels = 3
conv1_out_channels = 32
conv1_kernel_shape = (3, 3)
conv1_pads = (1, 1, 1, 1)
conv1_W = np.ones(shape=(conv1_out_channels, conv1_in_channels,
*conv1_kernel_shape)).astype(np.float32)
conv1_B = np.ones(shape=(conv1_out_channels)).astype(np.float32)
# Create the initializer tensor for the weights.
conv1_W_initializer_tensor_name = "Conv1_W"
conv1_W_initializer_tensor = create_initializer_tensor(
name=conv1_W_initializer_tensor_name,
tensor_array=conv1_W,
data_type=onnx.TensorProto.FLOAT)
conv1_B_initializer_tensor_name = "Conv1_B"
conv1_B_initializer_tensor = create_initializer_tensor(
name=conv1_B_initializer_tensor_name,
tensor_array=conv1_B,
data_type=onnx.TensorProto.FLOAT)

conv1_node = onnx.helper.make_node(
name="Conv1", # Name is optional.
op_type="Conv",
# Must follow the order of input and output definitions.
# https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3
inputs=[
model_input_name, conv1_W_initializer_tensor_name,
conv1_B_initializer_tensor_name
],
outputs=[conv1_output_node_name],
# The following arguments are attributes.
kernel_shape=conv1_kernel_shape,
# Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1
pads=conv1_pads,
)

# Create a BatchNorm node (NodeProto).
bn1_output_node_name = "BN1_Y"
# Dummy paramters for batchnorm.
bn1_scale = np.random.randn(conv1_out_channels).astype(np.float32)
bn1_bias = np.random.randn(conv1_out_channels).astype(np.float32)
bn1_mean = np.random.randn(conv1_out_channels).astype(np.float32)
bn1_var = np.random.rand(conv1_out_channels).astype(np.float32)
# Create the initializer tensors.
bn1_scale_initializer_tensor_name = "BN1_Scale"
bn1_bias_initializer_tensor_name = "BN1_Bias"
bn1_mean_initializer_tensor_name = "BN1_Mean"
bn1_var_initializer_tensor_name = "BN1_Var"
bn1_scale_initializer_tensor = create_initializer_tensor(
name=bn1_scale_initializer_tensor_name,
tensor_array=bn1_scale,
data_type=onnx.TensorProto.FLOAT)
bn1_bias_initializer_tensor = create_initializer_tensor(
name=bn1_bias_initializer_tensor_name,
tensor_array=bn1_bias,
data_type=onnx.TensorProto.FLOAT)
bn1_mean_initializer_tensor = create_initializer_tensor(
name=bn1_mean_initializer_tensor_name,
tensor_array=bn1_mean,
data_type=onnx.TensorProto.FLOAT)
bn1_var_initializer_tensor = create_initializer_tensor(
name=bn1_var_initializer_tensor_name,
tensor_array=bn1_var,
data_type=onnx.TensorProto.FLOAT)

bn1_node = onnx.helper.make_node(
name="BN1", # Name is optional.
op_type="BatchNormalization",
inputs=[
conv1_output_node_name, bn1_scale_initializer_tensor_name,
bn1_bias_initializer_tensor_name, bn1_mean_initializer_tensor_name,
bn1_var_initializer_tensor_name
],
outputs=[bn1_output_node_name],
)

# Create a ReLU node (NodeProto).
relu1_output_node_name = "ReLU1_Y"

relu1_node = onnx.helper.make_node(
name="ReLU1", # Name is optional.
op_type="Relu",
inputs=[bn1_output_node_name],
outputs=[relu1_output_node_name],
)

# Create a GlobalAveragePool node (NodeProto).
avg_pool1_output_node_name = "Avg_Pool1_Y"

avg_pool1_node = onnx.helper.make_node(
name="Avg_Pool1", # Name is optional.
op_type="GlobalAveragePool",
inputs=[relu1_output_node_name],
outputs=[avg_pool1_output_node_name],
)

# Create a Conv node (NodeProto).
# https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#conv
# Dummy weights for conv.
conv2_in_channels = conv1_out_channels
conv2_out_channels = model_output_channels
conv2_kernel_shape = (1, 1)
conv2_pads = (0, 0, 0, 0)
conv2_W = np.ones(shape=(conv2_out_channels, conv2_in_channels,
*conv2_kernel_shape)).astype(np.float32)
conv2_B = np.ones(shape=(conv2_out_channels)).astype(np.float32)
# Create the initializer tensor for the weights.
conv2_W_initializer_tensor_name = "Conv2_W"
conv2_W_initializer_tensor = create_initializer_tensor(
name=conv2_W_initializer_tensor_name,
tensor_array=conv2_W,
data_type=onnx.TensorProto.FLOAT)
conv2_B_initializer_tensor_name = "Conv2_B"
conv2_B_initializer_tensor = create_initializer_tensor(
name=conv2_B_initializer_tensor_name,
tensor_array=conv2_B,
data_type=onnx.TensorProto.FLOAT)

conv2_node = onnx.helper.make_node(
name="Conv2",
op_type="Conv",
inputs=[
avg_pool1_output_node_name, conv2_W_initializer_tensor_name,
conv2_B_initializer_tensor_name
],
outputs=[model_output_name],
kernel_shape=conv2_kernel_shape,
pads=conv2_pads,
)

# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[conv1_node, bn1_node, relu1_node, avg_pool1_node, conv2_node],
name="ConvNet",
inputs=[X], # Graph input
outputs=[Y], # Graph output
initializer=[
conv1_W_initializer_tensor, conv1_B_initializer_tensor,
bn1_scale_initializer_tensor, bn1_bias_initializer_tensor,
bn1_mean_initializer_tensor, bn1_var_initializer_tensor,
conv2_W_initializer_tensor, conv2_B_initializer_tensor
],
)

# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name="onnx-example")
model_def.opset_import[0].version = 13

model_def = onnx.shape_inference.infer_shapes(model_def)

onnx.checker.check_model(model_def)

onnx.save(model_def, "convnet.onnx")


if __name__ == "__main__":

main()

Once the ONNX model is created, we can further verify the model using ONNX Runtime.

Modifying ONNX Model

Modifying ONNX model is a little bit complicated since all the information were encoded as protocol buffers and there is no ONNX helper function to modify the protocol buffers.

Fortunately, we can assign values to the non-repeated attributes in the onnx.proto directly. For the repeated attributes, we cannot assign new values to it, but we are allowed to modify the values in place, or use Python binding interface, such as extend and pop, to add and remove items.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import onnx
from typing import Iterable


def print_tensor_data(initializer: onnx.TensorProto) -> None:

if initializer.data_type == onnx.TensorProto.DataType.FLOAT:
print(initializer.float_data)
elif initializer.data_type == onnx.TensorProto.DataType.INT32:
print(initializer.int32_data)
elif initializer.data_type == onnx.TensorProto.DataType.INT64:
print(initializer.int64_data)
elif initializer.data_type == onnx.TensorProto.DataType.DOUBLE:
print(initializer.double_data)
elif initializer.data_type == onnx.TensorProto.DataType.UINT64:
print(initializer.uint64_data)
else:
raise NotImplementedError

return


def dims_prod(dims: Iterable) -> int:

prod = 1
for dim in dims:
prod *= dim

return prod


def main() -> None:

model = onnx.load("convnet.onnx")
onnx.checker.check_model(model)

graph_def = model.graph

initializers = graph_def.initializer

# Modify initializer
for initializer in initializers:
# Data type:
# https://github.com/onnx/onnx/blob/rel-1.9.0/onnx/onnx.proto
print("Tensor information:")
print(
f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}"
)
print("Tensor value before modification:")
print_tensor_data(initializer)
# Replace the value with new value.
if initializer.data_type == onnx.TensorProto.DataType.FLOAT:
for i in range(dims_prod(initializer.dims)):
initializer.float_data[i] = 2
print("Tensor value after modification:")
print_tensor_data(initializer)
# If we want to change the data type and dims, we need to create new tensors from scratch.
# onnx.helper.make_tensor

# Modify nodes
nodes = graph_def.node
for node in nodes:
print(node.name)
print(node.op_type)
print(node.input)
print(node.output)
# Modify batchnorm attributes.
if node.op_type == "BatchNormalization":
print("Attributes before adding:")
for attribute in node.attribute:
print(attribute)
# Add epislon for the BN nodes.
epsilon_attribute = onnx.helper.make_attribute("epsilon", 1e-06)
node.attribute.extend([epsilon_attribute])
# node.attribute.pop() # Pop an attribute if necessary.
print("Attributes after adding:")
for attribute in node.attribute:
print(attribute)

inputs = graph_def.input
for graph_input in inputs:
input_shape = []
for d in graph_input.type.tensor_type.shape.dim:
if d.dim_value == 0:
input_shape.append(None)
else:
input_shape.append(d.dim_value)
print(
f"Input Name: {graph_input.name}, Input Data Type: {graph_input.type.tensor_type.elem_type}, Input Shape: {input_shape}"
)

outputs = graph_def.output
for graph_output in outputs:
output_shape = []
for d in graph_output.type.tensor_type.shape.dim:
if d.dim_value == 0:
output_shape.append(None)
else:
output_shape.append(d.dim_value)
print(
f"Output Name: {graph_output.name}, Output Data Type: {graph_output.type.tensor_type.elem_type}, Output Shape: {output_shape}"
)

# To modify inputs and outputs, we would rather create new inputs and outputs.
# Using onnx.helper.make_tensor_value_info and onnx.helper.make_model

onnx.checker.check_model(model)
onnx.save(model, "convnets_modified.onnx")


if __name__ == "__main__":

main()

Source Code

The source code of the implementation is available on GitHub.

Miscellaneous

Dealing with ONNX protocol buffer is complicated and error-prone. The ONNX protocol buffer representation also depends on ONNX IR version and opset version. It will be desirable in some scenarios if we can have a high-level abstracted interface that allows us to modify the model without having to going through the low-level data structure. Fortunately, we have ONNX GraphSurgeon that can help us to do so.

References

Creating and Modifying ONNX Model Using ONNX Python API

https://leimao.github.io/blog/ONNX-Python-API/

Author

Lei Mao

Posted on

09-15-2021

Updated on

09-15-2021

Licensed under


Comments