Data is probably one of the most important things to deep learning. Nowadays, in many applications, not only the training data starts to explode, but also the evaluation data. In my previous post “PyTorch Distributed Training”, we have discussed how to run PyTorch distributed training to accelerate model training, but it seems that in some cases, model evaluation needs to be accelerated by distributed computing as well.
In this blog post, I would like to discuss how to use PyTorch and TorchMetrics to run PyTorch distributed evaluation. Specifically, I will evaluate the pre-trained ResNet-18 model from TorchVision models on a subset of ImageNet evaluation dataset.
Evaluation Dataset Preparation
Instead of using the full ImageNet dataset, we will use a smaller subset of the ImageNet dataset, ImageNet-1K, for evaluation. The dataset is roughly 260 MB and could be downloaded from MIT Han Lab.
1 2 3 4
$ wget https://hanlab.mit.edu/files/OnceForAll/ofa_cvpr_tutorial/imagenet_1k.zip $ unzip imagenet_1k.zip $ ls imagenet_1k labels.txt synset_words.txt train val
Docker Container
To make all the experiments reproducible, we used the NVIDIA NGC PyTorch Docker image.
1
$ docker run -it --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --network host -v $(pwd):/mnt -w /mnt nvcr.io/nvidia/pytorch:22.01-py3
In addition, please do install TorchMetrics 0.7.1 inside the Docker container.
1
$ pip install torchmetrics==0.7.1
Single-Node Single-GPU Evaluation
We created the implementation of single-node single-GPU evaluation, evaluate the pre-trained ResNet-18, and use the evaluation accuracy as the reference. The implementation was derived from the PyTorch official ImageNet example and should be easy to understand by most of the PyTorch users.
import os import time from enum import Enum import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models
# Most of the code were copied from # https://github.com/pytorch/examples/blob/00ea159a99f5cb3f3301a9bf0baa1a5089c7e217/imagenet/main.py
classSummary(Enum): NONE = 0 AVERAGE = 1 SUM = 2 COUNT = 3
defdisplay(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print("\t".join(entries))
defdisplay_summary(self): entries = [" *"] entries += [meter.summary() for meter in self.meters] print(" ".join(entries))
classAverageMeter(object): """Computes and stores the average and current value""" def__init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): self.name = name self.fmt = fmt self.summary_type = summary_type self.reset()
defsummary(self): fmtstr = "" if self.summary_type is Summary.NONE: fmtstr = "" elif self.summary_type is Summary.AVERAGE: fmtstr = "{name} {avg:.3f}" elif self.summary_type is Summary.SUM: fmtstr = "{name} {sum:.3f}" elif self.summary_type is Summary.COUNT: fmtstr = "{name} {count:.3f}" else: raise ValueError("invalid summary type %r" % self.summary_type)
return fmtstr.format(**self.__dict__)
defaccuracy(output, target, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred))
res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res
with torch.no_grad(): end = time.time() for i, (images, target) inenumerate(val_loader): if cuda_device isnotNone: images = images.to(cuda_device, non_blocking=True) if torch.cuda.is_available(): target = target.to(cuda_device, non_blocking=True)
# compute output output = model(images) loss = criterion(output, target)
# measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0))
# measure elapsed time batch_time.update(time.time() - end) end = time.time()
print_freq = 10 if i % print_freq == 0: progress.display(i)
progress.display_summary()
return top1.avg
defevaluate_imagenet():
# Specify the GPU used for evaluation cuda_device = torch.device("cuda:0")
# evaluate on validation set acc1 = validate(val_loader, model, criterion, cuda_device)
if __name__ == "__main__":
evaluate_imagenet()
Although the pre-trained ResNet-18 model was evaluated on a subset of the ImageNet evaluation dataset, the accuracy 69.300% is quite close to the accuracy 69.758% evaluated on the full ImageNet evaluation dataset, reported on the TorchVision models webpage.
$ python single_gpu_evaluation.py Test: [ 0/250] Time 1.002 ( 1.002) Loss 7.6721e-01 (7.6721e-01) Acc@1 50.00 ( 50.00) Acc@5 100.00 (100.00) Test: [ 10/250] Time 0.003 ( 0.094) Loss 2.6849e-01 (4.3245e-01) Acc@1 75.00 ( 88.64) Acc@5 100.00 ( 97.73) Test: [ 20/250] Time 0.003 ( 0.051) Loss 6.7839e-01 (9.7392e-01) Acc@1 75.00 ( 73.81) Acc@5 100.00 ( 95.24) Test: [ 30/250] Time 0.003 ( 0.036) Loss 9.0414e-01 (9.1842e-01) Acc@1 75.00 ( 76.61) Acc@5 100.00 ( 92.74) Test: [ 40/250] Time 0.004 ( 0.028) Loss 3.0528e+00 (9.2618e-01) Acc@1 25.00 ( 75.61) Acc@5 100.00 ( 93.29) Test: [ 50/250] Time 0.004 ( 0.024) Loss 1.0280e+00 (8.8351e-01) Acc@1 50.00 ( 75.00) Acc@5 100.00 ( 94.12) Test: [ 60/250] Time 0.003 ( 0.020) Loss 1.4846e+00 (8.8949e-01) Acc@1 50.00 ( 75.00) Acc@5 100.00 ( 93.44) Test: [ 70/250] Time 0.004 ( 0.018) Loss 4.3598e-01 (8.5601e-01) Acc@1 100.00 ( 76.06) Acc@5 100.00 ( 94.01) Test: [ 80/250] Time 0.003 ( 0.016) Loss 1.4468e+00 (9.1462e-01) Acc@1 75.00 ( 75.62) Acc@5 100.00 ( 93.52) Test: [ 90/250] Time 0.007 ( 0.016) Loss 1.9861e-02 (8.7991e-01) Acc@1 100.00 ( 76.37) Acc@5 100.00 ( 93.68) Test: [100/250] Time 0.003 ( 0.015) Loss 2.7188e+00 (9.0537e-01) Acc@1 25.00 ( 75.50) Acc@5 50.00 ( 93.07) Test: [110/250] Time 0.007 ( 0.014) Loss 1.4175e+00 (9.3323e-01) Acc@1 75.00 ( 75.23) Acc@5 75.00 ( 92.79) Test: [120/250] Time 0.003 ( 0.013) Loss 2.5602e+00 (9.9448e-01) Acc@1 50.00 ( 73.76) Acc@5 75.00 ( 91.74) Test: [130/250] Time 0.003 ( 0.012) Loss 1.5190e+00 (1.0699e+00) Acc@1 50.00 ( 72.14) Acc@5 100.00 ( 90.84) Test: [140/250] Time 0.003 ( 0.012) Loss 2.1484e+00 (1.0803e+00) Acc@1 75.00 ( 71.63) Acc@5 75.00 ( 90.96) Test: [150/250] Time 0.005 ( 0.011) Loss 4.9566e-01 (1.0947e+00) Acc@1 100.00 ( 71.69) Acc@5 100.00 ( 90.73) Test: [160/250] Time 0.003 ( 0.011) Loss 7.2313e-01 (1.0947e+00) Acc@1 75.00 ( 72.05) Acc@5 100.00 ( 90.68) Test: [170/250] Time 0.003 ( 0.010) Loss 2.2315e+00 (1.1162e+00) Acc@1 75.00 ( 71.35) Acc@5 75.00 ( 90.35) Test: [180/250] Time 0.003 ( 0.010) Loss 2.2024e+00 (1.1441e+00) Acc@1 50.00 ( 70.44) Acc@5 50.00 ( 90.06) Test: [190/250] Time 0.011 ( 0.010) Loss 8.5283e-01 (1.1437e+00) Acc@1 75.00 ( 70.42) Acc@5 75.00 ( 90.18) Test: [200/250] Time 0.004 ( 0.009) Loss 5.9278e-02 (1.1665e+00) Acc@1 100.00 ( 69.65) Acc@5 100.00 ( 89.68) Test: [210/250] Time 0.009 ( 0.009) Loss 1.9074e-01 (1.1649e+00) Acc@1 100.00 ( 69.79) Acc@5 100.00 ( 89.81) Test: [220/250] Time 0.004 ( 0.009) Loss 1.8873e+00 (1.1957e+00) Acc@1 50.00 ( 69.34) Acc@5 100.00 ( 89.37) Test: [230/250] Time 0.005 ( 0.009) Loss 1.4360e+00 (1.1980e+00) Acc@1 50.00 ( 69.37) Acc@5 100.00 ( 89.39) Test: [240/250] Time 0.004 ( 0.009) Loss 3.5062e+00 (1.1984e+00) Acc@1 25.00 ( 69.40) Acc@5 50.00 ( 89.32) * Acc@1 69.300 Acc@5 89.600
TorchMetrics Single-Node Multi-GPU Evaluation
TorchMetrics provides module metric to run evaluations using single GPU, multiple GPUs, or multiple nodes. This is the corresponding ResNet-18 TorchMetrics evaluation implementation for single-node multi-GPU evaluations.
import os import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import torchvision.datasets as datasets import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler import torchmetrics
# Actually, they are just local rank and local world size defmetric_ddp(rank, world_size):
with torch.no_grad(): for i, (images, target) inenumerate(val_loader): if cuda_device isnotNone: images = images.to(cuda_device, non_blocking=True) if torch.cuda.is_available(): target = target.to(cuda_device, non_blocking=True)
# compute output output = model(images) loss = criterion(output, target)
acc = metric(output, target)
print_freq = 10 if rank == 0and i % print_freq == 0: # print only for rank 0 print(f"Accuracy on batch {i}: {acc}")
# metric on all batches and all accelerators using custom accumulation # accuracy is same across both accelerators acc = metric.compute() print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")
# Reseting internal state such that metric ready for new data metric.reset()
# cleanup dist.destroy_process_group()
if __name__ == "__main__":
world_size = 1# number of gpus to parallize over mp.spawn(metric_ddp, args=(world_size, ), nprocs=world_size, join=True)
Notice that we intentionally set the world_size to be 1 to enforce the evaluation to use one single GPU. The multi-GPU evaluation implementation using one single GPU got exactly the same evaluation accuracy.
$ python distributed_evaluation.py Accuracy on batch 0: 0.75 Accuracy on batch 10: 0.75 Accuracy on batch 20: 0.5 Accuracy on batch 30: 0.75 Accuracy on batch 40: 1.0 Accuracy on batch 50: 0.5 Accuracy on batch 60: 0.75 Accuracy on batch 70: 0.75 Accuracy on batch 80: 1.0 Accuracy on batch 90: 0.5 Accuracy on batch 100: 0.5 Accuracy on batch 110: 0.5 Accuracy on batch 120: 0.75 Accuracy on batch 130: 0.75 Accuracy on batch 140: 0.5 Accuracy on batch 150: 0.5 Accuracy on batch 160: 0.75 Accuracy on batch 170: 0.5 Accuracy on batch 180: 0.75 Accuracy on batch 190: 1.0 Accuracy on batch 200: 1.0 Accuracy on batch 210: 1.0 Accuracy on batch 220: 0.75 Accuracy on batch 230: 1.0 Accuracy on batch 240: 0.25 Accuracy on all data: 0.6930000185966492, accelerator rank: 0
Let’s further proceed to using two GPUs for evaluation by changing the world_size from 1 to 2, namely,
1 2 3 4
if __name__ == "__main__":
world_size = 2# number of gpus to parallize over mp.spawn(metric_ddp, args=(world_size, ), nprocs=world_size, join=True)
The multi-GPU evaluation implementation using two GPUs also got exactly the same evaluation accuracy. Also notice that the number of batches becomes smaller as we used multiple GPUs for evaluation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
$ python distributed_evaluation.py Accuracy on batch 0: 0.25 Accuracy on batch 10: 0.75 Accuracy on batch 20: 1.0 Accuracy on batch 30: 1.0 Accuracy on batch 40: 0.75 Accuracy on batch 50: 0.75 Accuracy on batch 60: 1.0 Accuracy on batch 70: 0.75 Accuracy on batch 80: 0.5 Accuracy on batch 90: 0.75 Accuracy on batch 100: 1.0 Accuracy on batch 110: 0.75 Accuracy on batch 120: 0.25 Accuracy on all data: 0.6930000185966492, accelerator rank: 0 Accuracy on all data: 0.6930000185966492, accelerator rank: 1
TorchMetrics Multi-Node Multi-GPU Evaluation
Launching multi-node multi-GPU evaluation requires using tools such as torch.distributed.launch. I have discussed the usages of torch.distributed.launch for PyTorch distributed training in my previous post “PyTorch Distributed Training”, and I am not going to elaborate it here. More information could also be found on the PyTorch official example “Launching and Configuring Distributed Data Parallel Applications”.
Caveats
Let $N$ be the number of nodes on which the application is running and $G$ be the number of GPUs per node. The total number of application processes running across all the nodes at one time is called the world_size, $W$ and the number of processes running on each node is referred to as the local_world_size, $L$.
In the single-node multi-GPU scenario, we have the same value for world_size and nprocs and the values of them should be smaller or equal to the number of GPUs in the node. The world_size in this context really means the local_world_size in the node. So in the single-node multi-GPU scenario, world_size and nprocs has to be exactly the same by definition.
For example, in the single-node multi-GPU scenario, suppose $N = G= 8$, when $W = L = 8$, each process could use up to one single GPU; when $W = L = 1$, the single process could use up to 8 GPU.
That’s why in our single-node multi-GPU evaluation implementation, we have the following code for spawning jobs, where world_size = nprocs.