Tutorial: Import an ONNX Model into TensorFlow for Inference

In the last tutorial, we trained a CNN model in PyTorch and converted that into an ONNX model. In the current tutorial, we will import the model into TensorFlow and use it for inference.
Before proceeding, make sure that you completed the previous tutorial as this is an extension of the same.
Converting ONNX Model to TensorFlow Model
The output
folder has an ONNX model which we will convert into TensorFlow format.
ONNX has a Python module that loads the model and saves it into the TensorFlow graph.
1 |
pip install onnx_tf |
We are now ready for conversion. Create a Python program with the below code and run it:
1 2 3 4 5 6 |
import onnx from onnx_tf.backend import prepare onnx_model = onnx.load("output/model.onnx") tf_rep = prepare(onnx_model) tf_rep.export_graph("output/model.pb") |
The output
folder contains three models: PyTorch, ONNX, and TensorFlow.
We are now ready to use the model in TensorFlow. Note that it works only with TensorFlow 1.x. For this tutorial, we are using the 1.15, which is the last version.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import tensorflow as tf import numpy as np import cv2 import logging, os logging.disable(logging.WARNING) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" INPUT_TENSOR_NAME = 'input.1:0' OUTPUT_TENSOR_NAME = 'add_4:0' IMAGE_PATH="0.png" PB_PATH="output/model.pb" img = cv2.imread(IMAGE_PATH) img = np.dot(img[...,:3], [0.299, 0.587, 0.114]) img = cv2.resize(img, dsize=(28, 28), interpolation=cv2.INTER_AREA) img.resize((1, 1, 28, 28)) with tf.gfile.FastGFile(PB_PATH, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name="") input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME) output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME) with tf.Session(graph=graph) as sess: output_vals = sess.run(output_tensor, feed_dict={input_tensor: img}) # prediction=int(np.argmax(np.array(output_vals).squeeze(), axis=0)) print(prediction) |
We start by importing the right modules and then disable the warnings generated by TensorFlow.
The names for input and output tensor can be taken from Netron tool by opening the model.pb
file.
The input node (input.1) and output node (add_4) name and shape are visible in the Netron.
The next few lines of code preprocess the image through OpenCV. We then open the TensorFlow model and create a session based on the graph.
Finally, by applying the argmax
function, we classify the output into one of the ten classes defined by MNIST.
In this tutorial, we imported an ONNX model into TensorFlow and used it for inference. In the next part, we will build a computer vision application that runs at the edge powered by Intel’s Movidius Neural Compute Stick. The model uses an ONNX Runtime execution provider optimized for the OpenVINO Toolkit. Stay tuned.
Janakiram MSV’s Webinar series, “Machine Intelligence and Modern Infrastructure (MI2)” offers informative and insightful sessions covering cutting-edge technologies. Sign up for the upcoming MI2 webinar at http://mi2.live.