Nsight Compute In Docker

Introduction

NVIDIA Nsight Compute is an interactive profiler for CUDA that provides detailed performance metrics and API debugging via a user interface and command-line tool. Users can run guided analysis and compare results with a customizable and data-driven user interface, as well as post-process and analyze results in their own workflows.

In this blog post, I would like to discuss how to install and use Nsight Compute in Docker container so that we could use it and its GUI anywhere that has Docker installed.

Nsight Compute

Build Docker Image

It is possible to install Nsight Compute inside a Docker image and used it anywhere. The Dockerfile for building Nsight Compute is as follows.

nsight-compute.Dockerfile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
FROM nvcr.io/nvidia/cuda:12.0.1-devel-ubuntu22.04

ENV DEBIAN_FRONTEND noninteractive

RUN apt-get update -y && \
apt-get install -y --no-install-recommends \
apt-transport-https \
ca-certificates \
dbus \
fontconfig \
gnupg \
libasound2 \
libfreetype6 \
libglib2.0-0 \
libnss3 \
libsqlite3-0 \
libx11-xcb1 \
libxcb-glx0 \
libxcb-xkb1 \
libxcomposite1 \
libxcursor1 \
libxdamage1 \
libxi6 \
libxml2 \
libxrandr2 \
libxrender1 \
libxtst6 \
libgl1-mesa-glx \
libxkbfile-dev \
openssh-client \
wget \
xcb \
xkb-data && \
apt-get clean

# QT6 is required for the Nsight Compute UI.
RUN apt-get update -y && \
apt-get install -y --no-install-recommends \
qt6-base-dev && \
apt-get clean

To build the Docker image, please run the following command.

1
$ docker build -f nsight-compute.Dockerfile --no-cache --tag=nsight-compute:12.0.1 .

Upload Docker Image

To build the Docker image, please run the following command.

1
2
$ docker tag nsight-compute:12.0.1 leimao/nsight-compute:12.0.1
$ docker push leimao/nsight-compute:12.0.1

Pull Docker Image

To pull the Docker image, please run the following command.

1
2
$ docker pull leimao/nsight-compute:12.0.1
$ docker tag leimao/nsight-compute:12.0.1 nsight-compute:12.0.1

Run Docker Container

To run the Docker container, please run the following command.

1
2
3
$ xhost +
$ docker run -it --rm --gpus all -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix --cap-add=SYS_ADMIN --security-opt seccomp=unconfined -v $(pwd):/mnt --network=host nsight-compute:12.0.1
$ xhost -

Run Nsight Compute

To run Nsight Compute with GUI, please run the following command.

1
$ ncu-ui

We could now profile the applications from the Docker container, from the Docker local host machine via Docker mount, and from the remote host such as a remote workstation or an embedding device.

Examples

Non-Coalesced Memory Access VS Coalesced Memory Access

In this example, we implemented a naive GEMM kernel that performs matrix multiplication on the GPU. The kernel is naive because it did not use any advanced techniques such as shared memory tiling. Our original goal was to read and write the global memory in a coalesced manner. However, in the first version of the kernel, gemm_non_coalesced, we created a bug which swapped the row and column indices of the output matrix. Therefore, the kernel read and write the global memory in a non-coalesced manner. In the second version of the kernel, gemm_coalesced, we fixed the bug and read and write the global memory in a coalesced manner. We profiled the two kernels using Nsight Compute and compared the performance difference between the two kernels.

gemm_naive.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#include <iostream>

#include <cuda_runtime.h>

