In this post, we continue to consider how to speed up inference quickly and painlessly if we already have a trained model in PyTorch. In the previous post
- We discussed what ONNX and TensorRT are and why they are needed
- Сonfigured the environment for PyTorch and TensorRT Python API
- Loaded and launched a pre-trained model using PyTorch
- Converted the PyTorch model to ONNX format
- Visualized ONNX Model in Netron
- Used NVIDIA TensorRT for inference
- Found out what CUDA streams are
- Learned about TensorRT Context, Engine, Builder, Network, and Parser
- Tested performance
You can find this post here: https://learnopencv.com/how-to-convert-a-model-from-pytorch-to-tensorrt-and-speed-up-inference/.
However, in the previous post, we used TensorRT Python API, although TensorRT supports C++ API too. Let us close the gap and take a closer look at the C++ API as well. But first, let’s compare the pros and cons of both approaches.
Python API vs C++ API
When it comes to TensorRT, in general, Python API and C++ API, both will allow you to achieve good performance and solve the problem. So, which approach you should choose depends only on your current task and not on the framework.
Python API benefits
- You can reuse data preprocessing and postprocessing, which you already implement for training. We need to know what transformations were made during training to replicate them for inference. And in the case of C++ API, we have to re-implement the same transformations using only available C++ libraries. Which, as you know, is not always possible. So before you going to use some exotic transformations, ensure that you can replicate them in the C++ part too. It’s good practice to set meaningful default parameters explicitly because they can have different values even for the same frameworks in Python and C++.
- You don’t have to learn C++ if you’re not familiar with it. Just enjoy simplicity, flexibility, and intuitive Python.
C++ API benefits
- TensorRT C++ API supports more platforms than Python API. For example, if you use Python API, an inference can not be done on Windows x64. To find out more about supported platforms please refer: https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html
- C++ supports multithreading. Some models required heavy parallelization, and you can easily create threads in C++. But if you use Python based on CPython (most widely-used implementation), it cannot run more than one system thread at a time due to the GIL.
- Pure C++ library can be used in real-time applications, in contrast with a slow Python script. So in any performance-critical scenarios, as well as in situations where safety is important, for example, in automotive, NVIDIA recommends using C++ API.
Environment Setup for TensorRT
To reproduce the experiments mentioned in this article you’ll need an NVIDIA graphics card. Any architecture newer than Maxwell, which compute capability is 5.0, will do. You can find your GPU compute capability in the table here: https://developer.nvidia.com/cuda-gpus#compute. Don’t forget to install appropriate drivers.
Install PyTorch, ONNX, and OpenCV
Install Python 3.6 or later and run
python3 -m pip install -r requirements.txt
Requirements.txt content:
torch==1.2.0
torchvision==0.4.0
albumentations==0.4.5
onnx==1.4.1
opencv-python==4.2.0.34
The code was tested on specified versions. But it’s okay to try to launch it on other versions if you have some of those components already installed.
Install TensorRT
- Install CMake at least 3.10 version
- Download and install NVIDIA CUDA 10.0 or later following by official instruction: link
- Download and extract CuDNN library for your CUDA version (login required): link
- Download and extract NVIDIA TensorRT library for your CUDA version (login required): link. The minimum required version is 6.0.1.5
- Add the path to CUDA, TensorRT, CuDNN to
PATH
variable (orLD_LIBRARY_PATH
) - Build or install a pre-built version of OpenCV and OpenCV Contrib. The minimum required version is 4.0.0.
We are now ready to for our experiment.
Convert pre-trained PyTorch model to ONNX
We have already done all this work in the previous article, so here we just give the listing of the Python script. If you need further clarification, please refer to this: How to Convert a Model from PyTorch to TensorRT and Speed Up Inference
def preprocess_image(img_path):
# transformations for the input data
transforms = Compose([
Resize(224, 224, interpolation=cv2.INTER_NEAREST),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
ToTensor(),
])
# read input image
input_img = cv2.imread(img_path)
# do transformations
input_data = transforms(image=input_img)["image"]
# prepare batch
batch_data = torch.unsqueeze(input_data, 0)
return batch_data
def postprocess(output_data):
# get class names
with open("imagenet_classes.txt") as f:
classes = [line.strip() for line in f.readlines()]
# calculate human-readable value by softmax
confidences = torch.nn.functional.softmax(output_data, dim=1)[0] * 100
# find top predicted classes
_, indices = torch.sort(output_data, descending=True)
i = 0
# print the top classes predicted by the model
while confidences[indices[0][i]] > 0.5:
class_idx = indices[0][i]
print(
"class:",
classes[class_idx],
", confidence:",
confidences[class_idx].item(),
"%, index:",
class_idx.item(),
)
i += 1
def main():
# load pre-trained model ------------------------------------
model = models.resnet50(pretrained=True)
# preprocessing stage ---------------------------------------
input = preprocess_image("turkish_coffee.jpg").cuda()
# inference stage -------------------------------------------
model.eval()
model.cuda()
output = model(input)
# post-processing stage -------------------------------------
postprocess(output)
# convert to ONNX -------------------------------------------
ONNX_FILE_PATH = "resnet50.onnx"
torch.onnx.export(model, input, ONNX_FILE_PATH, input_names=["input"], output_names=["output"], export_params=True)
onnx_model = onnx.load(ONNX_FILE_PATH)
# check that the model converted fine
onnx.checker.check_model(onnx_model)
print("Model was successfully converted to ONNX format.")
print("It was saved to", ONNX_FILE_PATH)
Use TensorRT C++ API
1. Preprocessing : Prepare input image for inference in OpenCV
To get the same result in TensorRT as in PyTorch we would prepare data for inference and repeat all preprocessing steps that we’ve taken before. Let’s create function PreprocessImage which would accept the path to the input image, float pointer (we will allocate the memory outside of the function) where we would store tensor after all transformations, and size of input of the model.
void preprocessImage(const std::string& image_path, float* gpu_input, const nvinfer1::Dims& dims)
Read the image using OpenCV as we did in Python and upload it to GPU:
{
cv::Mat frame = cv::imread(image_path);
if (frame.empty())
{
std::cerr << "Input image " << image_path << " load failed\n";
return;
}
cv::cuda::GpuMat gpu_frame;
// upload image to GPU
gpu_frame.upload(frame);
Resize:
auto input_width = dims.d[2];
auto input_height = dims.d[1];
auto channels = dims.d[0];
auto input_size = cv::Size(input_width, input_height);
// resize
cv::cuda::GpuMat resized;
cv::cuda::resize(gpu_frame, resized, input_size, 0, 0, cv::INTER_NEAREST);
Normalize:
cv::cuda::GpuMat flt_image;
resized.convertTo(flt_image, CV_32FC3, 1.f / 255.f);
cv::cuda::subtract(flt_image, cv::Scalar(0.485f, 0.456f, 0.406f), flt_image, cv::noArray(), -1);
cv::cuda::divide(flt_image, cv::Scalar(0.229f, 0.224f, 0.225f), flt_image, 1, -1);
ToTensor (copy data to output float pointer channel by channel):
std::vector< cv::cuda::GpuMat > chw;
for (size_t i = 0; i < channels; ++i)
{
chw.emplace_back(cv::cuda::GpuMat(input_size, CV_32FC1, gpu_input + i * input_width * input_height));
}
cv::cuda::split(flt_image, chw);
}
2. Post-processing
We know that as output we’ll get the array of 1000 float numbers. So we already can repeat the post-processing step. The output of inference would be in GPU memory so as a beginning we should copy it to CPU.
void postprocessResults(float *gpu_output, const nvinfer1::Dims &dims, int batch_size)
{
// get class names
auto classes = getClassNames("imagenet_classes.txt");
// copy results from GPU to CPU
std::vector< float > cpu_output(getSizeByDim(dims) * batch_size);
cudaMemcpy(cpu_output.data(), gpu_output, cpu_output.size() * sizeof(float), cudaMemcpyDeviceToHost);
Softmax formula:
To calculate it we will calculate the exponent of each element at cpu_output and then sum them all up. Then let’s find and print top predicted classes by the network.
// calculate softmax
std::transform(cpu_output.begin(), cpu_output.end(), cpu_output.begin(), [](float val) {return std::exp(val);});
auto sum = std::accumulate(cpu_output.begin(), cpu_output.end(), 0.0);
// find top classes predicted by the model
std::vector< int > indices(getSizeByDim(dims) * batch_size);
// generate sequence 0, 1, 2, 3, ..., 999
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&cpu_output](int i1, int i2) {return cpu_output[i1] > cpu_output[i2];});
// print results
int i = 0;
while (cpu_output[indices[i]] / sum > 0.005)
{
if (classes.size() > indices[i])
{
std::cout << "class: " << classes[indices[i]] << " | ";
}
std::cout << "confidence: " << 100 * cpu_output[indices[i]] / sum << "% | index: " << indices[i] << "n";
++i;
}
}
Tips: if in your case output is much larger than 1000 values it’s not a good solution to copy it from GPU to CPU. You could do post-processing step right on GPU using kernel CUDA functions or thrust vectors.
3. Setup TensorRT logger
To simplify the code let us use some utilities. TensorRT’s builder and engine required a logger to capture errors, warnings, and other information during the build and inference phases. In our case, we’re only going to print out errors ignoring warnings. You can get more info from the logger, including conversion steps and optimizations, with Severity::kVERBOSE
or just by removing the condition.
class Logger : public nvinfer1::ILogger
{
public:
void log(Severity severity, const char* msg) override {
// remove this 'if' if you need more logged info
if ((severity == Severity::kERROR) || (severity == Severity::kINTERNAL_ERROR)) {
std::cout << msg << "n";
}
}
} gLogger;
Create definition TRTUniquePtr for unique pointer of TensorRT’s classes:
// destroy TensorRT objects if something goes wrong
struct TRTDestroy
{
template< class T >
void operator()(T* obj) const
{
if (obj)
{
obj->destroy();
}
}
};
template< class T >
using TRTUniquePtr = std::unique_ptr< T, TRTDestroy >;
Calculate size of tensor if we have all dimensions:
size_t getSizeByDim(const nvinfer1::Dims& dims)
{
size_t size = 1;
for (size_t i = 0; i < dims.nbDims; ++i)
{
size *= dims.d[i];
}
return size;
}
Get the class names from the file imagenet_classes.txt:
std::vector< std::string > getClassNames(const std::string& imagenet_classes)
{
std::ifstream classes_file(imagenet_classes);
std::vector< std::string > classes;
if (!classes_file.good())
{
std::cerr << "ERROR: can't read file with classes names.n";
return classes;
}
std::string class_name;
while (std::getline(classes_file, class_name))
{
classes.push_back(class_name);
}
return classes;
}
4. Initialize model in TensorRT
Now it’s time to parse the ONNX model and initialize TensorRT Context and Engine. To do it we need to create an instance of Builder. The builder can create Network and generate Engine (that would be optimized to your platform\hardware) from this network. When we create Network we can define the structure of the network by flags, but in our case, it’s enough to use default flag which means all tensors would have an implicit batch dimension. With Network definition we can create an instance of Parser and finally, parse our ONNX file.
void parseOnnxModel(const std::string& model_path, TRTUniquePtr<nvinfer1::ICudaEngine>& engine,
TRTUniquePtr< nvinfer1::IExecutionContext >& context)
{
TRTUniquePtr< nvinfer1::IBuilder > builder{nvinfer1::createInferBuilder(gLogger)};
TRTUniquePtr< nvinfer1::INetworkDefinition > network{builder->createNetwork()};
TRTUniquePtr< nvonnxparser::IParser > parser{nvonnxparser::createParser(*network, gLogger)};
// parse ONNX
if (!parser->parseFromFile(model_path.c_str(), static_cast< int >(nvinfer1::ILogger::Severity::kINFO)))
{
std::cerr << "ERROR: could not parse the model.\n";
return;
}
It’s possible to Configure some engine parameters such as maximum memory allowed to use by TensorRT’s engine or set FP16 mode. We also should specify the size of the batch.
TRTUniquePtr< nvinfer1::IBuilderConfig > config{builder->createBuilderConfig()};
// allow TensorRT to use up to 1GB of GPU memory for tactic selection.
config->setMaxWorkspaceSize(1ULL << 30);
// use FP16 mode if possible
if (builder->platformHasFastFp16())
{
config->setFlag(nvinfer1::BuilderFlag::kFP16);
}
// we have only one image in batch
builder->setMaxBatchSize(1);
After that, we can generate the Engine and create the executable Context. The engine takes input data, performs inferences, and emits inference output.
engine.reset(builder->buildEngineWithConfig(*network, *config));
context.reset(engine->createExecutionContext());
}
Tips: Initialization can take a lot of time because TensorRT tries to find out the best and faster way to perform your network on your platform. To do it only once and then use the already created engine you can serialize your engine. Serialized engines are not portable across different GPU models, platforms, or TensorRT versions. Engines are specific to the exact hardware and software they were built on. More info can be found here: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#serial_model_c.
5. Main pipeline
So what would the full pipeline look like in C++? Let’s take a look at the main function:
int main(int argc, char* argv[])
{
if (argc < 3)
{
std::cerr << "usage: " << argv[0] << " model.onnx image.jpgn";
return -1;
}
std::string model_path(argv[1]);
std::string image_path(argv[2]);
int batch_size = 1;
Parse the model and initialize the engine and the context:
TRTUniquePtr< nvinfer1::ICudaEngine > engine{nullptr};
TRTUniquePtr< nvinfer1::IExecutionContext > context{nullptr};
parseOnnxModel(model_path, engine, context);
When we have the initialized engine we could find out dimensions of input and output in our program. To know that we can allocate memory required for input data and output data. In common cases, a model can have a bunch of inputs and outputs, but in our case, we know that we have only one input and one output. So the vector of buffers that we create to store memory for input and output would have just two elements.
std::vector< nvinfer1::Dims > input_dims; // we expect only one input
std::vector< nvinfer1::Dims > output_dims; // and one output
std::vector< void* > buffers(engine->getNbBindings()); // buffers for input and output data
for (size_t i = 0; i < engine->getNbBindings(); ++i)
{
auto binding_size = getSizeByDim(engine->getBindingDimensions(i)) * batch_size * sizeof(float);
cudaMalloc(&buffers[i], binding_size);
if (engine->bindingIsInput(i))
{
input_dims.emplace_back(engine->getBindingDimensions(i));
}
else
{
output_dims.emplace_back(engine->getBindingDimensions(i));
}
}
if (input_dims.empty() || output_dims.empty())
{
std::cerr << "Expect at least one input and one output for networkn";
return -1;
}
The final step: preprocess image, do inference, get results, and, of course, free used memory. That’s all!
// preprocess input data
PreprocessImage(image_path, (float*)buffers[0], input_dims[0]);
// inference
context->enqueue(batch_size, buffers.data(), 0, nullptr);
// post-process results
PostprocessResults((float *) buffers[1], output_dims[0], batch_size);
for (void* buf : buffers)
{
cudaFree(buf);
}
return 0;
}
6. Build Application
Time to test! To build the application we recommend using CMake. Please download CMakeLists.txt from the provided source files (or write your own). Then build and launch the app:
mkdir build
cd build
cmake -DOpenCV_DIR=[path-to-opencv-build] -DTensorRT_DIR=[path-to-tensorrt] ..
make -j8
trt_sample[.exe] resnet50.onnx turkish_coffee.jpg
For testing purpose we use the following image:
All results we get with the following configuration:
Ubuntu 18.04.4, AMD® Ryzen 7 2700x eight-core processor × 16, GeForce RTX 2070 SUPER, TensorRT 6.0.1.5, OpenCV 4.2, CUDA 10.2
7. Accuracy Test
We did some ad-hoc testing that is summarized in the table below.
Class | Index | PyTorch | TensorRT : FP32 | TensorRT : FP16 |
---|---|---|---|---|
cup | 968 | 92.430747% | 92.4308% | 92.2583% |
espresso | 967 | 6.138075% | 6.13806% | 6.27826% |
coffee mug | 504 | 0.728557% | 0.728555% | 0.743995% |
As we can see, the predicted classes match. Confidence is almost the same in FP32 mode (error less than 1e-05). In FP16 mode error is bigger (~0.002), but it’s still enough to get correct predictions.
Keep in mind there is no guarantee that you’ll get the same error in tests with different hardware, software, or even input picture. The error can depend on the initial benchmark decision and can be different with different cards.
8. Speed-up using TensorRT
To compare time in PyTorch and TensorRT we wouldn’t measure time of initialization of model because we initialize it only once. So we’ll compare inference time and detection time (preprocessing + inference + post-processing). At first launch, CUDA initialize and cache some data so the first call of any CUDA function is slower than usual. To account for this we run inference a few times and get an average time. And what we have:
In our example, we have achieved 4-6 times speed-up in FP16 mode and 2-3 times speed-up in FP32 mode.