Machine Learning / Software Development

Tutorial: Train a Deep Learning Model in PyTorch and Export It to ONNX

17 Jul 2020 8:51am, by
This post is the third in a series of introductory tutorials on the Open Neural Network Exchange (ONNX), an initiative from AWS, Microsoft, and Facebook to define a standard for interoperability across machine learning platforms. See: Part 1, Part 2.

In this tutorial, we will train a Convolutional Neural Network in PyTorch and convert it into an ONNX model. Once we have the model in ONNX format, we can import that into other frameworks such as TensorFlow for either inference and reusing the model through transfer learning.

Setting up the Environment

The only prerequisite for this tutorial is Python 3.x. Make sure it is installed on your machine.

Create a Python virtual environment that will be used for this and the next tutorial.

python3 -m virtualenv pyt2tf

source pyt2tf/bin/activate

Create a file, requirements.txt, with the below content that has the modules needed for the tutorial.

Note that we are using TensorFlow 1.x for this tutorial. You may see errors if you install any version of TensorFlow above 1.15.

Install the modules from the above file with pip.

pip install -r requirements.txt

Finally, create a directory to save the model.

mkdir output

Train a CNN with MNIST Dataset

Let’s start by importing the right modules needed for the program.

We will then define the neural network with appropriate layers.

Create a method to train the PyTorch model.

The below method will test and evaluate the model:

With the network architecture, train, and test methods in place, let’s create the main method to create an instance of the neural network and train it with the MNIST dataset.

Within the main method, we download the MNIST dataset, preprocess it, and train the model with 10 epochs.

If you are training the model on a beefy box with a powerful GPU, you can change the device variable and tweak the number of epochs to get better accuracy. But, for the MNIST dataset, you will hit ~98% accuracy with just 10 epochs running on the CPU.Below is the complete code to train the model in PyTorch.

Once the training is done, you will find the file,, in the output directory. This is the artifact we need to convert the model into ONNX format.

Exporting PyTorch Model to ONNX Format

PyTorch supports ONNX natively which means we can convert the model without using an additional module.

Let’s load the trained model from the previous step, create an input that matches the shape of the input tensor, and export the model to ONNX.

The neural network class is included in the code to ensure that the model architecture is accessible along with the input tensor shape.

Running the above code results in the creation of model.onnx file which contains the ONNX version of the deep learning model originally trained in PyTorch.

You can open this in the Netron tool to explore the layers and the architecture of the neural network.

In the next part of this tutorial, we will import the ONNX model into TensorFlow and use it for inference. Stay tuned.

Janakiram MSV’s Webinar series, “Machine Intelligence and Modern Infrastructure (MI2)” offers informative and insightful sessions covering cutting-edge technologies. Sign up for the upcoming MI2 webinar at

Feature image: “Taking in the Wheat Sheaves” via New Old Stock.