Save, Load and Inference From TensorFlow 2.x Frozen Graph
Introduction
Frozen graphs are commonly used for inference in TensorFlow and are stepping stones for inference for other frameworks. TensorFlow 1.x provided an interface to freeze models via tf.Session
, and I previously had a blog on how to use frozen models for inference in TensorFlow 1.x. However, since TensorFlow 2.x removed tf.Session
, freezing models in TensorFlow 2.x had been a problem for most of the users.
In this blog post, I am going to show how to save, load, and run inference for frozen graphs in TensorFlow 2.x.
Materials
This sample code was available on my GitHub. It was modified from the official TensorFlow 2.x Fashion MNIST Classification example.
Train Model and Export to Frozen Graph
We would train a simple fully connected neural network to classify the Fashion MNIST data. The model would be saved as SavedModel
in the models
directory for completeness. In addition, the model would also be frozen and saved as frozen_graph.pb
in the frozen_models
directory.
To train and export the model, please run the following command in the terminal.
1 | $ python train.py |
We would also have a reference value for the sample inference from TensorFlow 2.x using the conventional inference protocol in the printouts.
1 | Example prediction reference: |
The key to exporting the frozen graph is to convert the model to concrete function, extract and freeze graphs from the concrete function, and serialize to hard drive.
1 | # Convert Keras model to ConcreteFunction |
Run Inference Using Frozen Graph
To run inference using the frozen graph in TensorFlow 2.x, please run the following command in the terminal.
1 | $ python test.py |
We also got the value for the sample inference using frozen graph. It is (almost) exactly the same as the reference value we got using the conventional inference protocol.
1 | Example prediction reference: |
Because frozen graph has been sort of being deprecated by TensorFlow, and SavedModel
format is encouraged to use, we would have to use the TensorFlow 1.x function to load the frozen graph from hard drive.
1 | # Load frozen graph using TensorFlow 1.x functions |
Once the frozen graph is loaded, we convert the frozen graph to concrete function and run inference.
1 | def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False): |
Convert Frozen Graph to ONNX
If TensorFlow 1.x and tf2onnx
have been installed, the frozen graph could be converted to ONNX model using the following command.
1 | $ python -m tf2onnx.convert --input ./frozen_models/frozen_graph.pb --output model.onnx --outputs Identity:0 --inputs x:0 |
Convert Frozen Graph to UFF
The frozen graph could also be converted to UFF model for TensorRT using the following command.
1 | $ convert-to-uff frozen_graph.pb -t -O Identity -o frozen_graph.uff |
TensorRT 6.0 Docker image could be pulled from NVIDIA NGC.
1 | $ docker pull nvcr.io/nvidia/tensorrt:19.12-py3 |
Conclusions
TensorFlow 2.x could also save, load, and run inference for frozen graphs. The frozen graphs from TensorFlow 2.x should be equivalent to the frozen graphs from TensorFlow 1.x.
References
Save, Load and Inference From TensorFlow 2.x Frozen Graph
https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/