ResNet CIFAR Classification Using LibTorch C++ API
Introduction
The LibTorch C++ API has been around for a while. Although it claims that the LibTorch C++ API is mimicking the PyTorch Python API as much as possible, and tries to make sure the users have similar user experiences as using the PyTorch Python API during programming, it has not gain significant popularity in these years.
In this blog post, I created an implementation for the ResNet neural network family, an implementation for the CIFAR dataset, an training and inference program for ResNet classification on CIFAR dataset. I would like to briefly discuss my user experiences on the LibTorch C++ API.
ResNet-CIFAR Classification
The LibTorch C++ API only provides the common building block interfaces for neural networks and data. Probably because there are less community contributions, relatively high level implementations for neural networks and data, such as ResNet and CIFAR dataset, are not available. So I have to create an implementation for ResNet and CIFAR dataset from scratch using the most basic building block APIs.
The entire implementation could be found on my GitHub.
Dataset
Creating an implementation for CIFAR dataset is not difficult, provided that the user is familiar with data IO in C++. LibTorch also has an implementation for MNIST dataset for us to mimic.
The MNIST and CIFAR datasets are stored as binary files, which makes reading the dataset a lot easier. However, if the dataset consists of other data formats, such as JPEG, we would have to rely on other dependencies, such as OpenCV or LibJPEG, for reading the data. Depending on the dependencies, the API interface might be complicated.
DataLoader
Some key data augmentation implementations, such as flip, random crop, etc., are also missing in the LibTorch C++ library.
Module
Because we have been used to the simple abstracted PyTorch neural network module creation, creating a LibTorch neural network class becomes somewhat awkward, especially when we have to minimize the difference between the PyTorch implementation and the LibTorch implementation.
It is generally recommended to create the LibTorch neural network module using the following pattern.
1 | struct LinearImpl : torch::nn::Module { |
All the LibTorch native building blocks are created via this way. For example, torch::nn::Conv2d
is just a wrapper over std::shared_ptr<torch::nn::Conv2dImpl>
.
However, if the actual module implementation is a templated class, I have no idea how to “register” it using TORCH_MODULE
.
1 | template <typename T> |
Creating building blocks are not straightforward, as every building block has its option data structures. In PyTorch, we could use primitive types as the argument to create building blocks. For example, to create a relatively sophisticated torch::nn::Conv2d
, we have to study torch::nn::Conv2dOptions
, and use it to construct a torch::nn::Conv2d
object,
Optimization
The optimization capability in LibTorch is also limited. The user does not have too many choices for optimizer, learning rate scheduler, loss function, etc. Creating auxiliary functions for training and evaluation heavily relies on auto
and template because the type of many LibTorch objects are not obvious.
Portability
Currently, LibTorch does not support saving module state dict. Therefore, the model saved from LibTorch could only be used for LibTorch but not PyTorch. Similarly, the model saved from PyTorch could not be used for LibTorch, either. The LibTorch library supports loading the JIT model saved from PyTorch. However, it does not have the capability to export a JIT model.
Performance
Although I did not do rigorous apple to apple comparison, I feel there are no significant performance improvement when I use the LibTorch API to train a model or run inference.
Miscellaneous
Looking up LibTorch APIs is not convenient, as most of the APIs are not documented. The user would have to create two browser windows side by side to look up. One window is the LibTorch API documentation for the querying the API signature, and the other window is the PyTorch API documentation for querying the meanings of the arguments and the functionalities of the API.
Conclusions
Although I believe LibTorch is adding more and more useful APIs in every release, I will not recommend using it unless there are very special use cases. The performance gain over the programming inconvenience is negligible. Unless someday using the LibTorch C++ API is 10x faster than the using PyTorch Python API, I will not consider using the LibTorch C++ API. This performance difference in the future will certainly not be possible because PyTorch at the lower level is using the LibTorch C++ API.
References
ResNet CIFAR Classification Using LibTorch C++ API