#define CHECK_CUDA_ERROR(val) check_cuda((val), #val, __FILE__, __LINE__)
void check_cuda(cudaError_t err, const char* const func, const char* const file,
const int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

#define CHECK_LAST_CUDA_ERROR() check_cuda_last(__FILE__, __LINE__)
void check_cuda_last(const char* const file, const int line)
{
cudaError_t const err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
}

// Non-coalesced read and write from global memory.
template <typename T>
__global__ void gemm_non_coalesced(size_t m, size_t n, size_t k, T alpha,
T const* A, size_t lda, T const* B,
size_t ldb, T beta, T* C, size_t ldc)
{
// Compute the row and column of C that this thread is responsible for.
size_t const C_row_idx{blockIdx.x * blockDim.x + threadIdx.x};
size_t const C_col_idx{blockIdx.y * blockDim.y + threadIdx.y};

// Each thread compute
// C[C_row_idx, C_col_idx] = alpha * A[C_row_idx, :] * B[:, C_col_idx] +
// beta * C[C_row_idx, C_col_idx].
if (C_row_idx < m && C_col_idx < n)
{
T sum{static_cast<T>(0)};
for (size_t k_idx{0U}; k_idx < k; ++k_idx)
{
sum += A[C_row_idx * lda + k_idx] * B[k_idx * ldb + C_col_idx];
}
C[C_row_idx * ldc + C_col_idx] =
alpha * sum + beta * C[C_row_idx * ldc + C_col_idx];
}
}

template <typename T>
void launch_gemm_kernel_non_coalesced(size_t m, size_t n, size_t k,
T const* alpha, T const* A, size_t lda,
T const* B, size_t ldb, T const* beta,
T* C, size_t ldc, cudaStream_t stream)
{
dim3 const block_dim{32U, 8U, 1U};
dim3 const grid_dim{
(static_cast<unsigned int>(m) + block_dim.x - 1U) / block_dim.x,
(static_cast<unsigned int>(n) + block_dim.y - 1U) / block_dim.y, 1U};
gemm_non_coalesced<T><<<grid_dim, block_dim, 0U, stream>>>(
m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc);
CHECK_LAST_CUDA_ERROR();
}

// Coalesced read and write from global memory.
template <typename T>
__global__ void gemm_coalesced(size_t m, size_t n, size_t k, T alpha,
T const* A, size_t lda, T const* B, size_t ldb,
T beta, T* C, size_t ldc)
{
// Compute the row and column of C that this thread is responsible for.
size_t const C_col_idx{blockIdx.x * blockDim.x + threadIdx.x};
size_t const C_row_idx{blockIdx.y * blockDim.y + threadIdx.y};

// Each thread compute
// C[C_row_idx, C_col_idx] = alpha * A[C_row_idx, :] * B[:, C_col_idx] +
// beta * C[C_row_idx, C_col_idx].
if (C_row_idx < m && C_col_idx < n)
{
T sum{static_cast<T>(0)};
for (size_t k_idx{0U}; k_idx < k; ++k_idx)
{
sum += A[C_row_idx * lda + k_idx] * B[k_idx * ldb + C_col_idx];
}
C[C_row_idx * ldc + C_col_idx] =
alpha * sum + beta * C[C_row_idx * ldc + C_col_idx];
}
}

template <typename T>
void launch_gemm_kernel_coalesced(size_t m, size_t n, size_t k, T const* alpha,
T const* A, size_t lda, T const* B,
size_t ldb, T const* beta, T* C, size_t ldc,
cudaStream_t stream)
{
dim3 const block_dim{32U, 8U, 1U};
dim3 const grid_dim{
(static_cast<unsigned int>(n) + block_dim.x - 1U) / block_dim.x,
(static_cast<unsigned int>(m) + block_dim.y - 1U) / block_dim.y, 1U};
gemm_coalesced<T><<<grid_dim, block_dim, 0U, stream>>>(
m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc);
CHECK_LAST_CUDA_ERROR();
}

template <typename T>
void gemm_cpu(size_t m, size_t n, size_t k, T alpha, T const* A, size_t lda,
T const* B, size_t ldb, T beta, T* C, size_t ldc)
{
for (size_t i{0U}; i < m; ++i)
{
for (size_t j{0U}; j < n; ++j)
{
T sum{static_cast<T>(0)};
for (size_t k_idx{0U}; k_idx < k; ++k_idx)
{
sum += A[i * lda + k_idx] * B[k_idx * ldb + j];
}
C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j];
}
}
}

template <typename T>
void verify_outputs(size_t m, size_t n, size_t ldc, T const* C, T const* C_ref,
T abs_error_tol)
{
for (size_t i{0U}; i < m; ++i)
{
for (size_t j{0U}; j < n; ++j)
{
T const abs_error{std::abs(C[i * ldc + j] - C_ref[i * ldc + j])};
if (abs_error > abs_error_tol)
{
std::cerr << "Error: i = " << i << ", j = " << j
<< ", abs_error = " << abs_error << std::endl;
std::exit(EXIT_FAILURE);
}
}
}
}

int main()
{
size_t const m{1024U};
size_t const n{1024U};
size_t const k{1024U};
float const alpha{1.0f};
float const beta{0.0f};
float const abs_error_tol{1e-5f};

size_t const lda{k};
size_t const ldb{n};
size_t const ldc{n};

cudaStream_t stream;
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));

// Allocate memory on the host.
float* A_host{nullptr};
float* B_host{nullptr};
float* C_host{nullptr};
float* C_host_from_device{nullptr};
CHECK_CUDA_ERROR(cudaMallocHost(&A_host, m * lda * sizeof(float)));
CHECK_CUDA_ERROR(cudaMallocHost(&B_host, k * ldb * sizeof(float)));
CHECK_CUDA_ERROR(cudaMallocHost(&C_host, m * ldc * sizeof(float)));
CHECK_CUDA_ERROR(
cudaMallocHost(&C_host_from_device, m * ldc * sizeof(float)));

