Open Neural Network Exchange (ONNX) is a powerful and open format built to represent machine learning models. The final outcome of training any machine learning or deep learning algorithm is a model file that represents the mapping of input data to output predictions in an efficient manner. These models are stored in different file formats depending on the framework they were created in .pkl for Scikit-learn, .pb for TensorFlow, .pth for PyTorch, and so on. Therein lies the problem, you can’t take a model created and trained in one framework and use it or deploy it in a different framework.
The intent behind ONNX is to be like the “USB standard” of the machine learning world. Before the introduction of USB computers and computer peripherals used to have ad hoc proprietary interfaces. Much like the pre-USB era, the present machine learning models have ad hoc formats. It overcomes the problem of framework lock-in by providing a universal intermediary model format that frameworks can easily save to and load from. This allows ML developers to create models in the framework of their choice without worrying about the deployment environment. ONNX also makes it easier to access hardware acceleration provided by different frameworks and runtime environments.
Installing ONNX and other required libraries
Requirement: Python 3.6 or higher version
1. Install ONNX
pip:pip install onnx
Conda:conda install -c conda-forge onnx
2. Install tensorflow and onnx-tensorflow
pip install tensorflow pip install tensorflow-addons git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow && pip install -e .
3. Install PyTorch and torchvisionpip install pytorch
pip install torchvision
Converting a PyTorch model to TensorFlow
- Import required libraries and classes
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import onnx from onnx_tf.backend import prepare
- Define a basic CNN model
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1)
- Create the train and test methods
def train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 1000 == 0: print('Train Epoch: {} \tLoss: {:.6f}'.format( epoch, loss.item())) def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss pred = output.max(1, keepdim=True)[1] # get the index of the maxlog-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
- Download the train and test datasets, normalize them and create data loaders.
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=1000, shuffle=True)
- Create the model, define the optimitier and train it
device = torch.device("cuda") model = Net().to(device) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for epoch in range(21): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader)
- Save the trained model
torch.save(model.state_dict(), 'mnist.pth')
- Load the saved model. Generate and pass random input so the Pytorch exporter can trace the model and save it to an ONNX file.
trained_model = Net() trained_model.load_state_dict(torch.load('mnist.pth')) dummy_input = Variable(torch.randn(1, 1, 28, 28)) torch.onnx.export(trained_model, dummy_input, "mnist.onnx")
- Load the ONNX file and import it to Tensorflow
model = onnx.load('mnist.onnx') tf_rep = prepare(model)
- Run and test the Tensorflow model.
import numpy as np from IPython.display import display from PIL import Image print('Image 1:') img = Image.open('two.png').resize((28, 28)).convert('L') display(img) output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :]) print('The digit is classified as ', np.argmax(output)) print('Image 2:') img = Image.open('three.png').resize((28, 28)).convert('L') display(img) output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :]) print('The digit is classified as ', np.argmax(output))
- Save the Tensorflow model.
tf_rep.export_graph('mnist.pb')