TensorFlow Inference for Estimator
Introduction
Although I am not a big fan of the high-level API Estimator
in TensorFlow, there are more and more models using Estimator
to do training, evaluation, and inference now. Because everything was wrapped up, the machine learning process becomes less transparent. This is likely to cause a lot of trouble to engineer the fast inference process, especially when people are not familiar with the high-level APIs.
In this blog post, I am going to provide comprehensive guidance on how to set up fast inference protocols for TensorFlow models based on Estimator
.
Repository
The sample code for this tutorial was forked from Guillaume Genthial’s tf-estimator-basics, with some modifications. All the tests were conducted using an NVIDIA RTX 2080 TI graphic card.
To know more about the details of the model, please check Guillaume Genthial’s blog post.
Before starting to do inference tests, please train the model by running the following command in the terminal.
1 | $ python train.py |
please also export the model to SavedModel
by running the following command in the terminal.
1 | $ python export.py |
I have also provided the pre-trained ckpt
model and SavedModel
in the GitHub repository.
Fast Inference Protocols
TensorFlow Estimator
uses predict
method to do inference. The predict
method needs to take input_fn
which will return a input from a generator to the model upon being called. Without orchestration, if new data comes in batches, we would have to create input_fn
for each batch of the new data, and run the predict
method. The predict
method will return a generator to the prediction values corresponding to the input values generated from input_fn
.
The problem is that TensorFlow will create a graph and load all the parameters of the model when predict
is being called. Once input_fn
raises an end-of-input exception during the function call, TensorFlow will destroy the graph and release the memory for all the parameters. This overhead process will take a very long time. During inference, if we create input_fn
for each batch of the new data, the overhead process will make the inference extremely slow.
There are generally two ways to make the inference of Estimator
based models faster, including using predict
while keeping the graph alive all the time, and converting Estimator
based models to SavedModel
and serve.
Keeping Graph Alive
As I mentioned previously, the graph will be destroyed when input_fn
raises an end-of-input exception during a function call. So if the input_fn
uses an indefinite generator, the input_fn
will never raise an end-of-input exception. Therefore, the graph will be alive all the time. So designing such an indefinite generator is very important.
I have tested the vanilla Estimator predict by running predict.py
in the repository. It takes 0.1152 seconds per example using a batch size of 1, which is extremely slow.
Marc Stogaitis has implemented a FastPredict as a wrapper for the predict
method of Estimator
, using an indefinite generator. I have applied his wrapper to the same model, and tested it by running fast_predict.py
. It takes 0.352 milliseconds per example using a batch size of 1, which is extremely fast. However, the shortcoming of his interface is that it only allows exactly one example at one time.
I modified Marc Stogaitis’s interface implementation such that it allows multiple examples to be fed at each time, although the inference was still done using a batch size of 1. I have tested it by running faster_predict.py
. It takes 0.238 milliseconds per example using a batch size of 1, which is 30% faster than Marc Stogaitis’s implementation somehow.
Inference on SavedModel
Guillaume Genthial has talked about exporting the model to SavedModel
and doing inference on it using predictor
from tf.contrib
in his blog post. I am not going to elaborate too much on it. I have tested it by running serve.py
. It takes 0.178 milliseconds per example using a batch size of 1, which is 40% faster than my Estimator
predict
solution.
Changing Prediction Tensors
Sometimes, you would like to change the default output tensors from the original settings. For example, you would like to extract some hidden layer tensors. You can change the model_fn
function passed to Estimator
when loading the model using
1 | estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) |
The output node in the estimator
which was built using the following model_fn
is predictions
tensor, and its name in the graph is 'output'
by default.
1 | def model_fn(features, labels, mode, params): |
To add more output nodes, we passed {'hidden':hidden, 'predictions':predictions}
a dictionary. Here the name of output nodes in the graph are 'hidden'
and 'predictions'
, respectively.
1 | def model_fn(features, labels, mode, params): |
If using predictor
to do inference for SavedModel
, we simply extract the values from the dictionary using the output node names.
We change the serve.py
from
1 | for nb in my_service(): |
to
1 | for nb in my_service(): |
Conclusions
Inference using SavedModel
is a better inference protocol compared to Estimator
based predict
.
Final Remarks
The internal implementation of the predictor
class from tf.contrib
consists of a bunch of input nodes, output nodes, and sessions to obtained the values from inference. However, in TensorFlow 2.0, there will be no tf.contrib
and TensorFlow session
will not be exposed to users. Doing inference using Estimator
‘s predict
method with a living graph would still work, but we will not be able to use predictor
anymore for the SavedModel
. Fortunately, TensorFlow 2.0 has a official tutorial on this which is simple and straightforward. I will probably elaborate on this when TensorFlow 2.0 comes out officially and it is very necessary.
References
TensorFlow Inference for Estimator
https://leimao.github.io/blog/TensorFlow-Estimator-SavedModel/