PyTorch Benchmark

Introduction

PyTorch benchmark is critical for developing fast PyTorch training and inference applications using GPU and CUDA.

In this blog post, I would like to discuss the correct way for benchmarking PyTorch applications.

PyTorch Benchmark

Synchronization

PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. However, when there are no such operations, the CPU thread and the CUDA stream could be out of sync, and the CPU thread will never know when certain CUDA operation finishes.

If the user uses a CPU timer to measure the elapsed time of a PyTorch application without synchronization, when the timer stops in the CPU thread, the CUDA operation might be still running, therefore the benchmark performance results will be incorrect.

Warmup Runs

In the benchmark, the first few runs could be slow if the GPU has not warmed up. So for the best practice, we always run a couple of warm up iterations that will not be counted in the profile results.

CPU Timer, CUDA Timer and PyTorch Benchmark Utilities

The time stamp of events could be measured on CPU using implementations such as time or timeit in Python. It could also be measured on CUDA using CUDA event such as the PyTorch torch.cuda.Event CUDA event wrapper. In addition, PyTorch has its own benchmark utilities that help the user run benchmarking. It takes care of the warmup runs and synchronizations automatically. In addition, the PyTorch benchmark utilities include the implementation for multi-thread benchmarking.

Implementation

Let’s benchmark a couple of PyTorch modules, including a custom convolution layer and a ResNet50, using CPU timer, CUDA timer and PyTorch benchmark utilities.

In our custom CPU and CUDA benchmark implementation, we will try placing the timer both outside and inside the iteration loop. We will also test the consequence of not running synchronization.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# benchmark_pytorch.py
from timeit import default_timer as timer
import torch
import torch.nn as nn
import torchvision
import torch.utils.benchmark as benchmark


@torch.no_grad()
def measure_time_host(
model: nn.Module,
input_tensor: torch.Tensor,
num_repeats: int = 100,
num_warmups: int = 10,
synchronize: bool = True,
continuous_measure: bool = True,
) -> float:

for _ in range(num_warmups):
_ = model.forward(input_tensor)
torch.cuda.synchronize()

elapsed_time_ms = 0

if continuous_measure:
start = timer()
for _ in range(num_repeats):
_ = model.forward(input_tensor)
if synchronize:
torch.cuda.synchronize()
end = timer()
elapsed_time_ms = (end - start) * 1000

else:
for _ in range(num_repeats):
start = timer()
_ = model.forward(input_tensor)
if synchronize:
torch.cuda.synchronize()
end = timer()
elapsed_time_ms += (end - start) * 1000

return elapsed_time_ms / num_repeats


@torch.no_grad()
def measure_time_device(
model: nn.Module,
input_tensor: torch.Tensor,
num_repeats: int = 100,
num_warmups: int = 10,
synchronize: bool = True,
continuous_measure: bool = True,
) -> float:

for _ in range(num_warmups):
_ = model.forward(input_tensor)
torch.cuda.synchronize()

elapsed_time_ms = 0

if continuous_measure:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_repeats):
_ = model.forward(input_tensor)
end_event.record()
if synchronize:
# This has to be synchronized to compute the elapsed time.
# Otherwise, there will be runtime error.
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)

else:
for _ in range(num_repeats):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
_ = model.forward(input_tensor)
end_event.record()
if synchronize:
# This has to be synchronized to compute the elapsed time.
# Otherwise, there will be runtime error.
torch.cuda.synchronize()
elapsed_time_ms += start_event.elapsed_time(end_event)

return elapsed_time_ms / num_repeats


@torch.no_grad()
def run_inference(model: nn.Module,
input_tensor: torch.Tensor) -> torch.Tensor:

return model.forward(input_tensor)


def main() -> None:

num_warmups = 100
num_repeats = 1000
input_shape = (1, 3, 224, 224)

device = torch.device("cuda:0")

# model = torchvision.models.resnet18(pretrained=False)
model = nn.Conv2d(in_channels=input_shape[1],
out_channels=256,
kernel_size=(5, 5))

model.to(device)
model.eval()

# Input tensor
input_tensor = torch.rand(input_shape, device=device)

torch.cuda.synchronize()

print("Latency Measurement Using CPU Timer...")
for continuous_measure in [True, False]:
for synchronize in [True, False]:
try:
latency_ms = measure_time_host(
model=model,
input_tensor=input_tensor,
num_repeats=num_repeats,
num_warmups=num_warmups,
synchronize=synchronize,
continuous_measure=continuous_measure,
)
print(f"|"
f"Synchronization: {synchronize!s:5}| "
f"Continuous Measurement: {continuous_measure!s:5}| "
f"Latency: {latency_ms:.5f} ms| ")
except Exception as e:
print(f"|"
f"Synchronization: {synchronize!s:5}| "
f"Continuous Measurement: {continuous_measure!s:5}| "
f"Latency: N/A ms| ")
torch.cuda.synchronize()

