TensorFlow Inference for Estimator

Introduction

Although I am not a big fan of the high-level API Estimator in TensorFlow, there are more and more models using Estimator to do training, evaluation, and inference now. Because everything was wrapped up, the machine learning process becomes less transparent. This is likely to cause a lot of trouble to engineer the fast inference process, especially when people are not familiar with the high-level APIs.

In this blog post, I am going to provide comprehensive guidance on how to set up fast inference protocols for TensorFlow models based on Estimator.

Repository

The sample code for this tutorial was forked from Guillaume Genthial’s tf-estimator-basics, with some modifications. All the tests were conducted using an NVIDIA RTX 2080 TI graphic card.

To know more about the details of the model, please check Guillaume Genthial’s blog post.

Before starting to do inference tests, please train the model by running the following command in the terminal.

1
$ python train.py

please also export the model to SavedModel by running the following command in the terminal.

1
$ python export.py

I have also provided the pre-trained ckpt model and SavedModel in the GitHub repository.

Fast Inference Protocols

TensorFlow Estimator uses predict method to do inference. The predict method needs to take input_fn which will return a input from a generator to the model upon being called. Without orchestration, if new data comes in batches, we would have to create input_fn for each batch of the new data, and run the predict method. The predict method will return a generator to the prediction values corresponding to the input values generated from input_fn.

The problem is that TensorFlow will create a graph and load all the parameters of the model when predict is being called. Once input_fn raises an end-of-input exception during the function call, TensorFlow will destroy the graph and release the memory for all the parameters. This overhead process will take a very long time. During inference, if we create input_fn for each batch of the new data, the overhead process will make the inference extremely slow.

There are generally two ways to make the inference of Estimator based models faster, including using predict while keeping the graph alive all the time, and converting Estimator based models to SavedModel and serve.

Keeping Graph Alive

As I mentioned previously, the graph will be destroyed when input_fn raises an end-of-input exception during a function call. So if the input_fn uses an indefinite generator, the input_fn will never raise an end-of-input exception. Therefore, the graph will be alive all the time. So designing such an indefinite generator is very important.

I have tested the vanilla Estimator predict by running predict.py in the repository. It takes 0.1152 seconds per example using a batch size of 1, which is extremely slow.

Marc Stogaitis has implemented a FastPredict as a wrapper for the predict method of Estimator, using an indefinite generator. I have applied his wrapper to the same model, and tested it by running fast_predict.py. It takes 0.352 milliseconds per example using a batch size of 1, which is extremely fast. However, the shortcoming of his interface is that it only allows exactly one example at one time.

I modified Marc Stogaitis’s interface implementation such that it allows multiple examples to be fed at each time, although the inference was still done using a batch size of 1. I have tested it by running faster_predict.py. It takes 0.238 milliseconds per example using a batch size of 1, which is 30% faster than Marc Stogaitis’s implementation somehow.

Inference on SavedModel

Guillaume Genthial has talked about exporting the model to SavedModel and doing inference on it using predictor from tf.contrib in his blog post. I am not going to elaborate too much on it. I have tested it by running serve.py. It takes 0.178 milliseconds per example using a batch size of 1, which is 40% faster than my Estimator predict solution.

Changing Prediction Tensors

Sometimes, you would like to change the default output tensors from the original settings. For example, you would like to extract some hidden layer tensors. You can change the model_fn function passed to Estimator when loading the model using

1
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

The output node in the estimator which was built using the following model_fn is predictions tensor, and its name in the graph is 'output' by default.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def model_fn(features, labels, mode, params):
# pylint: disable=unused-argument
"""Dummy model_fn"""
if isinstance(features, dict): # For serving
features = features['feature']

hidden = tf.layers.dense(features, 4)
predictions = tf.layers.dense(hidden, 1)

if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
else:
loss = tf.nn.l2_loss(predictions - labels)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode, loss=loss)

elif mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.AdamOptimizer(learning_rate=0.5).minimize(
loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode, loss=loss, train_op=train_op)
else:
raise NotImplementedError()

To add more output nodes, we passed {'hidden':hidden, 'predictions':predictions} a dictionary. Here the name of output nodes in the graph are 'hidden' and 'predictions', respectively.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def model_fn(features, labels, mode, params):
# pylint: disable=unused-argument
"""Dummy model_fn"""
if isinstance(features, dict): # For serving
features = features['feature']

hidden = tf.layers.dense(features, 4)
predictions = tf.layers.dense(hidden, 1)

if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions={'hidden':hidden, 'predictions':predictions})
else:
loss = tf.nn.l2_loss(predictions - labels)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode, loss=loss)

elif mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.AdamOptimizer(learning_rate=0.5).minimize(
loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode, loss=loss, train_op=train_op)
else:
raise NotImplementedError()

If using predictor to do inference for SavedModel, we simply extract the values from the dictionary using the output node names.

We change the serve.py from

1
2
3
for nb in my_service():
count += 1
pred = predict_fn({'number': [[nb]]})['output']

to

1
2
3
4
5
for nb in my_service():
count += 1
pred = predict_fn({'number': [[nb]]})
hidden = pred['hidden']
predictions = pred['predictions']

Conclusions

Inference using SavedModel is a better inference protocol compared to Estimator based predict.

Final Remarks

The internal implementation of the predictor class from tf.contrib consists of a bunch of input nodes, output nodes, and sessions to obtained the values from inference. However, in TensorFlow 2.0, there will be no tf.contrib and TensorFlow session will not be exposed to users. Doing inference using Estimator‘s predict method with a living graph would still work, but we will not be able to use predictor anymore for the SavedModel. Fortunately, TensorFlow 2.0 has a official tutorial on this which is simple and straightforward. I will probably elaborate on this when TensorFlow 2.0 comes out officially and it is very necessary.

References

Author

Lei Mao

Posted on

08-29-2019

Updated on

08-29-2019

Licensed under


Comments