Lei Mao bio photo

Lei Mao

Machine Learning, Artificial Intelligence, Computer Science.

Twitter Facebook LinkedIn GitHub   G. Scholar E-Mail RSS

Introduction

PyTorch has relatively simple interface for distributed training. To do distributed training, the model would just have to be wrapped using DistributedDataParallel and the training script would just have to be launched using torch.distributed.launch. Although PyTorch has offered a series of tutorials on distributed training, I found it insufficient or overwhelming to help the beginners to do state-of-the-art PyTorch distributed training. Some key details were missing and the usages of Docker container in distributed training were not mentioned at all.


In this blog post, I would like to present a simple implementation of PyTorch distributed training on CIFAR-10 classification using DistributedDataParallel wrapped ResNet models. The usage of Docker container for distributed training and how to start distributed training using torch.distributed.launch would also be covered.

Examples

Source Code

The entire training script consists of a hundred lines of code. Most of the code should be easy to understand.

import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def evaluate(model, device, test_loader):

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy

def main():

    num_epochs_default = 10000
    batch_size_default = 256 # 1024
    learning_rate_default = 0.1
    random_seed_default = 0
    model_dir_default = "saved_models"
    model_filename_default = "resnet_distributed.pth"

    # Each process runs on 1 GPU device specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default)
    parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default)
    parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default)
    parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default)
    parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=model_dir_default)
    parser.add_argument("--model_filename", type=str, help="Model filename.", default=model_filename_default)
    parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.")
    argv = parser.parse_args()

    local_rank = argv.local_rank
    num_epochs = argv.num_epochs
    batch_size = argv.batch_size
    learning_rate = argv.learning_rate
    random_seed = argv.random_seed
    model_dir = argv.model_dir
    model_filename = argv.model_filename
    resume = argv.resume

    # Create directories outside the PyTorch program
    # Do not create directory here because it is not multiprocess safe
    '''
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    '''

    model_filepath = os.path.join(model_dir, model_filename)

    # We need to use seeds to make sure that the models initialized in different processes are the same
    set_random_seeds(random_seed=random_seed)

    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")
    # torch.distributed.init_process_group(backend="gloo")

    # Encapsulate the model on the GPU assigned to the current process
    model = torchvision.models.resnet18(pretrained=False)

    device = torch.device("cuda:{}".format(local_rank))
    model = model.to(device)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

    # We only save the model who uses device "cuda:0"
    # To resume, the device for the saved model would also be "cuda:0"
    if resume == True:
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location))

    # Prepare dataset and dataloader
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Data should be prefetched
    # Download should be set to be False, because it is not multiprocess safe
    train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=False, transform=transform) 
    test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=False, transform=transform)

    # Restricts data loading to a subset of the dataset exclusive to the current process
    train_sampler = DistributedSampler(dataset=train_set)

    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)
    # Test loader does not have to follow distributed sampling strategy
    test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # Loop over the dataset multiple times
    for epoch in range(num_epochs):

        print("Local Rank: {}, Epoch: {}, Training ...".format(local_rank, epoch))
        
        # Save and evaluate model routinely
        if epoch % 10 == 0:
            if local_rank == 0:
                accuracy = evaluate(model=ddp_model, device=device, test_loader=test_loader)
                torch.save(ddp_model.state_dict(), model_filepath)
                print("-" * 75)
                print("Epoch: {}, Accuracy: {}".format(epoch, accuracy))
                print("-" * 75)

        ddp_model.train()

        for data in train_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = ddp_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    
    main()

Caveats

The caveats are as the follows:

  • Use --local_rank for argparse if we are going to use torch.distributed.launch to launch distributed training.
  • Set random seed to make sure that the models initialized in different processes are the same.
  • Use DistributedDataParallel to wrap the model for distributed training.
  • Use DistributedSampler to training data loader.
  • To save models, each node would save a copy of the checkpoint file in the local hard drive.
  • Downloading dataset and making directories should be avoided in the distributed training program as they are not multi-process safe, unless we use some sort of barriers, such as torch.distributed.barrier.
  • The node communication bandwidth are extremely important for multi-node distributed training. Instead of randomly finding two computers in the network, try to use the nodes from the specialized computing clusters, since the communications between the nodes are highly optimized.

Launching Distributed Training

In this particular experiment, I tested the program using two nodes. Each of the nodes has 8 GPUs and each GPU would launch one process. We also used Docker container to make sure that the environments are exactly the same and reproducible.

Docker Container

To start Docker container, we have to make a copy of the script to each node in the distributed system, and run the following command in the terminal of each node.

$ docker run -it --gpus all --rm -v $(pwd):/mnt --network=host pytorch:1.5.0

Here, pytorch:1.5.0 is a Docker image which has PyTorch 1.5.0 installed (we could use NVIDIA’s PyTorch NGC Image), --network=host makes sure that the distributed network communication between nodes would not be prevented by Docker containerization.

Preparations

Download the dataset on each node before starting distributed training.

$ mkdir -p data
$ cd data
$ wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
$ tar -xvzf cifar-10-python.tar.gz

Creating directories for saving models before starting distributed training.

$ mkdir -p saved_models

Training From Scratch

In the Docker terminal of the first node, we run the following command.

$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py

Here, 192.168.0.1 is the IP address of the first node.


In the Docker terminal of the second node, we run the following command.

$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py

Note that the only difference between the two commands is --node_rank.


There would always be some delay between the execution of the two commands in the two nodes. But don’t worry. The first node would wait for the second node, and they would start and train together.


The following messages are expected from the two nodes.

