Kill PyTorch Distributed Training Processes
Introduction
PyTorch DistributedDataParallel is a convenient wrapper for distributed data parallel training. It is also compatible with distributed model parallel training. The major difference between PyTorch DistributedDataParallel and PyTorch DataParallel is that PyTorch DistributedDataParallel uses a multi-process algorithm and PyTorch DataParallel uses a single-process multi-thread algorithm. Usually, using PyTorch DistributedDataParallel, each process on each node uses one GPU, whereas using PyTorch DataParallel, each thread in the process uses one GPU.
To start PyTorch multi-node distributed training, usually we have to run python -m torch.distributed.launch commands on different nodes. For example, to start a two-node distributed training whose master node is using address 192.168.1.1 and port 1234.
On node one, we run the following command:
1 | $ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE |
On node two, we run the following command:
1 | $ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE |
Zombie Processes
Sometimes, we would kill the multi-process program using Ctrl + C. When we check the GPU usage using nvidia-smi after the kill, usually we would see the GPU utilities on the node becomes zero. However, when we restart the program, sometimes we would see the following error message:
1 | Traceback (most recent call last): |
This means that the address and the port is occupied and we are not allowed to start the distributed training using the previous address and port. Why would this happen? This is because when we hit Ctrl + C, only one process is killed and the rest processes on the node is not killed. Thus they are still occupying the address and port. We could confirm this by running top.
1 | PID USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND |
In this case, we have 8 GPUs on one node and thus 8 processes after program execution. After hitting Ctrl + C, one process is killed and we still have 7 processes left.
In order to release these resources and free the address and port, we could write down the PIDs of these processes and use kill to kill each of them. However, this looks dumb. Is there any way to do it smarter? To capture the PIDs automatically, we use ps instead of top.
1 | $ kill $(ps aux | grep YOUR_TRAINING_SCRIPT.py | grep -v grep | awk '{print $2}') |
Here $(ps aux | grep YOUR_TRAINING_SCRIPT.py | grep -v grep | awk '{print $2}') will return all the PIDs that YOUR_TRAINING_SCRIPT.py runs.
1 | $ ps aux | grep ddp.py | grep -v grep | awk '{print $2}' |
References
Kill PyTorch Distributed Training Processes
https://leimao.github.io/blog/Kill-PyTorch-Distributed-Training-Processes/