ResNet CIFAR Classification Using LibTorch C++ API

Introduction

The LibTorch C++ API has been around for a while. Although it claims that the LibTorch C++ API is mimicking the PyTorch Python API as much as possible, and tries to make sure the users have similar user experiences as using the PyTorch Python API during programming, it has not gain significant popularity in these years.

In this blog post, I created an implementation for the ResNet neural network family, an implementation for the CIFAR dataset, an training and inference program for ResNet classification on CIFAR dataset. I would like to briefly discuss my user experiences on the LibTorch C++ API.

ResNet-CIFAR Classification

The LibTorch C++ API only provides the common building block interfaces for neural networks and data. Probably because there are less community contributions, relatively high level implementations for neural networks and data, such as ResNet and CIFAR dataset, are not available. So I have to create an implementation for ResNet and CIFAR dataset from scratch using the most basic building block APIs.

The entire implementation could be found on my GitHub.

Dataset

Creating an implementation for CIFAR dataset is not difficult, provided that the user is familiar with data IO in C++. LibTorch also has an implementation for MNIST dataset for us to mimic.

The MNIST and CIFAR datasets are stored as binary files, which makes reading the dataset a lot easier. However, if the dataset consists of other data formats, such as JPEG, we would have to rely on other dependencies, such as OpenCV or LibJPEG, for reading the data. Depending on the dependencies, the API interface might be complicated.

DataLoader

Some key data augmentation implementations, such as flip, random crop, etc., are also missing in the LibTorch C++ library.

Module

Because we have been used to the simple abstracted PyTorch neural network module creation, creating a LibTorch neural network class becomes somewhat awkward, especially when we have to minimize the difference between the PyTorch implementation and the LibTorch implementation.

It is generally recommended to create the LibTorch neural network module using the following pattern.

1
2
3
4
5
6
7
8
9
struct LinearImpl : torch::nn::Module {
LinearImpl(int64_t in, int64_t out);

Tensor forward(const Tensor& input);

Tensor weight, bias;
};

TORCH_MODULE(Linear); // Linear is now a wrapper over std::shared_ptr<LinearImpl>.

All the LibTorch native building blocks are created via this way. For example, torch::nn::Conv2d is just a wrapper over std::shared_ptr<torch::nn::Conv2dImpl>.

However, if the actual module implementation is a templated class, I have no idea how to “register” it using TORCH_MODULE.

1
2
3
4
5
6
7
8
template <typename T>
struct LinearImpl : torch::nn::Module {
LinearImpl(int64_t in, int64_t out);

Tensor forward(const Tensor& input);

Tensor weight, bias;
};

Creating building blocks are not straightforward, as every building block has its option data structures. In PyTorch, we could use primitive types as the argument to create building blocks. For example, to create a relatively sophisticated torch::nn::Conv2d, we have to study torch::nn::Conv2dOptions, and use it to construct a torch::nn::Conv2d object,

Optimization

The optimization capability in LibTorch is also limited. The user does not have too many choices for optimizer, learning rate scheduler, loss function, etc. Creating auxiliary functions for training and evaluation heavily relies on auto and template because the type of many LibTorch objects are not obvious.

Portability

Currently, LibTorch does not support saving module state dict. Therefore, the model saved from LibTorch could only be used for LibTorch but not PyTorch. Similarly, the model saved from PyTorch could not be used for LibTorch, either. The LibTorch library supports loading the JIT model saved from PyTorch. However, it does not have the capability to export a JIT model.

Performance

Although I did not do rigorous apple to apple comparison, I feel there are no significant performance improvement when I use the LibTorch API to train a model or run inference.

Miscellaneous

Looking up LibTorch APIs is not convenient, as most of the APIs are not documented. The user would have to create two browser windows side by side to look up. One window is the LibTorch API documentation for the querying the API signature, and the other window is the PyTorch API documentation for querying the meanings of the arguments and the functionalities of the API.

Conclusions

Although I believe LibTorch is adding more and more useful APIs in every release, I will not recommend using it unless there are very special use cases. The performance gain over the programming inconvenience is negligible. Unless someday using the LibTorch C++ API is 10x faster than the using PyTorch Python API, I will not consider using the LibTorch C++ API. This performance difference in the future will certainly not be possible because PyTorch at the lower level is using the LibTorch C++ API.

References

ResNet CIFAR Classification Using LibTorch C++ API

https://leimao.github.io/blog/LibTorch-ResNet-CIFAR/

Author

Lei Mao

Posted on

07-01-2021

Updated on

07-01-2021

Licensed under


Comments