CUDA Coalesced Memory Access

Introduction

In CUDA programming, accessing the GPU global memory from a CUDA kernel is usually a factor that will affect the CUDA kernel performance. To reduce global memory IO, we would like to reduce the number of global memory access by coalescing the global memory access and cache the reusable data in the fast shared memory.

In this blog post, I would like to discuss how to coalesce the GPU global memory read and write and use an example to show the performance improvement brought by coalescing both the global memory read and write.

CUDA Matrix Transpose

Implementations

In the following example, I implemented three CUDA kernels for (out-of-place) matrix transpose.

  1. The global memory read is coalesced whereas the global memory write is not.
  2. The global memory write is coalesced whereas the global memory read is not.
  3. The global memory read and write are both coalesced. This is implemented by using shared memory.
transpose.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdio>
#include <functional>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>

#include <cuda_runtime.h>

#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, char const* func, char const* file, int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

#define CHECK_LAST_CUDA_ERROR() check_last(__FILE__, __LINE__)
void check_last(char const* file, int line)
{
cudaError_t const err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <class T>
float measure_performance(std::function<T(cudaStream_t)> bound_function,
cudaStream_t stream, size_t num_repeats = 100,
size_t num_warmups = 100)
{
cudaEvent_t start, stop;
float time;

CHECK_CUDA_ERROR(cudaEventCreate(&start));
CHECK_CUDA_ERROR(cudaEventCreate(&stop));

for (size_t i{0}; i < num_warmups; ++i)
{
bound_function(stream);
}

CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

CHECK_CUDA_ERROR(cudaEventRecord(start, stream));
for (size_t i{0}; i < num_repeats; ++i)
{
bound_function(stream);
}
CHECK_CUDA_ERROR(cudaEventRecord(stop, stream));
CHECK_CUDA_ERROR(cudaEventSynchronize(stop));
CHECK_LAST_CUDA_ERROR();
CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop));
CHECK_CUDA_ERROR(cudaEventDestroy(start));
CHECK_CUDA_ERROR(cudaEventDestroy(stop));

float const latency{time / num_repeats};

return latency;
}

constexpr size_t div_up(size_t a, size_t b) { return (a + b - 1) / b; }

template <typename T>
__global__ void transpose_read_coalesced(T* output_matrix,
T const* input_matrix, size_t M,
size_t N)
{
size_t const j{threadIdx.x + blockIdx.x * blockDim.x};
size_t const i{threadIdx.y + blockIdx.y * blockDim.y};
size_t const from_idx{i * N + j};
if ((i < M) && (j < N))
{
size_t const to_idx{j * M + i};
output_matrix[to_idx] = input_matrix[from_idx];
}
}

template <typename T>
__global__ void transpose_write_coalesced(T* output_matrix,
T const* input_matrix, size_t M,
size_t N)
{
size_t const j{threadIdx.x + blockIdx.x * blockDim.x};
size_t const i{threadIdx.y + blockIdx.y * blockDim.y};
size_t const to_idx{i * M + j};
if ((i < N) && (j < M))
{
size_t const from_idx{j * N + i};
output_matrix[to_idx] = input_matrix[from_idx];
}
}

template <typename T>
void launch_transpose_read_coalesced(T* output_matrix, T const* input_matrix,
size_t M, size_t N, cudaStream_t stream)
{
constexpr size_t const warp_size{32};
dim3 const threads_per_block{warp_size, warp_size};
dim3 const blocks_per_grid{static_cast<unsigned int>(div_up(N, warp_size)),
static_cast<unsigned int>(div_up(M, warp_size))};
transpose_read_coalesced<<<blocks_per_grid, threads_per_block, 0, stream>>>(
output_matrix, input_matrix, M, N);
CHECK_LAST_CUDA_ERROR();
}

template <typename T>
void launch_transpose_write_coalesced(T* output_matrix, T const* input_matrix,
size_t M, size_t N, cudaStream_t stream)
{
constexpr size_t const warp_size{32};
dim3 const threads_per_block{warp_size, warp_size};
dim3 const blocks_per_grid{static_cast<unsigned int>(div_up(M, warp_size)),
static_cast<unsigned int>(div_up(N, warp_size))};
transpose_write_coalesced<<<blocks_per_grid, threads_per_block, 0,
stream>>>(output_matrix, input_matrix, M, N);
CHECK_LAST_CUDA_ERROR();
}

