Frozen graphs are commonly used for inference in TensorFlow and are stepping stones for inference for other frameworks. TensorFlow 1.x provided an interface to freeze models via
tf.Session, and I previously had a blog on how to use frozen models for inference in TensorFlow 1.x. However, since TensorFlow 2.x removed
tf.Session, freezing models in TensorFlow 2.x had been a problem for most of the users.
In this blog post, I am going to show how to save, load, and run inference for frozen graphs in TensorFlow 2.x.
This sample code was available on my GitHub. It was modified from the official TensorFlow 2.x Fashion MNIST Classification example.
Train Model and Export to Frozen Graph
We would train a simple fully connected neural network to classify the Fashion MNIST data. The model would be saved as
SavedModel in the
models directory for completeness. In addition, the model would also be frozen and saved as
frozen_graph.pb in the
To train and export the model, please run the following command in the terminal.
$ python train.py
We would also have a reference value for the sample inference from TensorFlow 2.x using the conventional inference protocol in the printouts.
Example prediction reference: [3.9113933e-05 1.1972898e-07 5.2244545e-06 5.4371812e-06 6.1125693e-06 1.1335548e-01 3.0090479e-05 2.8483599e-01 9.5160649e-04 6.0077089e-01]
The key to exporting the frozen graph is to convert the model to concrete function, extract and freeze graphs from the concrete function, and serialize to hard drive.
# Convert Keras model to ConcreteFunction full_model = tf.function(lambda x: model(x)) full_model = full_model.get_concrete_function( tf.TensorSpec(model.inputs.shape, model.inputs.dtype)) # Get frozen ConcreteFunction frozen_func = convert_variables_to_constants_v2(full_model) frozen_func.graph.as_graph_def() layers = [op.name for op in frozen_func.graph.get_operations()] print("-" * 50) print("Frozen model layers: ") for layer in layers: print(layer) print("-" * 50) print("Frozen model inputs: ") print(frozen_func.inputs) print("Frozen model outputs: ") print(frozen_func.outputs) # Save frozen graph from frozen ConcreteFunction to hard drive tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir="./frozen_models", name="frozen_graph.pb", as_text=False)
Run Inference Using Frozen Graph
To run inference using the frozen graph in TensorFlow 2.x, please run the following command in the terminal.
$ python test.py
We also got the value for the sample inference using frozen graph. It is (almost) exactly the same as the reference value we got using the conventional inference protocol.
Example prediction reference: [3.9113860e-05 1.1972921e-07 5.2244545e-06 5.4371812e-06 6.1125752e-06 1.1335552e-01 3.0090479e-05 2.8483596e-01 9.5160597e-04 6.0077089e-01]
Because frozen graph has been sort of being deprecated by TensorFlow, and
SavedModel format is encouraged to use, we would have to use the TensorFlow 1.x function to load the frozen graph from hard drive.
# Load frozen graph using TensorFlow 1.x functions with tf.io.gfile.GFile("./frozen_models/frozen_graph.pb", "rb") as f: graph_def = tf.compat.v1.GraphDef() loaded = graph_def.ParseFromString(f.read()) # Wrap frozen graph to ConcreteFunctions frozen_func = wrap_frozen_graph(graph_def=graph_def, inputs=["x:0"], outputs=["Identity:0"], print_graph=True)
Once the frozen graph is loaded, we convert the frozen graph to concrete function and run inference.
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False): def _imports_graph_def(): tf.compat.v1.import_graph_def(graph_def, name="") wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, ) import_graph = wrapped_import.graph print("-" * 50) print("Frozen model layers: ") layers = [op.name for op in import_graph.get_operations()] if print_graph == True: for layer in layers: print(layer) print("-" * 50) return wrapped_import.prune( tf.nest.map_structure(import_graph.as_graph_element, inputs), tf.nest.map_structure(import_graph.as_graph_element, outputs))
Convert Frozen Graph to ONNX
If TensorFlow 1.x and
tf2onnx have been installed, the frozen graph could be converted to ONNX model using the following command.
$ python -m tf2onnx.convert --input ./frozen_models/frozen_graph.pb --output model.onnx --outputs Identity:0 --inputs x:0
Convert Frozen Graph to UFF
The frozen graph could also be converted to UFF model for TensorRT using the following command.
$ convert-to-uff frozen_graph.pb -t -O Identity -o frozen_graph.uff
TensorRT 6.0 Docker image could be pulled from NVIDIA NGC.
$ docker pull nvcr.io/nvidia/tensorrt:19.12-py3
TensorFlow 2.x could also save, load, and run inference for frozen graphs. The frozen graphs from TensorFlow 2.x should be equivalent to the frozen graphs from TensorFlow 1.x.