In this post, we will learn how to convert a PyTorch model to TensorFlow.
If you are new to Deep Learning you may be overwhelmed by which framework to use. We personally think PyTorch is the first framework you should learn, but it may not be the only framework you may want to learn.
The good news is that you do not need to be married to a framework. You can train your model in PyTorch and then convert it to Tensorflow easily as long as you are using standard layers. The best way to achieve this conversion is to first convert the PyTorch model to ONNX and then to Tensorflow / Keras format.
Same Result, Different Framework Using ONNX
As we could observe, in the early post about FCN ResNet-18 PyTorch the implemented model predicted the dromedary area in the picture more accurately than in TensorFlow FCN version:
Suppose, we would like to capture the results and transfer them into another field, for instance, from PyTorch to TensorFlow. Is there any way to perform it? The answer is yes. One of the possible ways is to use pytorch2keras library. This tool provides an easy way of model conversion between such frameworks as PyTorch and Keras as it is stated in its name. You can easily install it using pip:
pip3 install pytorch2keras
PyTorch Model Conversion Pipeline
As we can see from pytorch2keras repo the pipeline’s logic is described in converter.py. Let’s view its key points:
def pytorch_to_keras(
model, args, input_shapes=None,
change_ordering=False, verbose=False, name_policy=None,
use_optimizer=False, do_constant_folding=False
):
# ...
# load a ModelProto structure with ONNX
onnx_model = onnx.load(stream)
# ...
#
k_model = onnx_to_keras(onnx_model=onnx_model, input_names=input_names,
input_shapes=input_shapes, name_policy=name_policy,
verbose=verbose, change_ordering=change_ordering)
return k_model
As you may noticed the tool is based on the Open Neural Network Exchange (ONNX). ONNX is an open-source AI project, whose goal is to make possible the interchange of neural network models between different tools for choosing a better combination of these tools. Obtained transitional top-level ONNX ModelProto container is passed to the function onnx_to_keras
of onnx2keras tool for further layer mapping.
Let’s examine the PyTorch ResNet18 conversion process by the example of fully convolutional network architecture:
# import transferring tool
from pytorch2keras.converter import pytorch_to_keras
def converted_fully_convolutional_resnet18(
input_tensor, pretrained_resnet=True,
):
# define input tensor
input_var = Variable(torch.FloatTensor(input_tensor))
# get PyTorch ResNet18 model
model_to_transfer = FullyConvolutionalResnet18(pretrained=pretrained_resnet)
model_to_transfer.eval()
# convert PyTorch model to Keras
model = pytorch_to_keras(
model_to_transfer,
input_var,
[input_var.shape[-3:]],
change_ordering=True,
verbose=False,
name_policy="keep",
)
return model
Now we can compare PyTorch and TensorFlow FCN versions. Let’s have a look at the first bunch of PyTorch FullyConvolutionalResnet18
layers. It’s worth noting that we used torchsummary
tool for the visual consistency of the PyTorch and TensorFlow model summaries:
from torchsummary import summary
summary(model_to_transfer, input_size=input_var.shape[-3:])
The output is:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 363, 960] 9,408
BatchNorm2d-2 [-1, 64, 363, 960] 128
ReLU-3 [-1, 64, 363, 960] 0
MaxPool2d-4 [-1, 64, 182, 480] 0
Conv2d-5 [-1, 64, 182, 480] 36,864
BatchNorm2d-6 [-1, 64, 182, 480] 128
ReLU-7 [-1, 64, 182, 480] 0
Conv2d-8 [-1, 64, 182, 480] 36,864
BatchNorm2d-9 [-1, 64, 182, 480] 128
ReLU-10 [-1, 64, 182, 480] 0
BasicBlock-11 [-1, 64, 182, 480] 0
Conv2d-12 [-1, 64, 182, 480] 36,864
BatchNorm2d-13 [-1, 64, 182, 480] 128
ReLU-14 [-1, 64, 182, 480] 0
Conv2d-15 [-1, 64, 182, 480] 36,864
BatchNorm2d-16 [-1, 64, 182, 480] 128
ReLU-17 [-1, 64, 182, 480] 0
TensorFlow model obtained after conversion with pytorch_to_keras
function contains identical layers to the initial PyTorch ResNet18 model, except TF-specific InputLayer
and ZeroPadding2D
, which is included into torch.nn.Conv2d
as padding
parameter.
The below summary was produced with built-in Keras summary
method of the tf.keras.Model
class:
model.summary()
The corresponding layers in the output were marked with the appropriate numbers for PyTorch-TF mapping:
Layer (type) Output Shape Param #
===============================================================
input_0 (InputLayer) [(None, 725, 1920, 3 0
_______________________________________________________________
125_pad (ZeroPadding2D) (None, 731, 1926, 3) 0
_______________________________________________________________
125 (Conv2D) (None, 363, 960, 64) 9408 1
_______________________________________________________________
126 (BatchNormalization) (None, 363, 960, 64) 256 2
_______________________________________________________________
127 (Activation) (None, 363, 960, 64) 0 3
_______________________________________________________________
128_pad (ZeroPadding2D) (None, 365, 962, 64) 0
_______________________________________________________________
128 (MaxPooling2D) (None, 182, 480, 64) 0 4
_______________________________________________________________
129_pad (ZeroPadding2D) (None, 184, 482, 64) 0
_______________________________________________________________
129 (Conv2D) (None, 182, 480, 64) 36864 5
_______________________________________________________________
130 (BatchNormalization) (None, 182, 480, 64) 256 6
_______________________________________________________________
131 (Activation) (None, 182, 480, 64) 0 7
_______________________________________________________________
132_pad (ZeroPadding2D) (None, 184, 482, 64) 0
_______________________________________________________________
132 (Conv2D) (None, 182, 480, 64) 36864 8
_______________________________________________________________
133 (BatchNormalization) (None, 182, 480, 64) 256 9
_______________________________________________________________
134 (Add) (None, 182, 480, 64) 0
_______________________________________________________________
135 (Activation) (None, 182, 480, 64) 0 10
_______________________________________________________________
136_pad (ZeroPadding2D) (None, 184, 482, 64) 0
_______________________________________________________________
136 (Conv2D) (None, 182, 480, 64) 36864 12
_______________________________________________________________
137 (BatchNormalization) (None, 182, 480, 64) 256 13
_______________________________________________________________
138 (Activation) (None, 182, 480, 64) 0 14
_______________________________________________________________
139_pad (ZeroPadding2D) (None, 184, 482, 64) 0
_______________________________________________________________
139 (Conv2D) (None, 182, 480, 64) 36864 15
_______________________________________________________________
140 (BatchNormalization) (None, 182, 480, 64) 256 16
_______________________________________________________________
141 (Add) (None, 182, 480, 64) 0
_______________________________________________________________
142 (Activation) (None, 182, 480, 64) 0 17
_______________________________________________________________
143_pad (ZeroPadding2D) (None, 184, 482, 64) 0
_______________________________________________________________
The below scheme part introduces a visual representation of the FCN ResNet18 blocks for both versions – TensorFlow and PyTorch:
Model graphs were generated with a Netron open source viewer. It supports a wide range of model formats obtained from ONNX, TensorFlow, Caffe, PyTorch and others. The saved model graph is passed as an input to the Netron, which further produces the detailed model chart.
Transferred Model Results
Thus, we converted the whole PyTorch FC ResNet-18 model with its weights to TensorFlow changing NCHW (batch size, channels, height, width) format to NHWC with change_ordering=True
parameter.
That’s been done because in PyTorch model the shape of the input layer is 3×725×1920, whereas in TensorFlow it is changed to 725×1920×3 as the default data format in TF is NHWC. We should also remember, that to obtain the same shape of prediction as it was in PyTorch (1, 1000, 3, 8)
, we should transpose the network output once more:
# NHWC: (1, 725, 1920, 3)
predict_image = tf.expand_dims(image, 0)
# NCHW: (1, 3, 725, 1920)
image = np.transpose(tf.expand_dims(image, 0).numpy(), [0, 3, 1, 2])
# get transferred torch ResNet18 with pre-trained ImageNet weights
model = converted_fully_convolutional_resnet18(
input_tensor=image, pretrained_resnet=True,
)
# Perform inference.
# Instead of a 1×1000 vector, we will get a
# 1×1000×n×m output ( i.e. a probability map
# of size n × m for each 1000 class,
# where n and m depend on the size of the image).
preds = model.predict(predict_image)
# NHWC: (1, 3, 8, 1000) back to NCHW: (1, 1000, 3, 8)
preds = tf.transpose(preds, (0, 3, 1, 2))
preds = tf.nn.softmax(preds, axis=1)
One more point to be mentioned is image preprocessing. We remember that in TF fully convolutional ResNet50 special preprocess_input
util function was applied. However, here, for converted to TF model, we use the same normalization as in PyTorch FCN ResNet-18 case:
# transform input image:
transform = Compose(
[
Normalize(
# subtract mean
mean=(0.485, 0.456, 0.406),
# divide by standard deviation
std=(0.229, 0.224, 0.225),
),
],
)
# apply image transformations, (725, 1920, 3)
image = transform(image=image)["image"]
Let’s explore the results:
Response map shape : (1, 1000, 3, 8)
Predicted Class : Arabian camel, dromedary, Camelus dromedarius tf.Tensor(354, shape=(), dtype=int64)
The predicted class is correct, let’s have a look at the response map:
You can see, that the response area is the same as we have in the previous PyTorch FCN post: