PyTorch Distributed Training

Introduction

PyTorch has relatively simple interface for distributed training. To do distributed training, the model would just have to be wrapped using DistributedDataParallel and the training script would just have to be launched using torch.distributed.launch. Although PyTorch has offered a series of tutorials on distributed training, I found it insufficient or overwhelming to help the beginners to do state-of-the-art PyTorch distributed training. Some key details were missing and the usages of Docker container in distributed training were not mentioned at all.

In this blog post, I would like to present a simple implementation of PyTorch distributed training on CIFAR-10 classification using DistributedDataParallel wrapped ResNet models. The usage of Docker container for distributed training and how to start distributed training using torch.distributed.launch would also be covered.

Examples

Source Code

The entire training script consists of a hundred lines of code. Most of the code should be easy to understand.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np

def set_random_seeds(random_seed=0):

torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

def evaluate(model, device, test_loader):

model.eval()

correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = correct / total

return accuracy

def main():

num_epochs_default = 10000
batch_size_default = 256 # 1024
learning_rate_default = 0.1
random_seed_default = 0
model_dir_default = "saved_models"
model_filename_default = "resnet_distributed.pth"

# Each process runs on 1 GPU device specified by the local_rank argument.
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default)
parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default)
parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default)
parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default)
parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=model_dir_default)
parser.add_argument("--model_filename", type=str, help="Model filename.", default=model_filename_default)
parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.")
argv = parser.parse_args()

local_rank = argv.local_rank
num_epochs = argv.num_epochs
batch_size = argv.batch_size
learning_rate = argv.learning_rate
random_seed = argv.random_seed
model_dir = argv.model_dir
model_filename = argv.model_filename
resume = argv.resume

# Create directories outside the PyTorch program
# Do not create directory here because it is not multiprocess safe
'''
if not os.path.exists(model_dir):
os.makedirs(model_dir)
'''

model_filepath = os.path.join(model_dir, model_filename)

# We need to use seeds to make sure that the models initialized in different processes are the same
set_random_seeds(random_seed=random_seed)

# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
# torch.distributed.init_process_group(backend="gloo")

# Encapsulate the model on the GPU assigned to the current process
model = torchvision.models.resnet18(pretrained=False)

device = torch.device("cuda:{}".format(local_rank))
model = model.to(device)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

# We only save the model who uses device "cuda:0"
# To resume, the device for the saved model would also be "cuda:0"
if resume == True:
map_location = {"cuda:0": "cuda:{}".format(local_rank)}
ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location))

# Prepare dataset and dataloader
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Data should be prefetched
# Download should be set to be False, because it is not multiprocess safe
train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=False, transform=transform)
test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=False, transform=transform)

# Restricts data loading to a subset of the dataset exclusive to the current process
train_sampler = DistributedSampler(dataset=train_set)

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)
# Test loader does not have to follow distributed sampling strategy
test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

# Loop over the dataset multiple times
for epoch in range(num_epochs):

print("Local Rank: {}, Epoch: {}, Training ...".format(local_rank, epoch))

# Save and evaluate model routinely
if epoch % 10 == 0:
if local_rank == 0:
accuracy = evaluate(model=ddp_model, device=device, test_loader=test_loader)
torch.save(ddp_model.state_dict(), model_filepath)
print("-" * 75)
print("Epoch: {}, Accuracy: {}".format(epoch, accuracy))
print("-" * 75)

ddp_model.train()

for data in train_loader:
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

if __name__ == "__main__":

main()

Caveats

The caveats are as the follows:

  • Use --local_rank for argparse if we are going to use torch.distributed.launch to launch distributed training.
  • Set random seed to make sure that the models initialized in different processes are the same. (Updates on 3/19/2021: PyTorch DistributedDataParallel starts to make sure the model initial states are the same across different processes. So the purpose of setting random seed becomes reproducing the distributed training.)
  • Use DistributedDataParallel to wrap the model for distributed training.
  • Use DistributedSampler to training data loader.
  • To save models, each node would save a copy of the checkpoint file in the local hard drive.
  • Downloading dataset and making directories should be avoided in the distributed training program as they are not multi-process safe, unless we use some sort of barriers, such as torch.distributed.barrier.
  • The node communication bandwidth are extremely important for multi-node distributed training. Instead of randomly finding two computers in the network, try to use the nodes from the specialized computing clusters, since the communications between the nodes are highly optimized.

Launching Distributed Training

In this particular experiment, I tested the program using two nodes. Each of the nodes has 8 GPUs and each GPU would launch one process. We also used Docker container to make sure that the environments are exactly the same and reproducible.

Docker Container

To start Docker container, we have to make a copy of the script to each node in the distributed system, and run the following command in the terminal of each node.

1
$ docker run -it --gpus all --rm -v $(pwd):/mnt --network=host pytorch:1.5.0

Here, pytorch:1.5.0 is a Docker image which has PyTorch 1.5.0 installed (we could use NVIDIA’s PyTorch NGC Image), --network=host makes sure that the distributed network communication between nodes would not be prevented by Docker containerization.

Preparations

Download the dataset on each node before starting distributed training.

1
2
3
4
$ mkdir -p data
$ cd data
$ wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
$ tar -xvzf cifar-10-python.tar.gz

Creating directories for saving models before starting distributed training.

1
$ mkdir -p saved_models

Training From Scratch

In the Docker terminal of the first node, we run the following command.

1
$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py

Here, 192.168.0.1 is the IP address of the first node.

In the Docker terminal of the second node, we run the following command.

1
$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py

Note that the only difference between the two commands is --node_rank.

There would always be some delay between the execution of the two commands in the two nodes. But don’t worry. The first node would wait for the second node, and they would start and train together.

