# TensorFlow Inference for Estimator

## Introduction

Although I am not a big fan of the high-level API `Estimator`

in TensorFlow, there are more and more models using `Estimator`

to do training, evaluation, and inference now. Because everything was wrapped up, the machine learning process becomes less transparent. This is likely to cause a lot of trouble to engineer the fast inference process, especially when people are not familiar with the high-level APIs.

In this blog post, I am going to provide comprehensive guidance on how to set up fast inference protocols for TensorFlow models based on `Estimator`

.

## Repository

The sample code for this tutorial was forked from Guillaume Genthial’s tf-estimator-basics, with some modifications. All the tests were conducted using an NVIDIA RTX 2080 TI graphic card.

To know more about the details of the model, please check Guillaume Genthial’s blog post.

Before starting to do inference tests, please train the model by running the following command in the terminal.

1 | $ python train.py |

please also export the model to `SavedModel`

by running the following command in the terminal.

1 | $ python export.py |

I have also provided the pre-trained `ckpt`

model and `SavedModel`

in the GitHub repository.

## Fast Inference Protocols

TensorFlow `Estimator`

uses `predict`

method to do inference. The `predict`

method needs to take `input_fn`

which will return a input from a generator to the model upon being called. Without orchestration, if new data comes in batches, we would have to create `input_fn`

for each batch of the new data, and run the `predict`

method. The `predict`

method will return a generator to the prediction values corresponding to the input values generated from `input_fn`

.

The problem is that TensorFlow will create a graph and load all the parameters of the model when `predict`

is being called. Once `input_fn`

raises an end-of-input exception during the function call, TensorFlow will destroy the graph and release the memory for all the parameters. This overhead process will take a very long time. During inference, if we create `input_fn`

for each batch of the new data, the overhead process will make the inference extremely slow.

There are generally two ways to make the inference of `Estimator`

based models faster, including using `predict`

while keeping the graph alive all the time, and converting `Estimator`

based models to `SavedModel`

and serve.

### Keeping Graph Alive

As I mentioned previously, the graph will be destroyed when `input_fn`

raises an end-of-input exception during a function call. So if the `input_fn`

uses an indefinite generator, the `input_fn`

will never raise an end-of-input exception. Therefore, the graph will be alive all the time. So designing such an indefinite generator is very important.

I have tested the vanilla Estimator predict by running `predict.py`

in the repository. It takes 0.1152 seconds per example using a batch size of 1, which is extremely slow.

Marc Stogaitis has implemented a FastPredict as a wrapper for the `predict`

method of `Estimator`

, using an indefinite generator. I have applied his wrapper to the same model, and tested it by running `fast_predict.py`

. It takes 0.352 milliseconds per example using a batch size of 1, which is extremely fast. However, the shortcoming of his interface is that it only allows exactly one example at one time.

I modified Marc Stogaitis’s interface implementation such that it allows multiple examples to be fed at each time, although the inference was still done using a batch size of 1. I have tested it by running `faster_predict.py`

. It takes 0.238 milliseconds per example using a batch size of 1, which is 30% faster than Marc Stogaitis’s implementation somehow.

### Inference on SavedModel

Guillaume Genthial has talked about exporting the model to `SavedModel`

and doing inference on it using `predictor`

from `tf.contrib`

in his blog post. I am not going to elaborate too much on it. I have tested it by running `serve.py`

. It takes 0.178 milliseconds per example using a batch size of 1, which is 40% faster than my `Estimator`

`predict`

solution.

### Changing Prediction Tensors

Sometimes, you would like to change the default output tensors from the original settings. For example, you would like to extract some hidden layer tensors. You can change the `model_fn`

function passed to `Estimator`

when loading the model using

1 | estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) |

The output node in the `estimator`

which was built using the following `model_fn`

is `predictions`

tensor, and its name in the graph is `'output'`

by default.

1 | def model_fn(features, labels, mode, params): |

To add more output nodes, we passed `{'hidden':hidden, 'predictions':predictions}`

a dictionary. Here the name of output nodes in the graph are `'hidden'`

and `'predictions'`

, respectively.

1 | def model_fn(features, labels, mode, params): |

If using `predictor`

to do inference for `SavedModel`

, we simply extract the values from the dictionary using the output node names.

We change the `serve.py`

from

1 | for nb in my_service(): |

to

1 | for nb in my_service(): |

## Conclusions

Inference using `SavedModel`

is a better inference protocol compared to `Estimator`

based `predict`

.

## Final Remarks

The internal implementation of the `predictor`

class from `tf.contrib`

consists of a bunch of input nodes, output nodes, and sessions to obtained the values from inference. However, in TensorFlow 2.0, there will be no `tf.contrib`

and TensorFlow `session`

will not be exposed to users. Doing inference using `Estimator`

‘s `predict`

method with a living graph would still work, but we will not be able to use `predictor`

anymore for the `SavedModel`

. Fortunately, TensorFlow 2.0 has a official tutorial on this which is simple and straightforward. I will probably elaborate on this when TensorFlow 2.0 comes out officially and it is very necessary.

## References

TensorFlow Inference for Estimator

https://leimao.github.io/blog/TensorFlow-Estimator-SavedModel/