Kill PyTorch Distributed Training Processes
Introduction
PyTorch DistributedDataParallel
is a convenient wrapper for distributed data parallel training. It is also compatible with distributed model parallel training. The major difference between PyTorch DistributedDataParallel
and PyTorch DataParallel
is that PyTorch DistributedDataParallel
uses a multi-process algorithm and PyTorch DataParallel
uses a single-process multi-thread algorithm. Usually, using PyTorch DistributedDataParallel
, each process on each node uses one GPU, whereas using PyTorch DataParallel
, each thread in the process uses one GPU.
To start PyTorch multi-node distributed training, usually we have to run python -m torch.distributed.launch
commands on different nodes. For example, to start a two-node distributed training whose master node is using address 192.168.1.1
and port 1234
.
On node one, we run the following command:
1 | $ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE |
On node two, we run the following command:
1 | $ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE |
Zombie Processes
Sometimes, we would kill the multi-process program using Ctrl
+ C
. When we check the GPU usage using nvidia-smi
after the kill, usually we would see the GPU utilities on the node becomes zero. However, when we restart the program, sometimes we would see the following error message:
1 | Traceback (most recent call last): |
This means that the address and the port is occupied and we are not allowed to start the distributed training using the previous address and port. Why would this happen? This is because when we hit Ctrl
+ C
, only one process is killed and the rest processes on the node is not killed. Thus they are still occupying the address and port. We could confirm this by running top
.
1 | PID USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND |
In this case, we have 8 GPUs on one node and thus 8 processes after program execution. After hitting Ctrl
+ C
, one process is killed and we still have 7 processes left.
In order to release these resources and free the address and port, we could write down the PID
s of these processes and use kill
to kill each of them. However, this looks dumb. Is there any way to do it smarter? To capture the PID
s automatically, we use ps
instead of top
.
1 | $ kill $(ps aux | grep YOUR_TRAINING_SCRIPT.py | grep -v grep | awk '{print $2}') |
Here $(ps aux | grep YOUR_TRAINING_SCRIPT.py | grep -v grep | awk '{print $2}')
will return all the PID
s that YOUR_TRAINING_SCRIPT.py
runs.
1 | $ ps aux | grep ddp.py | grep -v grep | awk '{print $2}' |
References
Kill PyTorch Distributed Training Processes
https://leimao.github.io/blog/Kill-PyTorch-Distributed-Training-Processes/