The following messages are expected from the two nodes.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Local Rank: 5, Epoch: 0, Training ...
Local Rank: 4, Epoch: 0, Training ...
Local Rank: 3, Epoch: 0, Training ...
Local Rank: 7, Epoch: 0, Training ...
Local Rank: 6, Epoch: 0, Training ...
Local Rank: 2, Epoch: 0, Training ...
Local Rank: 1, Epoch: 0, Training ...
Local Rank: 0, Epoch: 0, Training ...
---------------------------------------------------------------------------
Epoch: 0, Accuracy: 0.0
---------------------------------------------------------------------------
Local Rank: 6, Epoch: 1, Training ...
Local Rank: 2, Epoch: 1, Training ...
Local Rank: 4, Epoch: 1, Training ...
Local Rank: 0, Epoch: 1, Training ...
Local Rank: 1, Epoch: 1, Training ...
Local Rank: 5, Epoch: 1, Training ...
Local Rank: 3, Epoch: 1, Training ...
Local Rank: 7, Epoch: 1, Training ...
Local Rank: 4, Epoch: 2, Training ...
Local Rank: 2, Epoch: 2, Training ...
Local Rank: 5, Epoch: 2, Training ...
Local Rank: 6, Epoch: 2, Training ...
Local Rank: 0, Epoch: 2, Training ...
Local Rank: 1, Epoch: 2, Training ...
Local Rank: 3, Epoch: 2, Training ...
Local Rank: 7, Epoch: 2, Training ...
Local Rank: 6, Epoch: 3, Training ...
Local Rank: 0, Epoch: 3, Training ...
Local Rank: 1, Epoch: 3, Training ...
Local Rank: 4, Epoch: 3, Training ...
Local Rank: 2, Epoch: 3, Training ...
Local Rank: 5, Epoch: 3, Training ...
Local Rank: 7, Epoch: 3, Training ...
Local Rank: 3, Epoch: 3, Training ...
Local Rank: 6, Epoch: 4, Training ...
Local Rank: 5, Epoch: 4, Training ...
Local Rank: 2, Epoch: 4, Training ...
Local Rank: 0, Epoch: 4, Training ...
Local Rank: 3, Epoch: 4, Training ...
Local Rank: 4, Epoch: 4, Training ...
Local Rank: 7, Epoch: 4, Training ...
Local Rank: 1, Epoch: 4, Training ...
Local Rank: 7, Epoch: 5, Training ...
Local Rank: 6, Epoch: 5, Training ...
Local Rank: 2, Epoch: 5, Training ...
Local Rank: 0, Epoch: 5, Training ...
Local Rank: 4, Epoch: 5, Training ...
Local Rank: 5, Epoch: 5, Training ...
Local Rank: 1, Epoch: 5, Training ...
Local Rank: 3, Epoch: 5, Training ...
Local Rank: 4, Epoch: 6, Training ...
Local Rank: 5, Epoch: 6, Training ...
Local Rank: 6, Epoch: 6, Training ...
Local Rank: 2, Epoch: 6, Training ...
Local Rank: 0, Epoch: 6, Training ...
Local Rank: 1, Epoch: 6, Training ...
Local Rank: 3, Epoch: 6, Training ...
Local Rank: 7, Epoch: 6, Training ...
Local Rank: 2, Epoch: 7, Training ...
Local Rank: 5, Epoch: 7, Training ...
Local Rank: 0, Epoch: 7, Training ...
Local Rank: 4, Epoch: 7, Training ...
Local Rank: 6, Epoch: 7, Training ...
Local Rank: 1, Epoch: 7, Training ...
Local Rank: 3, Epoch: 7, Training ...
Local Rank: 7, Epoch: 7, Training ...
Local Rank: 2, Epoch: 8, Training ...
Local Rank: 0, Epoch: 8, Training ...
Local Rank: 5, Epoch: 8, Training ...
Local Rank: 4, Epoch: 8, Training ...
Local Rank: 1, Epoch: 8, Training ...
Local Rank: 3, Epoch: 8, Training ...
Local Rank: 6, Epoch: 8, Training ...
Local Rank: 7, Epoch: 8, Training ...
Local Rank: 1, Epoch: 9, Training ...
Local Rank: 5, Epoch: 9, Training ...
Local Rank: 2, Epoch: 9, Training ...
Local Rank: 6, Epoch: 9, Training ...
Local Rank: 4, Epoch: 9, Training ...
Local Rank: 0, Epoch: 9, Training ...
Local Rank: 3, Epoch: 9, Training ...
Local Rank: 7, Epoch: 9, Training ...
Local Rank: 6, Epoch: 10, Training ...
Local Rank: 0, Epoch: 10, Training ...
Local Rank: 1, Epoch: 10, Training ...
Local Rank: 5, Epoch: 10, Training ...
Local Rank: 4, Epoch: 10, Training ...
Local Rank: 2, Epoch: 10, Training ...
Local Rank: 3, Epoch: 10, Training ...
Local Rank: 7, Epoch: 10, Training ...
---------------------------------------------------------------------------
Epoch: 10, Accuracy: 0.1747
---------------------------------------------------------------------------

Sometimes, even if the hosts have NCCL, the distributed training would be frozen if the communication via NCCL has problems. To troubleshoot, please run distributed training on one single node to see if the training could be performed without any problem. For example, if a host has 8 GPUs, we could run two Docker containers on the host, each instance uses 4 GPUs for training.

Resume Training From Checkpoint

Sometimes, the training would be disrupted for some reason. To resume training, we added --resume as the argument for the program.

In the Docker terminal of the first node, we run the following command.

1
$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py --resume

In the Docker terminal of the second node, we run the following command.

1
$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py --resume

Kill Distributed Training

I have talked about how to kill PyTorch distributed training in “Kill PyTorch Distributed Training Processes”. So I am not going to elaborate it here.

1
$ kill $(ps aux | grep resnet_ddp.py | grep -v grep | awk '{print $2}')

Miscellaneous

In case the evaluation dataset is also very large, please consider using PyTorch distributed evaluation.

References

Author

Lei Mao

Posted on

04-26-2020

Updated on

02-06-2022

Licensed under


Comments