Lei Mao bio photo

Lei Mao

Machine Learning, Artificial Intelligence, Computer Science.

Twitter Facebook LinkedIn GitHub   G. Scholar E-Mail RSS

Introduction

Frozen graphs are commonly used for inference in TensorFlow and are stepping stones for inference for other frameworks. TensorFlow 1.x provided an interface to freeze models via tf.Session, and I previously had a blog on how to use frozen models for inference in TensorFlow 1.x. However, since TensorFlow 2.x removed tf.Session, freezing models in TensorFlow 2.x had been a problem for most of the users.


In this blog post, I am going to show how to save, load, and run inference for frozen graphs in TensorFlow 2.x.

Materials

This sample code was available on my GitHub. It was modified from the official TensorFlow 2.x Fashion MNIST Classification example.

Train Model and Export to Frozen Graph

We would train a simple fully connected neural network to classify the Fashion MNIST data. The model would be saved as SavedModel in the models directory for completeness. In addition, the model would also be frozen and saved as frozen_graph.pb in the frozen_models directory.

To train and export the model, please run the following command in the terminal.

$ python train.py

We would also have a reference value for the sample inference from TensorFlow 2.x using the conventional inference protocol in the printouts.

Example prediction reference:
[3.9113933e-05 1.1972898e-07 5.2244545e-06 5.4371812e-06 6.1125693e-06
 1.1335548e-01 3.0090479e-05 2.8483599e-01 9.5160649e-04 6.0077089e-01]

The key to exporting the frozen graph is to convert the model to concrete function, extract and freeze graphs from the concrete function, and serialize to hard drive.

    # Convert Keras model to ConcreteFunction
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models",
                      name="frozen_graph.pb",
                      as_text=False)

Run Inference Using Frozen Graph

To run inference using the frozen graph in TensorFlow 2.x, please run the following command in the terminal.

$ python test.py

We also got the value for the sample inference using frozen graph. It is (almost) exactly the same as the reference value we got using the conventional inference protocol.

Example prediction reference:
[3.9113860e-05 1.1972921e-07 5.2244545e-06 5.4371812e-06 6.1125752e-06
 1.1335552e-01 3.0090479e-05 2.8483596e-01 9.5160597e-04 6.0077089e-01]

Because frozen graph has been sort of being deprecated by TensorFlow, and SavedModel format is encouraged to use, we would have to use the TensorFlow 1.x function to load the frozen graph from hard drive.

    # Load frozen graph using TensorFlow 1.x functions
    with tf.io.gfile.GFile("./frozen_models/frozen_graph.pb", "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

    # Wrap frozen graph to ConcreteFunctions
    frozen_func = wrap_frozen_graph(graph_def=graph_def,
                                    inputs=["x:0"],
                                    outputs=["Identity:0"],
                                    print_graph=True)

Once the frozen graph is loaded, we convert the frozen graph to concrete function and run inference.

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    print("-" * 50)
    print("Frozen model layers: ")
    layers = [op.name for op in import_graph.get_operations()]
    if print_graph == True:
        for layer in layers:
            print(layer)
    print("-" * 50)

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))

Convert Frozen Graph to ONNX

If TensorFlow 1.x and tf2onnx have been installed, the frozen graph could be converted to ONNX model using the following command.

$ python -m tf2onnx.convert --input ./frozen_models/frozen_graph.pb --output model.onnx --outputs Identity:0 --inputs x:0

Convert Frozen Graph to UFF

The frozen graph could also be converted to UFF model for TensorRT using the following command.

$ convert-to-uff frozen_graph.pb -t -O Identity -o frozen_graph.uff

TensorRT 6.0 Docker image could be pulled from NVIDIA NGC.

$ docker pull nvcr.io/nvidia/tensorrt:19.12-py3

Conclusions

TensorFlow 2.x could also save, load, and run inference for frozen graphs. The frozen graphs from TensorFlow 2.x should be equivalent to the frozen graphs from TensorFlow 1.x.

References