Lei Mao bio photo

Lei Mao

Machine Learning, Artificial Intelligence, Computer Science.

Twitter Facebook LinkedIn GitHub   G. Scholar E-Mail RSS


PyTorch DistributedDataParallel is a convenient wrapper for distributed data parallel training. It is also compatible with distributed model parallel training. The major difference between PyTorch DistributedDataParallel and PyTorch DataParallel is that PyTorch DistributedDataParallel uses a multi-process algorithm and PyTorch DataParallel uses a single-process multi-thread algorithm. Usually, using PyTorch DistributedDataParallel, each process on each node uses one GPU, whereas using PyTorch DataParallel, each thread in the process uses one GPU.

To start PyTorch multi-node distributed training, usually we have to run python -m torch.distributed.launch commands on different nodes. For example, to start a two-node distributed training whose master node is using address and port 1234.

On node one, we run the following command:

$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
             --nnodes=2 --node_rank=0 --master_addr=""
             --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
             and all other arguments of your training script)

On node two, we run the following command:

$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
             --nnodes=2 --node_rank=1 --master_addr=""
             --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
             and all other arguments of your training script)

Zombie Processes

Sometimes, we would kill the multi-process program using Ctrl + C. When we check the GPU usage using nvidia-smi after the kill, usually we would see the GPU utilities on the node becomes zero. However, when we restart the program, sometimes we would see the following error message:

Traceback (most recent call last):
  File "ddp.py", line 103, in <module>
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/distributed_c10d.py", line 393, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/rendezvous.py", line 172, in _env_rendezvous_handler
    store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
RuntimeError: Address already in use

This means that the address and the port is occupied and we are not allowed to start the distributed training using the previous address and port. Why would this happen? This is because when we hit Ctrl + C, only one process is killed and the rest processes on the node is not killed. Thus they are still occupying the address and port. We could confirm this by running top.

  PID USER      PR  NI    VIRT    RES    SHR S  %CPU %MEM     TIME+ COMMAND                          
 1750 root      20   0 39.210g 2.464g 538372 R 100.3  0.5   8:23.22 python                           
 1751 root      20   0 38.835g 2.407g 511792 R 100.3  0.5   8:14.84 python                           
 1753 root      20   0 39.013g 2.515g 519984 R 100.3  0.5   8:14.99 python                           
 1755 root      20   0 39.029g 2.507g 511508 R 100.3  0.5   8:14.73 python                           
 1752 root      20   0 38.837g 2.409g 511724 R 100.0  0.5   8:14.91 python                           
 1754 root      20   0 38.838g 2.409g 510980 R 100.0  0.5   8:14.91 python                           
 1756 root      20   0 38.845g 2.418g 512096 R 100.0  0.5   8:14.71 python                           
 1757 root      20   0 38.851g 2.423g 511780 R 100.0  0.5   8:14.90 python                                                     

In this case, we have 8 GPUs on one node and thus 8 processes after program execution. After hitting Ctrl + C, one process is killed and we still have 7 processes left.

In order to release these resources and free the address and port, we could write down the PIDs of these processes and use kill to kill each of them. However, this looks dumb. Is there any way to do it smarter? To capture the PIDs automatically, we use ps instead of top.

$ kill $(ps aux | grep YOUR_TRAINING_SCRIPT.py | grep -v grep | awk '{print $2}')

Here $(ps aux | grep YOUR_TRAINING_SCRIPT.py | grep -v grep | awk '{print $2}') will return all the PIDs that YOUR_TRAINING_SCRIPT.py runs.

$ ps aux | grep ddp.py | grep -v grep | awk '{print $2}'