*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Local Rank: 5, Epoch: 0, Training ...
Local Rank: 4, Epoch: 0, Training ...
Local Rank: 3, Epoch: 0, Training ...
Local Rank: 7, Epoch: 0, Training ...
Local Rank: 6, Epoch: 0, Training ...
Local Rank: 2, Epoch: 0, Training ...
Local Rank: 1, Epoch: 0, Training ...
Local Rank: 0, Epoch: 0, Training ...
---------------------------------------------------------------------------
Epoch: 0, Accuracy: 0.0
---------------------------------------------------------------------------
Local Rank: 6, Epoch: 1, Training ...
Local Rank: 2, Epoch: 1, Training ...
Local Rank: 4, Epoch: 1, Training ...
Local Rank: 0, Epoch: 1, Training ...
Local Rank: 1, Epoch: 1, Training ...
Local Rank: 5, Epoch: 1, Training ...
Local Rank: 3, Epoch: 1, Training ...
Local Rank: 7, Epoch: 1, Training ...
Local Rank: 4, Epoch: 2, Training ...
Local Rank: 2, Epoch: 2, Training ...
Local Rank: 5, Epoch: 2, Training ...
Local Rank: 6, Epoch: 2, Training ...
Local Rank: 0, Epoch: 2, Training ...
Local Rank: 1, Epoch: 2, Training ...
Local Rank: 3, Epoch: 2, Training ...
Local Rank: 7, Epoch: 2, Training ...
Local Rank: 6, Epoch: 3, Training ...
Local Rank: 0, Epoch: 3, Training ...
Local Rank: 1, Epoch: 3, Training ...
Local Rank: 4, Epoch: 3, Training ...
Local Rank: 2, Epoch: 3, Training ...
Local Rank: 5, Epoch: 3, Training ...
Local Rank: 7, Epoch: 3, Training ...
Local Rank: 3, Epoch: 3, Training ...
Local Rank: 6, Epoch: 4, Training ...
Local Rank: 5, Epoch: 4, Training ...
Local Rank: 2, Epoch: 4, Training ...
Local Rank: 0, Epoch: 4, Training ...
Local Rank: 3, Epoch: 4, Training ...
Local Rank: 4, Epoch: 4, Training ...
Local Rank: 7, Epoch: 4, Training ...
Local Rank: 1, Epoch: 4, Training ...
Local Rank: 7, Epoch: 5, Training ...
Local Rank: 6, Epoch: 5, Training ...
Local Rank: 2, Epoch: 5, Training ...
Local Rank: 0, Epoch: 5, Training ...
Local Rank: 4, Epoch: 5, Training ...
Local Rank: 5, Epoch: 5, Training ...
Local Rank: 1, Epoch: 5, Training ...
Local Rank: 3, Epoch: 5, Training ...
Local Rank: 4, Epoch: 6, Training ...
Local Rank: 5, Epoch: 6, Training ...
Local Rank: 6, Epoch: 6, Training ...
Local Rank: 2, Epoch: 6, Training ...
Local Rank: 0, Epoch: 6, Training ...
Local Rank: 1, Epoch: 6, Training ...
Local Rank: 3, Epoch: 6, Training ...
Local Rank: 7, Epoch: 6, Training ...
Local Rank: 2, Epoch: 7, Training ...
Local Rank: 5, Epoch: 7, Training ...
Local Rank: 0, Epoch: 7, Training ...
Local Rank: 4, Epoch: 7, Training ...
Local Rank: 6, Epoch: 7, Training ...
Local Rank: 1, Epoch: 7, Training ...
Local Rank: 3, Epoch: 7, Training ...
Local Rank: 7, Epoch: 7, Training ...
Local Rank: 2, Epoch: 8, Training ...
Local Rank: 0, Epoch: 8, Training ...
Local Rank: 5, Epoch: 8, Training ...
Local Rank: 4, Epoch: 8, Training ...
Local Rank: 1, Epoch: 8, Training ...
Local Rank: 3, Epoch: 8, Training ...
Local Rank: 6, Epoch: 8, Training ...
Local Rank: 7, Epoch: 8, Training ...
Local Rank: 1, Epoch: 9, Training ...
Local Rank: 5, Epoch: 9, Training ...
Local Rank: 2, Epoch: 9, Training ...
Local Rank: 6, Epoch: 9, Training ...
Local Rank: 4, Epoch: 9, Training ...
Local Rank: 0, Epoch: 9, Training ...
Local Rank: 3, Epoch: 9, Training ...
Local Rank: 7, Epoch: 9, Training ...
Local Rank: 6, Epoch: 10, Training ...
Local Rank: 0, Epoch: 10, Training ...
Local Rank: 1, Epoch: 10, Training ...
Local Rank: 5, Epoch: 10, Training ...
Local Rank: 4, Epoch: 10, Training ...
Local Rank: 2, Epoch: 10, Training ...
Local Rank: 3, Epoch: 10, Training ...
Local Rank: 7, Epoch: 10, Training ...
---------------------------------------------------------------------------
Epoch: 10, Accuracy: 0.1747
---------------------------------------------------------------------------

Sometimes, even if the hosts have NCCL, the distributed training would be frozen if the communication via NCCL has problems. To troubleshoot, please run distributed training on one single node to see if the training could be performed without any problem. For example, if a host has 8 GPUs, we could run two Docker containers on the host, each instance uses 4 GPUs for training.

Resume Training From Checkpoint

Sometimes, the training would be disrupted for some reason. To resume training, we added --resume as the argument for the program.


In the Docker terminal of the first node, we run the following command.

$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py --resume

In the Docker terminal of the second node, we run the following command.

$ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="192.168.0.1" --master_port=1234 resnet_ddp.py --resume

Kill Distributed Training

I have talked about how to kill PyTorch distributed training in “Kill PyTorch Distributed Training Processes”. So I am not going to elaborate it here.

$ kill $(ps aux | grep resnet_ddp.py | grep -v grep | awk '{print $2}')

References