Converting a Simple Deep Learning Model from PyTorch to TensorFlow

Converting the model to TensorFlow

Now, we need to convert the .pt file to a .onnx file using the torch.onnx.export function. There are two things we need to take note here: 1) we need to pass a dummy input through the PyTorch model first before exporting, and 2) the dummy input needs to have the shape (1, dimension(s) of single input). For example, if the single input is an image array with the shape (number of channels, height, width), then the dummy input needs to have the shape (1, number of channels, height, width). The dummy input is needed as an input placeholder for the resulting TensorFlow model). The following snippet shows the process of exporting the PyTorch model in the ONNX format. I included the input and output names as arguments as well to make it easier for inference in TensorFlow.

After getting the .onnx file, we would need to use the prepare() function in ONNX-TF’s backend module to convert the model from ONNX to TensorFlow.

If you have specified the input and output names in the torch.onnx.export function, you should see the keys ‘input’ and ‘output’ along with their corresponding values, as shown in the snippet below. The names ‘input:0’ and ‘Sigmoid:0’ will be used during inference in TensorFlow.

{'fc0.bias': <tf.Tensor 'Const:0' shape=(50,) dtype=float32>, 'fc0.weight': <tf.Tensor 'Const_1:0' shape=(50, 20) dtype=float32>, 'fc1.bias': <tf.Tensor 'Const_2:0' shape=(50,) dtype=float32>, 'fc1.weight': <tf.Tensor 'Const_3:0' shape=(50, 50) dtype=float32>, 'last_fc.bias': <tf.Tensor 'Const_4:0' shape=(1,) dtype=float32>, 'last_fc.weight': <tf.Tensor 'Const_5:0' shape=(1, 50) dtype=float32>, 'input': <tf.Tensor 'input:0' shape=(1, 20) dtype=float32>, '7': <tf.Tensor 'add:0' shape=(1, 50) dtype=float32>, '8': <tf.Tensor 'Relu:0' shape=(1, 50) dtype=float32>, '9': <tf.Tensor 'add_1:0' shape=(1, 50) dtype=float32>, '10': <tf.Tensor 'Relu_1:0' shape=(1, 50) dtype=float32>, '11': <tf.Tensor 'add_2:0' shape=(1, 1) dtype=float32>, 'output': <tf.Tensor 'Sigmoid:0' shape=(1, 1) dtype=float32>}