PyTorch Automatic Mixed Precision Training
Introduction
Performing neural network training in lower precision, such as half precision, is significantly faster than training in full precision. PyTorch provides an automatic mixed precision (AMP) training interface to automatically handle the data type conversion and scaling for neural network training so that the neural network precision and model quality losses are minimized.
In this blog post, I would like to demonstrate how to use PyTorch automatic mixed precision training interface using an example of training a ResNet50 model on the CIFAR10 dataset.
PyTorch Automatic Mixed Precision Training
Enabling automatic mixed precision training using the torch.amp API is straightforward.
Full Precision Training
The following code shows how a neural network is usually trained in full precision.
1 | for epoch in range(num_epochs): |
Automatic Mixed Precision Training
The following code shows how to use PyTorch automatic mixed precision (AMP) training interface to train a neural network in full and half mixed precision.
The use_amp
can be used to dynamically enable or disable the automatic mixed precision training during the runtime.
1 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) |
ResNet50 Training on CIFAR10 Dataset
The ResNet50 models were trained on the CIFAR10 dataset with or without automatic mixed precision on a machine with an Intel Core i9-9900K CPU and an NVIDIA GeForce RTX 3090 GPU. The source code could be found on GitHub.
Full Precision Training
The full precision training took 17 minutes to complete with a test accuracy of 87.6%.
1 | $ python train.py --model_dir saved_models --model_filename resnet50_cifar10_fp32.pt |
Automatic Mixed Precision Training
The automatic mixed precision training took 12 minutes to complete with a test accuracy of 88.7%.
1 | $ python train.py --model_dir saved_models --model_filename resnet50_cifar10_amp.pt --use_amp |
Conclusions
The PyTorch AMP accelerated the ResNet50 model training on the CIFAR10 dataset by 1.5x without any loss of accuracy.
References
PyTorch Automatic Mixed Precision Training
https://leimao.github.io/blog/PyTorch-Automatic-Mixed-Precision-Training/