PyTorch to Tensorflow Model Conversion

In this post, we will learn how to convert a PyTorch model to TensorFlow. If you are new to Deep Learning you may be overwhelmed by which framework to use. We personally think PyTorch is the first framework you should learn, but it may not be the only framework you

Black Friday Sale  | 45% + 5% Off on all AI Courses

Black Friday Sale  | 45% + 5% Off on all AI Courses

Black Friday Sale  | 45% + 5% Off on all AI Courses

Black Friday Sale  | 45% + 5% Off on all AI Courses

Black Friday Sale  | 45% + 5% Off on all AI Courses

Black Friday Sale  | 45% + 5% Off on all AI Courses

In this post, we will learn how to convert a PyTorch model to TensorFlow.

If you are new to Deep Learning you may be overwhelmed by which framework to use. We personally think PyTorch is the first framework you should learn, but it may not be the only framework you may want to learn.

The good news is that you do not need to be married to a framework. You can train your model in PyTorch and then convert it to Tensorflow easily as long as you are using standard layers. The best way to achieve this conversion is to first convert the PyTorch model to ONNX and then to Tensorflow / Keras format.

Same Result, Different Framework Using ONNX

As we could observe, in the early post about FCN ResNet-18 PyTorch the implemented model predicted the dromedary area in the picture more accurately than in TensorFlow FCN version:

Figure 1: PyTorch FCN ResNet18 activations
Figure 2: TensorFlow FCN ResNet50 activations

Suppose, we would like to capture the results and transfer them into another field, for instance, from PyTorch to TensorFlow. Is there any way to perform it? The answer is yes. One of the possible ways is to use pytorch2keras library. This tool provides an easy way of model conversion between such frameworks as PyTorch and Keras as it is stated in its name. You can easily install it using pip:

pip3 install pytorch2keras
Download Code To easily follow along this tutorial, please download code by clicking on the button below. It's FREE!

PyTorch Model Conversion Pipeline

As we can see from pytorch2keras repo the pipeline’s logic is described in converter.py. Let’s view its key points:

def pytorch_to_keras(
    model, args, input_shapes=None,
    change_ordering=False, verbose=False, name_policy=None,
    use_optimizer=False, do_constant_folding=False
):

    # ...

    # load a ModelProto structure with ONNX
    onnx_model = onnx.load(stream)

    # ...
    #
    k_model = onnx_to_keras(onnx_model=onnx_model, input_names=input_names,
                            input_shapes=input_shapes, name_policy=name_policy,
                            verbose=verbose, change_ordering=change_ordering)

    return k_model

As you may noticed the tool is based on the Open Neural Network Exchange (ONNX). ONNX is an open-source AI project, whose goal is to make possible the interchange of neural network models between different tools for choosing a better combination of these tools. Obtained transitional top-level ONNX ModelProto container is passed to the function onnx_to_keras of onnx2keras tool for further layer mapping.

Let’s examine the PyTorch ResNet18 conversion process by the example of fully convolutional network architecture:

# import transferring tool
from pytorch2keras.converter import pytorch_to_keras

def converted_fully_convolutional_resnet18(
    input_tensor, pretrained_resnet=True,
):
    # define input tensor
    input_var = Variable(torch.FloatTensor(input_tensor))

    # get PyTorch ResNet18 model
    model_to_transfer = FullyConvolutionalResnet18(pretrained=pretrained_resnet)
    model_to_transfer.eval()

    # convert PyTorch model to Keras
    model = pytorch_to_keras(
        model_to_transfer,
        input_var,
        [input_var.shape[-3:]],
        change_ordering=True,
        verbose=False,
        name_policy="keep",
    )

    return model

Now we can compare PyTorch and TensorFlow FCN versions. Let’s have a look at the first bunch of PyTorch FullyConvolutionalResnet18 layers. It’s worth noting that we used torchsummary tool for the visual consistency of the PyTorch and TensorFlow model summaries:

from torchsummary import summary

summary(model_to_transfer, input_size=input_var.shape[-3:])