template <typename T, size_t BLOCK_SIZE = 32>
__global__ void transpose_read_write_coalesced(T* output_matrix,
T const* input_matrix, size_t M,
size_t N)
{
// BLOCK_SIZE + 1 for avoiding the shared memory bank conflicts.
// https://leimao.github.io/blog/CUDA-Shared-Memory-Bank/
// Try setting it to BLOCK_SIZE instead of BLOCK_SIZE + 1 to see the
// performance drop.
__shared__ T buffer[BLOCK_SIZE][BLOCK_SIZE + 1];

size_t const matrix_j{threadIdx.x + blockIdx.x * blockDim.x};
size_t const matrix_i{threadIdx.y + blockIdx.y * blockDim.y};
size_t const matrix_from_idx{matrix_i * N + matrix_j};

// We have two ways to write matrix data to the shared memory.
// 1. Write transposed matrix data from the DRAM to the shared memory and
// write the non-transposed matrix data from the shared memory to DRAM.
// 2. Write non-transposed matrix data from the DRAM to the shared memory
// and write the transposed matrix data from the shared memory to DRAM. Both
// should result in the same performance, even if there are shared memory
// access bank conflicts.

if ((matrix_i < M) && (matrix_j < N))
{
// The first approach.
buffer[threadIdx.x][threadIdx.y] = input_matrix[matrix_from_idx];
// The second approach.
// buffer[threadIdx.y][threadIdx.x] = input_matrix[matrix_from_idx];
}

// Make sure the buffer in a block is filled.
__syncthreads();

size_t const matrix_transposed_j{threadIdx.x + blockIdx.y * blockDim.y};
size_t const matrix_transposed_i{threadIdx.y + blockIdx.x * blockDim.x};

if ((matrix_transposed_i < N) && (matrix_transposed_j < M))
{
size_t const to_idx{matrix_transposed_i * M + matrix_transposed_j};
// The first approach.
output_matrix[to_idx] = buffer[threadIdx.y][threadIdx.x];
// The second approach.
// output_matrix[to_idx] = buffer[threadIdx.x][threadIdx.y];
}
}

template <typename T>
void launch_transpose_read_write_coalesced(T* output_matrix,
T const* input_matrix, size_t M,
size_t N, cudaStream_t stream)
{
constexpr size_t const warp_size{32};
dim3 const threads_per_block{warp_size, warp_size};
dim3 const blocks_per_grid{static_cast<unsigned int>(div_up(N, warp_size)),
static_cast<unsigned int>(div_up(M, warp_size))};
transpose_read_write_coalesced<T, warp_size>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(output_matrix,
input_matrix, M, N);
CHECK_LAST_CUDA_ERROR();
}

template <typename T>
bool is_equal(T const* data_1, T const* data_2, size_t size)
{
for (size_t i{0}; i < size; ++i)
{
if (data_1[i] != data_2[i])
{
return false;
}
}
return true;
}

template <typename T>
bool verify_transpose_implementation(
std::function<void(T*, T const*, size_t, size_t, cudaStream_t)>
transpose_function,
size_t M, size_t N)
{
// Fixed random seed for reproducibility
std::mt19937 gen{0};
cudaStream_t stream;
size_t const matrix_size{M * N};
std::vector<T> matrix(matrix_size, 0.0f);
std::vector<T> matrix_transposed(matrix_size, 1.0f);
std::vector<T> matrix_transposed_reference(matrix_size, 2.0f);
std::uniform_real_distribution<T> uniform_dist(-256, 256);
for (size_t i{0}; i < matrix_size; ++i)
{
matrix[i] = uniform_dist(gen);
}
// Create the reference transposed matrix using CPU.
for (size_t i{0}; i < M; ++i)
{
for (size_t j{0}; j < N; ++j)
{
size_t const from_idx{i * N + j};
size_t const to_idx{j * M + i};
matrix_transposed_reference[to_idx] = matrix[from_idx];
}
}
T* d_matrix;
T* d_matrix_transposed;
CHECK_CUDA_ERROR(cudaMalloc(&d_matrix, matrix_size * sizeof(T)));
CHECK_CUDA_ERROR(cudaMalloc(&d_matrix_transposed, matrix_size * sizeof(T)));
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));
CHECK_CUDA_ERROR(cudaMemcpy(d_matrix, matrix.data(),
matrix_size * sizeof(T),
cudaMemcpyHostToDevice));
transpose_function(d_matrix_transposed, d_matrix, M, N, stream);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
CHECK_CUDA_ERROR(cudaMemcpy(matrix_transposed.data(), d_matrix_transposed,
matrix_size * sizeof(T),
cudaMemcpyDeviceToHost));
bool const correctness{is_equal(matrix_transposed.data(),
matrix_transposed_reference.data(),
matrix_size)};
CHECK_CUDA_ERROR(cudaFree(d_matrix));
CHECK_CUDA_ERROR(cudaFree(d_matrix_transposed));
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
return correctness;
}

