Save, Load and Inference From TensorFlow Frozen Graph
Introduction
TensorFlow model saving has become easier than it was in the early days. Now you can either use Keras
to save h5
format model or use tf.train.Saver
to save the check point files. Loading those saved models are also easy. You can find a lot of instructions on TensorFlow official tutorials. There is another model format called pb
which is frequently seen in model zoos but hardly mentioned by TensorFlow official channels. pb
stands for Protocol Buffers, it is a language-neutral, platform-neutral extensible mechanism for serializing structured data. It is widely used in model deployment, such as fast inference tool TensorRT. While pb
format models seem to be important, there is lack of systematic tutorials on how to save, load and do inference on pb
format models in TensorFlow.
In this blog post, I am going to introduce how to save, load, and run inference for frozen graph in TensorFlow 1.x. For doing the equivalent tasks in TensorFlow 2.x, please read the other blog post “Save, Load and Inference From TensorFlow 2.x Frozen Graph”.
Materials
This sample code was available on my GitHub. It was modified from my previous simple CNN model to classify CIFAR10 dataset.
Train Model
We have to train our model first. Train the model using the following command:
1 | $ python main.py --train --test --epoch 30 --lr_decay 0.9 --dropout 0.5 |
The test accuracy after training is around 0.793900.
Save PB Model
The major component of pb
file is graph structure and also the parameters of your model. While the parameters are optional for pb
file, you need it for our task since we need to use parameters to do inference. Otherwise, people download your pb
file and they will not be able to deploy it.
This is the key code to save pb
file:
1 | from tensorflow.python.tools import freeze_graph |
You are required to save checkpoint of your model first, followed by saving the graph. Saving checkpoint is easy, you just have to use tf.train.Saver
and everything should be straightforward. In my code, I wrapped saving checkpoint using tf.train.Saver
in self.save
method. Saving graph is to use tf.train.write_graph
. There are two arguments which might be confusing to the new users, name
and as_text
. as_text
is a boolean value indicating whether the saved graph is human-readable or not. By convention, if it is human-readable, the file extension we use will be .pbtxt
, else the file extension will be .pb
. But this pb
file will not contain the parameters you trained in your model.
We then need to freeze and combine graph and parameters to pb
file. There are two ways to freeze graph.
The first method is to use freeze_graph
function. The argument description of freeze_graph
could be found here. If input_graph
is human-readable pbtxt
file, input_binary
should be False
. If input_graph
is binary pb
file, input_binary
should be True
. You will also need to specify the name of your output node. It can be a string if you only have one output, or a list of strings if you have multiple outputs. restore_op_name
and filename_tensor_name
are being deprecated, using the values provided should be universal to all models. Leave the rest of the arguments the same as mine should be fine. The pb
file will be saved to output_graph
path you provided.
The second method is to serialization yourself. I believe the first method is just a higher-level wrapper for the second method. The pb
files generated from the two methods both pass the accuracy tests that I am going to show below.
The model files generated in the model
directory are the follows:
1 | . |
pb
file is there!
Load PB Model
We wrote a object to load model from pb
files.
1 | class CNN(object): |
Working with the models loaded from pb
files is a little bit painful since you will have to work with tensor names all the time. If you are not sure about the tensor names you are working with, try to print out the names from graph_def.node
. In our case, because we are going to do inference, we need to bind the inputs of the graph to some placeholder so that we can feed values into the model. Getting the values of parameters is also available via graph_def.node
. Here I attached two placeholder to the graph using tf.import_graph_def(graph_def, {'input': self.input, 'dropout_rate': self.dropout_rate})
. It should be noted that 'input'
and 'dropout_rate'
are the name of inputs in the graph I defined in the original graph.
We also set up the test
method. Simply find out the tensor you are interested in, in our case it is the output tensor, and feed the input values using sess.run
.
Inference from PB Model
To verify that our loaded graph is correct and working, we need to do some inference to test.
1 | def test_from_frozen_graph(model_filepath): |
Run the following command to test:
1 | $ python test_pb.py |
Here I tested 500 samples from the test set. If you want to test all the examples, you can write a for loop to do so. The test accuracy is 0.788000. Comparing to the test accuracy 0.793900 we got right after training, it suggests that the pb
file we saved is valid.
Updates
2019/9/16
Thanks to the question raised by Yuqiong Li. I removed the usage of tf.InteractiveSession
and replaced it with tf.Session
. The new object to load pb
file is as follows.
1 | class CNN(object): |
The previous one is nothing wrong, but I placed the tf.InteractiveSession
before the graphdef
was loaded to the default graph, taking advantage of the side effect that tf.InteractiveSession
will set its corresponding graph as the default graph globally. Therefore, simply replacing tf.InteractiveSession
to tf.Session
would not work in the previous implementation. This might cause some confusion from the readers who really wanted to understand what is happening underneath. In this new implementation, I specifically created the default graph using Python resource manager and loaded the graphdef
to the default graph. No side effect was used and therefore it should be much easier to understand.
2020/1/9
This blog and example were designed for TensorFlow 1.x. TensorFlow 2.x also supports the frozen graph. Please check the blog post “Save, Load and Inference From TensorFlow 2.x Frozen Graph”.
Final Remarks
Now you should be good to go with pb
file in our deployment!
One additional caveat is that TensorFlow is starting to deprecating or changing a lot of APIs, including part of freeze_graph
. We have to be kept updated on those functions.
Save, Load and Inference From TensorFlow Frozen Graph
https://leimao.github.io/blog/Save-Load-Inference-From-TF-Frozen-Graph/