The output is:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 363, 960]           9,408
       BatchNorm2d-2         [-1, 64, 363, 960]             128
              ReLU-3         [-1, 64, 363, 960]               0
         MaxPool2d-4         [-1, 64, 182, 480]               0
            Conv2d-5         [-1, 64, 182, 480]          36,864
       BatchNorm2d-6         [-1, 64, 182, 480]             128
              ReLU-7         [-1, 64, 182, 480]               0
            Conv2d-8         [-1, 64, 182, 480]          36,864
       BatchNorm2d-9         [-1, 64, 182, 480]             128
             ReLU-10         [-1, 64, 182, 480]               0
       BasicBlock-11         [-1, 64, 182, 480]               0
           Conv2d-12         [-1, 64, 182, 480]          36,864
      BatchNorm2d-13         [-1, 64, 182, 480]             128
             ReLU-14         [-1, 64, 182, 480]               0
           Conv2d-15         [-1, 64, 182, 480]          36,864
      BatchNorm2d-16         [-1, 64, 182, 480]             128
             ReLU-17         [-1, 64, 182, 480]               0

TensorFlow model obtained after conversion with pytorch_to_keras function contains identical layers to the initial PyTorch ResNet18 model, except TF-specific InputLayer and ZeroPadding2D, which is included into torch.nn.Conv2d as padding parameter.

The below summary was produced with built-in Keras summary method of the tf.keras.Model class:

model.summary()

The corresponding layers in the output were marked with the appropriate numbers for PyTorch-TF mapping:

Layer (type)                    Output Shape         Param #
===============================================================
input_0 (InputLayer)            [(None, 725, 1920, 3 0
_______________________________________________________________
125_pad (ZeroPadding2D)         (None, 731, 1926, 3) 0
_______________________________________________________________
125 (Conv2D)                   (None, 363, 960, 64) 9408     1
_______________________________________________________________
126 (BatchNormalization)        (None, 363, 960, 64) 256     2
_______________________________________________________________
127 (Activation)                (None, 363, 960, 64) 0       3
_______________________________________________________________
128_pad (ZeroPadding2D)         (None, 365, 962, 64) 0
_______________________________________________________________
128 (MaxPooling2D)              (None, 182, 480, 64) 0       4
_______________________________________________________________
129_pad (ZeroPadding2D)         (None, 184, 482, 64) 0
_______________________________________________________________
129 (Conv2D)                    (None, 182, 480, 64) 36864   5
_______________________________________________________________
130 (BatchNormalization)        (None, 182, 480, 64) 256     6
_______________________________________________________________
131 (Activation)                (None, 182, 480, 64) 0       7
_______________________________________________________________
132_pad (ZeroPadding2D)         (None, 184, 482, 64) 0
_______________________________________________________________
132 (Conv2D)                    (None, 182, 480, 64) 36864   8
_______________________________________________________________
133 (BatchNormalization)        (None, 182, 480, 64) 256     9
_______________________________________________________________
134 (Add)                       (None, 182, 480, 64) 0

_______________________________________________________________
135 (Activation)                (None, 182, 480, 64) 0       10
_______________________________________________________________
136_pad (ZeroPadding2D)         (None, 184, 482, 64) 0
_______________________________________________________________
136 (Conv2D)                    (None, 182, 480, 64) 36864   12
_______________________________________________________________
137 (BatchNormalization)        (None, 182, 480, 64) 256     13
_______________________________________________________________
138 (Activation)                (None, 182, 480, 64) 0       14
_______________________________________________________________
139_pad (ZeroPadding2D)         (None, 184, 482, 64) 0
_______________________________________________________________
139 (Conv2D)                    (None, 182, 480, 64) 36864   15
_______________________________________________________________
140 (BatchNormalization)        (None, 182, 480, 64) 256     16
_______________________________________________________________
141 (Add)                       (None, 182, 480, 64) 0

_______________________________________________________________
142 (Activation)                (None, 182, 480, 64) 0       17
_______________________________________________________________
143_pad (ZeroPadding2D)         (None, 184, 482, 64) 0
_______________________________________________________________

The below scheme part introduces a visual representation of the FCN ResNet18 blocks for both versions – TensorFlow and PyTorch:

Figure 3: TensorFlow FCN ResNet18 model after conversion (on the left); PyTorch initial FCN ResNet18 (on the right)

Model graphs were generated with a Netron open source viewer. It supports a wide range of model formats obtained from ONNX, TensorFlow, Caffe, PyTorch and others. The saved model graph is passed as an input to the Netron, which further produces the detailed model chart.

Transferred Model Results

Thus, we converted the whole PyTorch FC ResNet-18 model with its weights to TensorFlow changing NCHW (batch size, channels, height, width) format to NHWC with change_ordering=True parameter.

That’s been done because in PyTorch model the shape of the input layer is 3×725×1920, whereas in TensorFlow it is changed to 725×1920×3 as the default data format in TF is NHWC. We should also remember, that to obtain the same shape of prediction as it was in PyTorch (1, 1000, 3, 8), we should transpose the network output once more:

# NHWC: (1, 725, 1920, 3)
predict_image = tf.expand_dims(image, 0)
# NCHW: (1, 3, 725, 1920)
image = np.transpose(tf.expand_dims(image, 0).numpy(), [0, 3, 1, 2])

    # get transferred torch ResNet18 with pre-trained ImageNet weights
    model = converted_fully_convolutional_resnet18(
        input_tensor=image, pretrained_resnet=True,
    )

# Perform inference.
# Instead of a 1×1000 vector, we will get a
# 1×1000×n×m output ( i.e. a probability map
# of size n × m for each 1000 class,
# where n and m depend on the size of the image).
preds = model.predict(predict_image)
# NHWC: (1, 3, 8, 1000) back to NCHW: (1, 1000, 3, 8)
preds = tf.transpose(preds, (0, 3, 1, 2))
preds = tf.nn.softmax(preds, axis=1)

One more point to be mentioned is image preprocessing. We remember that in TF fully convolutional ResNet50 special preprocess_input util function was applied. However, here, for converted to TF model, we use the same normalization as in PyTorch FCN ResNet-18 case:

# transform input image:
transform = Compose(
    [
        Normalize(
            # subtract mean
            mean=(0.485, 0.456, 0.406),
            # divide by standard deviation
            std=(0.229, 0.224, 0.225),
        ),
    ],
)
# apply image transformations, (725, 1920, 3)
image = transform(image=image)["image"]

Let’s explore the results:

Response map shape :  (1, 1000, 3, 8)
Predicted Class :  Arabian camel, dromedary, Camelus dromedarius tf.Tensor(354, shape=(), dtype=int64)

The predicted class is correct, let’s have a look at the response map:

Figure 4: Response map

You can see, that the response area is the same as we have in the previous PyTorch FCN post:

Figure 5: Converted TF FCN ResNet18 result


Read Next

VideoRAG: Redefining Long-Context Video Comprehension

VideoRAG: Redefining Long-Context Video Comprehension

Discover VideoRAG, a framework that fuses graph-based reasoning and multi-modal retrieval to enhance LLMs' ability to understand multi-hour videos efficiently.

AI Agent in Action: Automating Desktop Tasks with VLMs

AI Agent in Action: Automating Desktop Tasks with VLMs

Learn how to build AI agent from scratch using Moondream3 and Gemini. It is a generic task based agent free from…

The Ultimate Guide To VLM Evaluation Metrics, Datasets, And Benchmarks

The Ultimate Guide To VLM Evaluation Metrics, Datasets, And Benchmarks

Get a comprehensive overview of VLM Evaluation Metrics, Benchmarks and various datasets for tasks like VQA, OCR and Image Captioning.

Subscribe to our Newsletter

Subscribe to our email newsletter to get the latest posts delivered right to your email.

Subscribe to receive the download link, receive updates, and be notified of bug fixes

Which email should I send you the download link?

 

Get Started with OpenCV

Subscribe To Receive

We hate SPAM and promise to keep your email address safe.​