template <typename T>
void profile_transpose_implementation(
std::function<void(T*, T const*, size_t, size_t, cudaStream_t)>
transpose_function,
size_t M, size_t N)
{
constexpr int num_repeats{100};
constexpr int num_warmups{10};
cudaStream_t stream;
size_t const matrix_size{M * N};
T* d_matrix;
T* d_matrix_transposed;
CHECK_CUDA_ERROR(cudaMalloc(&d_matrix, matrix_size * sizeof(T)));
CHECK_CUDA_ERROR(cudaMalloc(&d_matrix_transposed, matrix_size * sizeof(T)));
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));

std::function<void(cudaStream_t)> const transpose_function_wrapped{
std::bind(transpose_function, d_matrix_transposed, d_matrix, M, N,
std::placeholders::_1)};
float const transpose_function_latency{measure_performance(
transpose_function_wrapped, stream, num_repeats, num_warmups)};
std::cout << std::fixed << std::setprecision(3)
<< "Latency: " << transpose_function_latency << " ms"
<< std::endl;
CHECK_CUDA_ERROR(cudaFree(d_matrix));
CHECK_CUDA_ERROR(cudaFree(d_matrix_transposed));
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
}

int main()
{
// Unit tests.
for (size_t m{1}; m <= 64; ++m)
{
for (size_t n{1}; n <= 64; ++n)
{
assert(verify_transpose_implementation<float>(
&launch_transpose_write_coalesced<float>, m, n));
assert(verify_transpose_implementation<float>(
&launch_transpose_read_coalesced<float>, m, n));
assert(verify_transpose_implementation<float>(
&launch_transpose_read_write_coalesced<float>, m, n));
}
}

// M: Number of rows.
size_t const M{12800};
// N: Number of columns.
size_t const N{12800};
std::cout << M << " x " << N << " Matrix" << std::endl;
std::cout << "Transpose Write Coalesced" << std::endl;
profile_transpose_implementation<float>(
&launch_transpose_write_coalesced<float>, M, N);
std::cout << "Transpose Read Coalesced" << std::endl;
profile_transpose_implementation<float>(
&launch_transpose_read_coalesced<float>, M, N);
std::cout << "Transpose Read and Write Coalesced" << std::endl;
profile_transpose_implementation<float>(
&launch_transpose_read_write_coalesced<float>, M, N);
}

Performance

The performances of the three CUDA kernels were measured using a $12800 \times 12800$ matrix. The reason why we used a square matrix for perfiormance measurement is that we want to compare the performance of global memory coalesced read and coalesced write as fair as possible.

Using -Xptxas -O0, we could disable all the NVCC compiler optimizations for the CUDA kernel. We could see that the kernel with global memory coalesced write is much faster than the kernel with global memory coalesced read, at least for this use case. By enabling both global memory coalesced read and write in the kernel, the kernel performance is the best among all the three kernels.

1
2
3
4
5
6
7
8
9
$ nvcc transpose.cu -o transpose -Xptxas -O0
$ ./transpose
12800 x 12800 Matrix
Transpose Write Coalesced
Latency: 5.220 ms
Transpose Read Coalesced
Latency: 7.624 ms
Transpose Read and Write Coalesced
Latency: 4.804 ms

Using -Xptxas -O3, which is the compiler default, we could enable all the NVCC compiler optimizations for the CUDA kernel. In this case, the CUDA kernel performance order of the three kernels remains the same.

1
2
3
4
5
6
7
8
9
$ nvcc transpose.cu -o transpose -Xptxas -O3
$ ./transpose
12800 x 12800 Matrix
Transpose Write Coalesced
Latency: 2.924 ms
Transpose Read Coalesced
Latency: 5.337 ms
Transpose Read and Write Coalesced
Latency: 2.345 ms

All the measurements were performed on a platform that has an Intel i9-9900K CPU and an NVIDIA RTX 3090 GPU.

Conclusions

In the CUDA kernel implementation, we should try to coalesce both the global memory read and write whenever it’s possible.

References

Author

Lei Mao

Posted on

03-19-2023

Updated on

03-19-2023

Licensed under


Comments