Neural network usage usually takes a lot of computations, but in our modern world, even a smartphone can be a device to run your trained neural model. Today we will take a look at how we can convert a trained PyTorch model to CoreML format.
CoreML format is used for iOS neural network runtime and it allows running model inference on iOS-based devices only. Your model can be integrated into your app using CoreML framework , which can be installed only on macOS devices. The easiest way to train a new classification or detection model is to use the Create ML app, but it has limited functionality. Due to that fact, we will use PyTorch as a framework for training.
The most common approach is to first convert the PyTorch model to ONNX format and then convert the ONNX model into CoreML format using different tools and libraries. We will discuss this in more detail in this post.
Step 1: PyTorch to ONNX
As we mentioned before in our “PyTorch Model Inference using ONNX and Caffe2” post: “Open Neural Network Exchange (ONNX) is an open format that lets users move deep learning models between different frameworks. This open format was initially proposed by Facebook and Microsoft but now is a widely accepted industry standard.”
Let’s take a look into PyTorch to ONNX conversion. The main function that converts PyTorch to ONNX is already implemented in PyTorch and called torch.onnx.export
.
Note: You need to first import ONNX library and then PyTorch. Otherwise, you might get some seg-fault errors because of dynamic loading issues.
Let’s take a look at the code for converting PyTorch Model to ONNX format.
def save_onnx_from_torch(
model, model_name, input_image, input_names=None, output_names=None, simplify=False,
):
# Section 1: PyTorch model conversion --
if input_names is None:
input_names = ["input"]
if output_names is None:
output_names = ["output"]
# set mode to evaluation and change device to cpu
model.eval()
model.cpu()
onnx_filename = model_name + ".onnx"
# export our model to ONNX format
torch.onnx.export(
model,
input_image,
onnx_filename,
verbose=True,
input_names=input_names,
output_names=output_names,
)
# Section 2: Model testing
onnx_model = check_onnx_model(model, onnx_filename, input_image)
# Section 3: ONNX simplifier
if simplify:
filename = model_name + "_simplified.onnx"
onnx_model_simplified, check = simplify_onnx(onnx_model, filename)
onnx.checker.check_model(onnx_model_simplified)
check_onnx_model(model, filename, input_image)
return onnx_model_simplified
else:
return onnx_model
The function above is divided into three sections, let’s take a deeper look at them.
PyTorch model conversion
In our case we use a pre-trained classification model from torchvision, so we have a tensor with one image as input and one tensor with predictions as output. Our code is compatible only with torchvision’s classification models due to different output formats and some layers which are not supported by default in ONNX for detection and segmentation models from torchvision. If your network structure requires many inputs or outputs you should mention their names in input_names
or output_names
field.
Our model can be converted into ONNX using torch.onnx.export
function and it needs any input example. To make a network forward pass for the conversion we need an image example that can be a real picture or just randomly generated tensor. As a result, we will have a converted to ONNX model saved on the disk.
To make sure the conversion succeeded and all the nodes are in place, we can visualize the obtained ONNX graph. Here you can use a Netron framework that can show you the network architecture in a pretty human-readable format. To use it, you can visit its official online tool and load your model. Another way is to download the offline version or use Netron’s Python API.
We won’t show the whole ONNX graph here, because it consists of too many nodes. Instead, you can take a look at a small part of it to get an idea:
Model testing
After conversion, we should always check for its correctness. The best approach is to compare two network’s predictions from PyTorch and ONNX models using the same input image. Let us see how the code looks like:
def check_onnx_output(filename, input_data, torch_output):
session = onnxruntime.InferenceSession(filename)
input_name = session.get_inputs()[0].name
result = session.run([], {input_name: input_data.numpy()})
for test_result, gold_result in zip(result, torch_output.values()):
np.testing.assert_almost_equal(
gold_result.cpu().numpy(), test_result, decimal=3,
)
return result
In the code above, we start ONNX inference using onnxruntime.InferenceSession
to get an output from the ONNX model.
After that, we can compare the result between PyTorch and ONNX using np.testing.assert_almost_equal
for each network output. As it follows from the NumPy documentation, assert_almost_equal
function “Raises an AssertionError if two items are not equal up to desired precision.” As a result, we will compare each value from two outputs up to three decimal digits.
ONNX Simplifier
Now, if all tests have passed, we can move forward. Let’s simplify our ONNX model with onnx-simplifier library. Simplifier is intended to make your computation graph smaller, get rid of some redundant operations, or make their computation much simpler. As a result, we are going to have a faster inference due to less number of operations.
Here we need to say that we use torch version 1.4.0 because there are some crashing issues with model simplifier in the newest torch versions 1.5.0+. There is no solution that could help to fix it, but we hope that it will work in the new torch releases. As we mentioned before, all stable requirements are collected in requirements.txt file.
The simplify
function from the onnx-simplifier
library requires ONNX model as input.
def simplify_onnx(onnx_model, filename):
simplified_model, check = simplify(onnx_model)
onnx.save_model(simplified_model, filename)
return simplified_model, check
Now we can take a look at how our model has changed after simplifying:
As you can see, there is no BatchNormalization layer in our new model. The onnx-simplifier has merged every BatchNorm into the preceding convolution layer. You can find additional information about fusing BatchNormalization layers using this link.
Step 2: ONNX to CoreML
To convert ONNX to CoreML model we will use onnx_coreml
library, in particular, the convert
function.
def convert_onnx_to_coreml(onnx_model, model_name, torch_model, input_data):
model_coreml = convert(onnx_model, minimum_ios_deployment_target="13")
coreml_filename = model_name + ".mlmodel"
model_coreml.save(coreml_filename)
# check that platform is macOS
if platform.system() == "Darwin":
check_coreml_model(coreml_filename, torch_model, input_data)
return model_coreml
Now, when we have obtained the CoreML model, you can use the check_coreml_model
function below to compare its output with the PyTorch model output the same way we did it with the check_onnx_model
. The tricky part here is that you can only do so on the macOS device since coremltools
library allows you to run inference only with the macOS, not any other operating system.
def check_coreml_model(coreml_filename, torch_model, input_data):
# get PyTorch model output
with torch.no_grad():
torch_output = {"output": torch_model(input_data)}
# get CoreML model output
coreml_model = coremltools.models.MLModel(coreml_filename, useCPUOnly=True)
# convert input to numpy and get coreml model prediction
input_data = input_data.cpu().numpy()
pred = coreml_model.predict({"input": input_data})
for key in pred:
np.testing.assert_almost_equal(
torch_output[key].cpu().numpy(), pred[key], decimal=3,
)
print("CoreML model is checked!")
return pred
Finally, we run both the steps described above using the following code:
# random image to make a network forward pass
dummy_input = torch.randn(1, 3, input_size, input_size, device="cpu")
# save ONNX model
onnx_model = save_onnx_from_torch(
torch_model, model_name, input_image=dummy_input, simplify=simp,
)
convert_onnx_to_coreml(
onnx_model, model_name, torch_model=torch_model, input_data=dummy_input,
)
print("PyTorch model has been converted to CoreML format")
We first define a dummy input and then call the 2 functions – save_onnx_from_torch
and convert_onnx_to_coreml
. You should get the following output:
Translation to CoreML spec completed. Now compiling the CoreML model.
Model Compilation done.
CoreML model is checked!
PyTorch model has been converted to CoreML format
Summary
In this post, we had a look at how we can convert PyTorch model to ONNX and then into CoreML format. The wide range of functions makes the conversion process easier. You should remember that different model architectures require different approaches. You should also perform all the checks to ensure that the model conversion was correct.
We will come up with a post on how to use the model on an iOS device soon!