// Initialize A and B.
for (size_t i{0U}; i < m; ++i)
{
for (size_t j{0U}; j < k; ++j)
{
A_host[i * lda + j] = static_cast<float>(i + j);
}
}
for (size_t i{0U}; i < k; ++i)
{
for (size_t j{0U}; j < n; ++j)
{
B_host[i * ldb + j] = static_cast<float>(i + j);
}
}

// Allocate memory on the device.
float* A_device{nullptr};
float* B_device{nullptr};
float* C_device{nullptr};
CHECK_CUDA_ERROR(cudaMalloc(&A_device, m * lda * sizeof(float)));
CHECK_CUDA_ERROR(cudaMalloc(&B_device, k * ldb * sizeof(float)));
CHECK_CUDA_ERROR(cudaMalloc(&C_device, m * ldc * sizeof(float)));

// Copy A and B to the device.
CHECK_CUDA_ERROR(cudaMemcpy(A_device, A_host, m * lda * sizeof(float),
cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(B_device, B_host, k * ldb * sizeof(float),
cudaMemcpyHostToDevice));

// Run the CPU version.
gemm_cpu(m, n, k, alpha, A_host, lda, B_host, ldb, beta, C_host, ldc);

// Launch the kernel.
launch_gemm_kernel_non_coalesced(m, n, k, &alpha, A_device, lda, B_device,
ldb, &beta, C_device, ldc, stream);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

// Copy C from the device.
CHECK_CUDA_ERROR(cudaMemcpy(C_host_from_device, C_device,
m * ldc * sizeof(float),
cudaMemcpyDeviceToHost));

// Compare the results.
verify_outputs(m, n, ldc, C_host_from_device, C_host, abs_error_tol);

// Launch the kernel.
launch_gemm_kernel_coalesced(m, n, k, &alpha, A_device, lda, B_device, ldb,
&beta, C_device, ldc, stream);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

// Copy C from the device.
CHECK_CUDA_ERROR(cudaMemcpy(C_host_from_device, C_device,
m * ldc * sizeof(float),
cudaMemcpyDeviceToHost));

// Compare the results.
verify_outputs(m, n, ldc, C_host_from_device, C_host, abs_error_tol);

// Free the memory.
CHECK_CUDA_ERROR(cudaFree(A_device));
CHECK_CUDA_ERROR(cudaFree(B_device));
CHECK_CUDA_ERROR(cudaFree(C_device));
CHECK_CUDA_ERROR(cudaFreeHost(A_host));
CHECK_CUDA_ERROR(cudaFreeHost(B_host));
CHECK_CUDA_ERROR(cudaFreeHost(C_host));
CHECK_CUDA_ERROR(cudaFreeHost(C_host_from_device));
}

To build the example using nvcc and profile the example using Nsight Compute, please run the following commands.

1
2
$ nvcc gemm_naive.cu -o gemm_naive
$ ncu --set full -f -o gemm_naive gemm_naive

To view the profiling results using Nsight Compute GUI, please run the following command.

1
$ ncu-ui
Nsight Compute Naive GEMM Profile Details

From the Nsight Compute profile details, we could see that the first version of the kernel gemm_non_coalesced took 9.85 ms, whereas the second version of the kernel gemm_coalesced took 1.20 ms. Despite the fact that neither of the two kernels is well optimized, Nsight Compute found a lot of issues with the two kernels. Specifically, the kernel gemm_non_coalesced is very memory-bound and Nsight Compute tells us that “This kernel has non-coalesced global accesses resulting in a total of 941359104 excessive sectors (85% of the total 1109393408 sectors)”. For example, for the kernel gemm_non_coalesced, L1/TEX Cache statistics shows that global load takes 16.51 sector per request ($\frac{32 \times 1 + 1}{2} = 16.5$), and global store takes 32 sectors per request ($\frac{32 \times 1}{1} = 32$), whereas for the kernel gemm_coalesced, which fixed the global memory coalesced access issue, global load takes 2.5 sector per request ($\frac{32 \times 4 / 32 + 1}{2} = 2.5$), and global store takes 4 sectors per request ($\frac{32 \times 4 / 32}{1} = 4$).

GitHub

All the Dockerfiles and examples are available on GitHub.

References

Author

Lei Mao

Posted on

01-02-2024

Updated on

01-02-2024

Licensed under


Comments