PyTorch benchmark is critical for developing fast PyTorch training and inference applications using GPU and CUDA.
In this blog post, I would like to discuss the correct way for benchmarking PyTorch applications.
PyTorch Benchmark
Synchronization
PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. However, when there are no such operations, the CPU thread and the CUDA stream could be out of sync, and the CPU thread will never know when certain CUDA operation finishes.
If the user uses a CPU timer to measure the elapsed time of a PyTorch application without synchronization, when the timer stops in the CPU thread, the CUDA operation might be still running, therefore the benchmark performance results will be incorrect.
Warmup Runs
In the benchmark, the first few runs could be slow if the GPU has not warmed up. So for the best practice, we always run a couple of warm up iterations that will not be counted in the profile results.
CPU Timer, CUDA Timer and PyTorch Benchmark Utilities
The time stamp of events could be measured on CPU using implementations such as time or timeit in Python. It could also be measured on CUDA using CUDA event such as the PyTorch torch.cuda.Event CUDA event wrapper. In addition, PyTorch has its own benchmark utilities that help the user run benchmarking. It takes care of the warmup runs and synchronizations automatically. In addition, the PyTorch benchmark utilities include the implementation for multi-thread benchmarking.
Implementation
Let’s benchmark a couple of PyTorch modules, including a custom convolution layer and a ResNet50, using CPU timer, CUDA timer and PyTorch benchmark utilities.
In our custom CPU and CUDA benchmark implementation, we will try placing the timer both outside and inside the iteration loop. We will also test the consequence of not running synchronization.
# benchmark_pytorch.py from timeit import default_timer as timer import torch import torch.nn as nn import torchvision import torch.utils.benchmark as benchmark
for _ inrange(num_warmups): _ = model.forward(input_tensor) torch.cuda.synchronize()
elapsed_time_ms = 0
if continuous_measure: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ inrange(num_repeats): _ = model.forward(input_tensor) end_event.record() if synchronize: # This has to be synchronized to compute the elapsed time. # Otherwise, there will be runtime error. torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event)
else: for _ inrange(num_repeats): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() _ = model.forward(input_tensor) end_event.record() if synchronize: # This has to be synchronized to compute the elapsed time. # Otherwise, there will be runtime error. torch.cuda.synchronize() elapsed_time_ms += start_event.elapsed_time(end_event)
This time, the ResNet50 benchmarks using CPU timer without synchronization are very close to the one with synchronization. But it does not mean the way we measured the latency was correct. For PyTorch modules that consists of many small CUDA layers, each of which runs very fast on GPU, it is possible that the benchmarks with and without synchronization get very close.
Conclusions
Benchmark PyTorch applications using CPU timer, CUDA timer, or PyTorch Benchmark, and placing the timer outside or inside the iteration loop, are all fine, as long as we don’t forget to synchronize between the CPU thread and the CUDA stream, and we ensure the ways we benchmark are consistent throughout all the experiments.