print("Latency Measurement Using CUDA Timer...")
for continuous_measure in [True, False]:
for synchronize in [True, False]:
try:
latency_ms = measure_time_device(
model=model,
input_tensor=input_tensor,
num_repeats=num_repeats,
num_warmups=num_warmups,
synchronize=synchronize,
continuous_measure=continuous_measure,
)
print(f"|"
f"Synchronization: {synchronize!s:5}| "
f"Continuous Measurement: {continuous_measure!s:5}| "
f"Latency: {latency_ms:.5f} ms| ")
except Exception as e:
print(f"|"
f"Synchronization: {synchronize!s:5}| "
f"Continuous Measurement: {continuous_measure!s:5}| "
f"Latency: N/A ms| ")
torch.cuda.synchronize()

print("Latency Measurement Using PyTorch Benchmark...")
num_threads = 1
timer = benchmark.Timer(stmt="run_inference(model, input_tensor)",
setup="from __main__ import run_inference",
globals={
"model": model,
"input_tensor": input_tensor
},
num_threads=num_threads,
label="Latency Measurement",
sub_label="torch.utils.benchmark.")

profile_result = timer.timeit(num_repeats)
# https://pytorch.org/docs/stable/_modules/torch/utils/benchmark/utils/common.html#Measurement
print(f"Latency: {profile_result.mean * 1000:.5f} ms")


if __name__ == "__main__":

main()

Docker Container

All the benchmarks were conducted using NVIDIA NGC PyTorch Docker container, Intel Core i9-9900K CPU, and NVIDIA RTX 2080 TI GPU.

1
$ docker run -it --rm --ipc=host --gpus all -v $(pwd):/mnt nvcr.io/nvidia/pytorch:21.08-py3

Benchmarks for Custom Convolution

1
2
3
4
5
6
7
8
9
10
11
12
13
$ python benchmark_pytorch.py
Latency Measurement Using CPU Timer...
|Synchronization: True | Continuous Measurement: True | Latency: 0.39115 ms|
|Synchronization: False| Continuous Measurement: True | Latency: 0.24755 ms|
|Synchronization: True | Continuous Measurement: False| Latency: 0.41779 ms|
|Synchronization: False| Continuous Measurement: False| Latency: 0.24659 ms|
Latency Measurement Using CUDA Timer...
|Synchronization: True | Continuous Measurement: True | Latency: 0.37559 ms|
|Synchronization: False| Continuous Measurement: True | Latency: N/A ms|
|Synchronization: True | Continuous Measurement: False| Latency: 0.39968 ms|
|Synchronization: False| Continuous Measurement: False| Latency: N/A ms|
Latency Measurement Using PyTorch Benchmark...
Latency: 0.37773 ms

As expected, the custom convolution layer benchmarks using CPU timer without synchronization underestimate the true PyTorch module latency.

Benchmarks for ResNet50

1
2
3
4
5
6
7
8
9
10
11
12
13
$ python benchmark_pytorch.py
Latency Measurement Using CPU Timer...
|Synchronization: True | Continuous Measurement: True | Latency: 2.08764 ms|
|Synchronization: False| Continuous Measurement: True | Latency: 2.08360 ms|
|Synchronization: True | Continuous Measurement: False| Latency: 2.23062 ms|
|Synchronization: False| Continuous Measurement: False| Latency: 2.09605 ms|
Latency Measurement Using CUDA Timer...
|Synchronization: True | Continuous Measurement: True | Latency: 2.08559 ms|
|Synchronization: False| Continuous Measurement: True | Latency: N/A ms|
|Synchronization: True | Continuous Measurement: False| Latency: 2.21680 ms|
|Synchronization: False| Continuous Measurement: False| Latency: N/A ms|
Latency Measurement Using PyTorch Benchmark...
Latency: 2.10802 ms

This time, the ResNet50 benchmarks using CPU timer without synchronization are very close to the one with synchronization. But it does not mean the way we measured the latency was correct. For PyTorch modules that consists of many small CUDA layers, each of which runs very fast on GPU, it is possible that the benchmarks with and without synchronization get very close.

Conclusions

Benchmark PyTorch applications using CPU timer, CUDA timer, or PyTorch Benchmark, and placing the timer outside or inside the iteration loop, are all fine, as long as we don’t forget to synchronize between the CPU thread and the CUDA stream, and we ensure the ways we benchmark are consistent throughout all the experiments.

References

Author

Lei Mao

Posted on

12-13-2021

Updated on

12-13-2021

Licensed under


Comments