PyTorch Automatic Mixed Precision Training

Introduction

Performing neural network training in lower precision, such as half precision, is significantly faster than training in full precision. PyTorch provides an automatic mixed precision (AMP) training interface to automatically handle the data type conversion and scaling for neural network training so that the neural network precision and model quality losses are minimized.

In this blog post, I would like to demonstrate how to use PyTorch automatic mixed precision training interface using an example of training a ResNet50 model on the CIFAR10 dataset.

PyTorch Automatic Mixed Precision Training

Enabling automatic mixed precision training using the torch.amp API is straightforward.

Full Precision Training

The following code shows how a neural network is usually trained in full precision.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for epoch in range(num_epochs):
for inputs, labels in train_loader:
# The data transfer takes ~100 ms on Intel i9-9900K + NVIDIA RTX 3090.
inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

outputs = model(inputs)
loss = criterion(outputs, labels)

loss.backward()
optimizer.step()

Automatic Mixed Precision Training

The following code shows how to use PyTorch automatic mixed precision (AMP) training interface to train a neural network in full and half mixed precision.

The use_amp can be used to dynamically enable or disable the automatic mixed precision training during the runtime.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
for epoch in range(num_epochs):
for inputs, labels in train_loader:
# The data transfer takes ~100 ms on Intel i9-9900K + NVIDIA RTX 3090.
inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# Use automatic mixed precision (AMP) for faster training.
with torch.autocast(device_type="cuda",
dtype=torch.float16,
enabled=use_amp):
outputs = model(inputs)
loss = criterion(outputs, labels)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

ResNet50 Training on CIFAR10 Dataset

The ResNet50 models were trained on the CIFAR10 dataset with or without automatic mixed precision on a machine with an Intel Core i9-9900K CPU and an NVIDIA GeForce RTX 3090 GPU. The source code could be found on GitHub.

Full Precision Training

The full precision training took 17 minutes to complete with a test accuracy of 87.6%.

1
2
3
4
5
$ python train.py --model_dir saved_models --model_filename resnet50_cifar10_fp32.pt
Training Model...
Training Elapsed Time: 00:17:13
Evaluating Model...
Test Accuracy: 0.876

Automatic Mixed Precision Training

The automatic mixed precision training took 12 minutes to complete with a test accuracy of 88.7%.

1
2
3
4
$ python train.py --model_dir saved_models --model_filename resnet50_cifar10_amp.pt --use_amp
Training Elapsed Time: 00:12:29
Evaluating Model...
Test Accuracy: 0.887

Conclusions

The PyTorch AMP accelerated the ResNet50 model training on the CIFAR10 dataset by 1.5x without any loss of accuracy.

References

Author

Lei Mao

Posted on

06-08-2024

Updated on

06-08-2024

Licensed under


Comments