CUDA Matrix Multiplication

Introduction

CUDA is a parallel computing platform and programming language that allows software to use certain types of graphics processing unit (GPU) for general purpose processing, an approach called general-purpose computing on GPUs (GPGPU). It could significantly enhance the performance of programs that could be computed with massive parallelism.

Matrix multiplication is a typical application that could be computed with massive parallelism. In this blog post, I would like to present a “hello-world” CUDA example of matrix multiplications and its preliminary optimizations.

Matrix Multiplication

There are two common matrix multiplication forms. The ordinary matrix multiplication mm and the batched matrix multiplication bmm.

$$
\begin{align}
\mathbf{C}^{n \times p} &= \mathbf{A}^{n \times m} \mathbf{B}^{m \times p} \\
\mathbf{C}^{b \times n \times p} &= \mathbf{A}^{b \times n \times m} \mathbf{B}^{b \times m \times p} \\
\end{align}
$$

The reader could find the specifications of mm and bmm from PyTorch documentation torch.mm and torch.bmm.

In the following example, we first implemented the mm and bmm using C++. Then we implemented the mm using CUDA and naturally extended the mm implementation to the bmm implementation. Finally, we verified the correctness of the mm and bmm CUDA implementations.

Naive Implementation

This is the single source code file that contains the CPU and CUDA implementations for the matrix multiplication mm and the batched matrix multiplication bmm.

mm.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>

#define BLOCK_DIM 32

#define checkCuda(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, const char* const func, const char* const file,
const int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <typename T>
std::vector<T> create_rand_vector(size_t n)
{
std::random_device r;
std::default_random_engine e(r());
std::uniform_int_distribution<int> uniform_dist(-256, 256);

std::vector<T> vec(n);
for (size_t i{0}; i < n; ++i)
{
vec.at(i) = static_cast<T>(uniform_dist(e));
}

return vec;
}

// mat_1: m x n
// mat_2: n x p
// mat_3: m x p
template <typename T>
void mm(T const* mat_1, T const* mat_2, T* mat_3, size_t m, size_t n, size_t p)
{
// Compute the cells in mat_3 sequentially.
for (size_t i{0}; i < m; ++i)
{
for (size_t j{0}; j < p; ++j)
{
T acc_sum{0};
for (size_t k{0}; k < n; ++k)
{
acc_sum += mat_1[i * n + k] * mat_2[k * p + j];
}
mat_3[i * p + j] = acc_sum;
}
}
}

// mat_1: b x m x n
// mat_2: b x n x p
// mat_3: b x m x p
template <typename T>
void bmm(T const* mat_1, T const* mat_2, T* mat_3, size_t b, size_t m, size_t n,
size_t p)
{
// Iterate through the batch dimension.
for (size_t i{0}; i < b; ++i)
{
mm(mat_1 + i * (m * n), mat_2 + i * (n * p), mat_3 + i * (m * p), m, n,
p);
}
}

template <typename T>
__global__ void mm_kernel(T const* mat_1, T const* mat_2, T* mat_3, size_t m,
size_t n, size_t p)
{
// 2D block and 2D thread
// Each thread computes one cell in mat_3.
size_t i{blockIdx.y * blockDim.y + threadIdx.y};
size_t j{blockIdx.x * blockDim.x + threadIdx.x};

// Do not process outside the matrix.
// Do not forget the equal sign!
if ((i >= m) || (j >= p))
{
return;
}

T acc_sum{0};
for (size_t k{0}; k < n; ++k)
{
acc_sum += mat_1[i * n + k] * mat_2[k * p + j];
}
mat_3[i * p + j] = acc_sum;
}

// It should be straightforward to extend a kernel to support batching.
template <typename T>
__global__ void bmm_kernel(T const* mat_1, T const* mat_2, T* mat_3, size_t b,
size_t m, size_t n, size_t p)
{
// 2D block and 2D thread
// Each thread computes one cell in mat_3.
size_t i{blockIdx.y * blockDim.y + threadIdx.y};
size_t j{blockIdx.x * blockDim.x + threadIdx.x};
size_t l{blockIdx.z};

// Do not process outside the matrix.
// Do not forget the equal sign!
if ((i >= m) || (j >= p))
{
return;
}

T acc_sum{0};
for (size_t k{0}; k < n; ++k)
{
acc_sum += mat_1[l * m * n + i * n + k] * mat_2[l * n * p + k * p + j];
}
mat_3[l * m * p + i * p + j] = acc_sum;
}

template <typename T>
void mm_cuda(T const* mat_1, T const* mat_2, T* mat_3, size_t m, size_t n,
size_t p)
{
dim3 threads_per_block(BLOCK_DIM, BLOCK_DIM);
dim3 blocks_per_grid(1, 1);
blocks_per_grid.x = std::ceil(static_cast<double>(p) /
static_cast<double>(threads_per_block.x));
blocks_per_grid.y = std::ceil(static_cast<double>(m) /
static_cast<double>(threads_per_block.y));
mm_kernel<<<blocks_per_grid, threads_per_block>>>(mat_1, mat_2, mat_3, m, n,
p);
}

template <typename T>
void bmm_cuda(T const* mat_1, T const* mat_2, T* mat_3, size_t b, size_t m,
size_t n, size_t p)
{
dim3 threads_per_block(BLOCK_DIM, BLOCK_DIM);
dim3 blocks_per_grid(1, 1, 1);
blocks_per_grid.x = std::ceil(static_cast<double>(p) /
static_cast<double>(threads_per_block.x));
blocks_per_grid.y = std::ceil(static_cast<double>(m) /
static_cast<double>(threads_per_block.y));
blocks_per_grid.z = b;
bmm_kernel<<<blocks_per_grid, threads_per_block>>>(mat_1, mat_2, mat_3, b,
m, n, p);
}

template <typename T>
bool allclose(std::vector<T> const& vec_1, std::vector<T> const& vec_2,
T const& abs_tol)
{
if (vec_1.size() != vec_2.size())
{
return false;
}
for (size_t i{0}; i < vec_1.size(); ++i)
{
if (std::abs(vec_1.at(i) - vec_2.at(i)) > abs_tol)
{
std::cout << vec_1.at(i) << " " << vec_2.at(i) << std::endl;
return false;
}
}
return true;
}

template <typename T>
bool random_test_mm_cuda(size_t m, size_t n, size_t p)
{
std::vector<T> const mat_1_vec{create_rand_vector<T>(m * n)};
std::vector<T> const mat_2_vec{create_rand_vector<T>(n * p)};
std::vector<T> mat_3_vec(m * p);
std::vector<T> mat_4_vec(m * p);
T const* mat_1{mat_1_vec.data()};
T const* mat_2{mat_2_vec.data()};
T* mat_3{mat_3_vec.data()};
T* mat_4{mat_4_vec.data()};

mm(mat_1, mat_2, mat_3, m, n, p);

T *d_mat_1, *d_mat_2, *d_mat_4;

// Allocate device buffer.
checkCuda(cudaMalloc(&d_mat_1, sizeof(T) * mat_1_vec.size()));
checkCuda(cudaMalloc(&d_mat_2, sizeof(T) * mat_2_vec.size()));
checkCuda(cudaMalloc(&d_mat_4, sizeof(T) * mat_4_vec.size()));

// Copy data from host to device.
checkCuda(cudaMemcpy(d_mat_1, mat_1, sizeof(T) * mat_1_vec.size(),
cudaMemcpyHostToDevice));
checkCuda(cudaMemcpy(d_mat_2, mat_2, sizeof(T) * mat_2_vec.size(),
cudaMemcpyHostToDevice));

// Run matrix multiplication on GPU.
mm_cuda(d_mat_1, d_mat_2, d_mat_4, m, n, p);
cudaDeviceSynchronize();
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Matrix Multiplication kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}

// Copy data from device to host.
checkCuda(cudaMemcpy(mat_4, d_mat_4, sizeof(T) * mat_4_vec.size(),
cudaMemcpyDeviceToHost));

// Free device buffer.
checkCuda(cudaFree(d_mat_1));
checkCuda(cudaFree(d_mat_2));
checkCuda(cudaFree(d_mat_4));

return allclose<T>(mat_3_vec, mat_4_vec, 1e-4);
}

template <typename T>
bool random_test_bmm_cuda(size_t b, size_t m, size_t n, size_t p)
{
std::vector<T> const mat_1_vec{create_rand_vector<T>(b * m * n)};
std::vector<T> const mat_2_vec{create_rand_vector<T>(b * n * p)};
std::vector<T> mat_3_vec(b * m * p);
std::vector<T> mat_4_vec(b * m * p);
T const* mat_1{mat_1_vec.data()};
T const* mat_2{mat_2_vec.data()};
T* mat_3{mat_3_vec.data()};
T* mat_4{mat_4_vec.data()};

bmm(mat_1, mat_2, mat_3, b, m, n, p);

T *d_mat_1, *d_mat_2, *d_mat_4;

// Allocate device buffer.
checkCuda(cudaMalloc(&d_mat_1, sizeof(T) * mat_1_vec.size()));
checkCuda(cudaMalloc(&d_mat_2, sizeof(T) * mat_2_vec.size()));
checkCuda(cudaMalloc(&d_mat_4, sizeof(T) * mat_4_vec.size()));

// Copy data from host to device.
checkCuda(cudaMemcpy(d_mat_1, mat_1, sizeof(T) * mat_1_vec.size(),
cudaMemcpyHostToDevice));
checkCuda(cudaMemcpy(d_mat_2, mat_2, sizeof(T) * mat_2_vec.size(),
cudaMemcpyHostToDevice));

// Run matrix multiplication on GPU.
bmm_cuda(d_mat_1, d_mat_2, d_mat_4, b, m, n, p);
cudaDeviceSynchronize();
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Matrix Multiplication kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}

// Copy data from device to host.
checkCuda(cudaMemcpy(mat_4, d_mat_4, sizeof(T) * mat_4_vec.size(),
cudaMemcpyDeviceToHost));

// Free device buffer.
checkCuda(cudaFree(d_mat_1));
checkCuda(cudaFree(d_mat_2));
checkCuda(cudaFree(d_mat_4));

return allclose<T>(mat_3_vec, mat_4_vec, 1e-4);
}

template <typename T>
bool random_multiple_test_mm_cuda(size_t num_tests)
{
std::random_device r;
std::default_random_engine e(r());
std::uniform_int_distribution<int> uniform_dist(1, 256);

size_t m{0}, n{0}, p{0};
bool success{false};

for (size_t i{0}; i < num_tests; ++i)
{
m = static_cast<size_t>(uniform_dist(e));
n = static_cast<size_t>(uniform_dist(e));
p = static_cast<size_t>(uniform_dist(e));
success = random_test_mm_cuda<T>(m, n, p);
if (!success)
{
return false;
}
}

return true;
}

template <typename T>
bool random_multiple_test_bmm_cuda(size_t num_tests)
{
std::random_device r;
std::default_random_engine e(r());
std::uniform_int_distribution<int> uniform_dist(1, 256);

size_t b{0}, m{0}, n{0}, p{0};
bool success{false};

for (size_t i{0}; i < num_tests; ++i)
{
b = static_cast<size_t>(uniform_dist(e));
m = static_cast<size_t>(uniform_dist(e));
n = static_cast<size_t>(uniform_dist(e));
p = static_cast<size_t>(uniform_dist(e));
success = random_test_bmm_cuda<T>(b, m, n, p);
if (!success)
{
return false;
}
}

return true;
}

template <typename T>
float measure_latency_mm_cuda(size_t m, size_t n, size_t p, size_t num_tests,
size_t num_warmups)
{
cudaEvent_t startEvent, stopEvent;
float time{0.0f};

checkCuda(cudaEventCreate(&startEvent));
checkCuda(cudaEventCreate(&stopEvent));

T *d_mat_1, *d_mat_2, *d_mat_4;

// Allocate device buffer.
checkCuda(cudaMalloc(&d_mat_1, sizeof(T) * m * n));
checkCuda(cudaMalloc(&d_mat_2, sizeof(T) * n * p));
checkCuda(cudaMalloc(&d_mat_4, sizeof(T) * m * p));

for (size_t i{0}; i < num_warmups; ++i)
{
mm_cuda(d_mat_1, d_mat_2, d_mat_4, m, n, p);
}

checkCuda(cudaEventRecord(startEvent, 0));
for (size_t i{0}; i < num_tests; ++i)
{
mm_cuda(d_mat_1, d_mat_2, d_mat_4, m, n, p);
}
checkCuda(cudaEventRecord(stopEvent, 0));
checkCuda(cudaEventSynchronize(stopEvent));
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Matrix Multiplication kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
checkCuda(cudaEventElapsedTime(&time, startEvent, stopEvent));

// Free device buffer.
checkCuda(cudaFree(d_mat_1));
checkCuda(cudaFree(d_mat_2));
checkCuda(cudaFree(d_mat_4));

float latency{time / num_tests};

return latency;
}

template <typename T>
float measure_latency_bmm_cuda(size_t b, size_t m, size_t n, size_t p,
size_t num_tests, size_t num_warmups)
{
cudaEvent_t startEvent, stopEvent;
float time{0.0f};

checkCuda(cudaEventCreate(&startEvent));
checkCuda(cudaEventCreate(&stopEvent));

T *d_mat_1, *d_mat_2, *d_mat_4;

// Allocate device buffer.
checkCuda(cudaMalloc(&d_mat_1, sizeof(T) * b * m * n));
checkCuda(cudaMalloc(&d_mat_2, sizeof(T) * b * n * p));
checkCuda(cudaMalloc(&d_mat_4, sizeof(T) * b * m * p));

for (size_t i{0}; i < num_warmups; ++i)
{
bmm_cuda(d_mat_1, d_mat_2, d_mat_4, b, m, n, p);
}

checkCuda(cudaEventRecord(startEvent, 0));
for (size_t i{0}; i < num_tests; ++i)
{
bmm_cuda(d_mat_1, d_mat_2, d_mat_4, b, m, n, p);
}
checkCuda(cudaEventRecord(stopEvent, 0));
checkCuda(cudaEventSynchronize(stopEvent));
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Matrix Multiplication kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
checkCuda(cudaEventElapsedTime(&time, startEvent, stopEvent));

// Free device buffer.
checkCuda(cudaFree(d_mat_1));
checkCuda(cudaFree(d_mat_2));
checkCuda(cudaFree(d_mat_4));

float latency{time / num_tests};

return latency;
}

int main()
{
constexpr size_t num_tests{10};

assert(random_multiple_test_mm_cuda<int32_t>(num_tests));
assert(random_multiple_test_mm_cuda<float>(num_tests));
assert(random_multiple_test_mm_cuda<double>(num_tests));
assert(random_multiple_test_bmm_cuda<int32_t>(num_tests));
assert(random_multiple_test_bmm_cuda<float>(num_tests));
assert(random_multiple_test_bmm_cuda<double>(num_tests));

constexpr size_t num_measurement_tests{100};
constexpr size_t num_measurement_warmups{10};
size_t b{128}, m{1024}, n{1024}, p{1024};

float mm_cuda_int32_latency{measure_latency_mm_cuda<int32_t>(
m, n, p, num_measurement_tests, num_measurement_warmups)};
float mm_cuda_float_latency{measure_latency_mm_cuda<float>(
m, n, p, num_measurement_tests, num_measurement_warmups)};
float mm_cuda_double_latency{measure_latency_mm_cuda<double>(
m, n, p, num_measurement_tests, num_measurement_warmups)};

float bmm_cuda_int32_latency{measure_latency_bmm_cuda<int32_t>(
b, m, n, p, num_measurement_tests, num_measurement_warmups)};
float bmm_cuda_float_latency{measure_latency_bmm_cuda<float>(
b, m, n, p, num_measurement_tests, num_measurement_warmups)};
float bmm_cuda_double_latency{measure_latency_bmm_cuda<double>(
b, m, n, p, num_measurement_tests, num_measurement_warmups)};

std::cout << "Matrix Multiplication CUDA Latency" << std::endl;
std::cout << "m: " << m << " "
<< "n: " << n << " "
<< "p: " << p << std::endl;
std::cout << "INT32: " << std::fixed << std::setprecision(5)
<< mm_cuda_int32_latency << " ms" << std::endl;
std::cout << "FLOAT: " << std::fixed << std::setprecision(5)
<< mm_cuda_float_latency << " ms" << std::endl;
std::cout << "DOUBLE: " << std::fixed << std::setprecision(5)
<< mm_cuda_double_latency << " ms" << std::endl;

std::cout << "Batched Matrix Multiplication CUDA Latency" << std::endl;
std::cout << "b: " << b << " "
<< "m: " << m << " "
<< "n: " << n << " "
<< "p: " << p << std::endl;
std::cout << "INT32: " << std::fixed << std::setprecision(5)
<< bmm_cuda_int32_latency << " ms" << std::endl;
std::cout << "FLOAT: " << std::fixed << std::setprecision(5)
<< bmm_cuda_float_latency << " ms" << std::endl;
std::cout << "DOUBLE: " << std::fixed << std::setprecision(5)
<< bmm_cuda_double_latency << " ms" << std::endl;
}

Run Naive Example

Building and running the example requires an NVIDIA GPU. We also used NVIDIA official Docker container to set up the building environment.

To start the Docker container, please run the following command on the host computer.

1
$ docker run -it --rm --gpus all --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/mnt nvcr.io/nvidia/cuda:11.7.1-devel-ubuntu22.04

To build and run the application, please run the following command in the Docker container.

1
2
3
4
5
6
7
8
9
10
11
12
13
$ cd /mnt/
$ nvcc mm.cu -o mm -std=c++14
$ ./mm
Matrix Multiplication CUDA Latency
m: 1024 n: 1024 p: 1024
INT32: 1.11436 ms
FLOAT: 0.98451 ms
DOUBLE: 4.10433 ms
Batched Matrix Multiplication CUDA Latency
b: 128 m: 1024 n: 1024 p: 1024
INT32: 125.26781 ms
FLOAT: 124.67697 ms
DOUBLE: 487.87039 ms

We should expect no assertion error or any other kind of error for build and execution. The latencies were measured on a NVIDIA RTX 3090 GPU.

Matrix Multiplication Optimizations

The CUDA kernel optimization is usually all about how to accelerate the data traffic without affecting the number of math operations. To get the CUDA kernel fully optimized for GPU, the user would have to be very experienced with low-level GPU features and specifications and CUDA programming. But this does not prevent us from doing some preliminary optimization based on some shallow understandings of GPU.

Make Matrix Multiplication More Math-Bound

GPU is very friendly with math-bound operations. According to my previous blog post “Math-Bound VS Memory-Bound Operations”, if the number of operations remains the same and the number of memory IO bytes gets reduced, the operation will become more math-bound. That is to say, we want

$$
\begin{gather}
\frac{N_{\text{op}}}{N_{\text{byte}}} > \frac{\text{BW}_{\text{math}}}{\text{BW}_{\text{mem}}}
\end{gather}
$$

In our matrix multiplication naive CUDA implementation,

$$
\begin{align}
\mathbf{C}^{n \times p} &= \mathbf{A}^{n \times m} \mathbf{B}^{m \times p} \\
\end{align}
$$

We have to do $mnp$ multiplication and additions, $2mnp$ reads from memory, and $mp$ writes to memory. We could ignore the $mp$ writes from memory IO because the $2mnp$ reads is usually much more than the $mp$ writes.

Suppose we are doing FP32 matrix multiplication,

$$
\begin{align}
\frac{N_{\text{op}}}{N_{\text{byte}}}
&= \frac{2 \times mnp}{2mnp \times 4} \\
&= \frac{1}{4} \\
\end{align}
$$

For a modern GPU such as NVIDIA RTX 3090, for FP32 math,

$$
\begin{align}
\frac{\text{BW}_{\text{math}}}{\text{BW}_{\text{mem}}} &= \frac{35.58}{0.936} \\
&= 38.0 \\
\end{align}
$$

We could see that the naive CUDA matrix multiplication implementation does not get even close to math-bound. Since $N_{\text{op}}$ should be a constant in matrix multiplication, let’s see if we could reduce $N_{\text{byte}}$ by caching.

Ideally, if we could cache the two full operand matrices $\mathbf{A}^{n \times m}$ and $\mathbf{B}^{m \times p}$, we could make the matrix multiplication most math-bound. However, since the caching size is limited and the implementation is supposed to support matrix multiplications with all different sizes, caching the full matrices is not technically possible.

Matrix Multiplication Decomposition

It is possible to decompose matrix multiplication mm into smaller matrix multiplications.

$$
\mathbf{A} =
\begin{bmatrix}
\mathbf{A}_{1,1}^{d \times d} & \mathbf{A}_{1,2}^{d \times d} & \cdots & \mathbf{A}_{1,n/d}^{d \times d} \\
\mathbf{A}_{2,1}^{d \times d} & \mathbf{A}_{2,2}^{d \times d} & \cdots & \mathbf{A}_{2,n/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
\mathbf{A}_{m/d,1}^{d \times d} & \mathbf{A}_{m/d,2}^{d \times d} & \cdots & \mathbf{A}_{m/d,n/d}^{d \times d} \\
\end{bmatrix}
$$

$$
\mathbf{B} =
\begin{bmatrix}
\mathbf{B}_{1,1}^{d \times d} & \mathbf{B}_{1,2}^{d \times d} & \cdots & \mathbf{B}_{1,p/d}^{d \times d} \\
\mathbf{B}_{2,1}^{d \times d} & \mathbf{B}_{2,2}^{d \times d} & \cdots & \mathbf{B}_{2,p/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
\mathbf{B}_{n/d,1}^{d \times d} & \mathbf{B}_{n/d,2}^{d \times d} & \cdots & \mathbf{B}_{n/d,p/d}^{d \times d} \\
\end{bmatrix}
$$

$$
\mathbf{C} =
\begin{bmatrix}
\mathbf{C}_{1,1}^{d \times d} & \mathbf{C}_{1,2}^{d \times d} & \cdots & \mathbf{C}_{1,p/d}^{d \times d} \\
\mathbf{C}_{2,1}^{d \times d} & \mathbf{C}_{2,2}^{d \times d} & \cdots & \mathbf{C}_{2,p/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
\mathbf{C}_{m/d,1}^{d \times d} & \mathbf{C}_{m/d,2}^{d \times d} & \cdots & \mathbf{C}_{m/d,p/d}^{d \times d} \\
\end{bmatrix}
$$

$$
\mathbf{C}_{i,j}^{d \times d} = \sum_{k=1}^{n/d} \mathbf{A}_{i,k}^{d \times d} \mathbf{B}_{k,j}^{d \times d}
$$

The decomposition does not alter the number of operations $N_{\text{op}}$.

$$
\begin{align}
N_{\text{op}} &= 2d^3 \left( \frac{n}{d} \right) \left( \frac{m}{d} \frac{p}{d}\right) \\
&= 2mnp \\
\end{align}
$$

Because small matrices $\mathbf{A}_{i,k}^{d \times d}$ and $\mathbf{B}_{k,j}^{d \times d}$ could be cached, the memory IO bytes could be reduced, and the overall matrix multiplication could become more math bound. Let’s calculate how much memory IO bytes is needed in this case.

$$
\begin{align}
N_{\text{byte}} &= 2d^2 \times 4 \times \left( \frac{n}{d} \right) \left( \frac{m}{d} \frac{p}{d}\right) \\
&= \frac{8mnp}{d} \\
\end{align}
$$

Therefore,

$$
\begin{align}
\frac{N_{\text{op}}}{N_{\text{byte}}}
&= \frac{2mnp}{\frac{8mnp}{d}} \\
&= \frac{d}{4} \\
\end{align}
$$

Notice that when $d=1$, the matrix multiplication falls back to the naive matrix multiplication. When $d$ becomes larger, the implementation becomes more math-bound.

Optimized Implementation

The following implementation decomposed the matrix multiplication into multiple small matrix multiplications. The source code could be found on GitHub.

mm_optimization.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>

#define BLOCK_DIM 32

#define checkCuda(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, const char* const func, const char* const file,
const int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <typename T>
std::vector<T> create_rand_vector(size_t n)
{
std::random_device r;
std::default_random_engine e(r());
std::uniform_int_distribution<int> uniform_dist(-256, 256);

std::vector<T> vec(n);
for (size_t i{0}; i < n; ++i)
{
vec.at(i) = static_cast<T>(uniform_dist(e));
}

return vec;
}

// mat_1: m x n
// mat_2: n x p
// mat_3: m x p
template <typename T>
void mm(T const* mat_1, T const* mat_2, T* mat_3, size_t m, size_t n, size_t p)
{
// Compute the cells in mat_3 sequentially.
for (size_t i{0}; i < m; ++i)
{
for (size_t j{0}; j < p; ++j)
{
T acc_sum{0};
for (size_t k{0}; k < n; ++k)
{
acc_sum += mat_1[i * n + k] * mat_2[k * p + j];
}
mat_3[i * p + j] = acc_sum;
}
}
}

template <typename T>
__global__ void mm_kernel(T const* mat_1, T const* mat_2, T* mat_3, size_t m,
size_t n, size_t p)
{
// 2D block and 2D thread
// Each thread computes one cell in mat_3.
size_t i{blockIdx.y * blockDim.y + threadIdx.y};
size_t j{blockIdx.x * blockDim.x + threadIdx.x};

// Do not process outside the matrix.
// Do not forget the equal sign!
if ((i >= m) || (j >= p))
{
return;
}

T acc_sum{0};
for (size_t k{0}; k < n; ++k)
{
acc_sum += mat_1[i * n + k] * mat_2[k * p + j];
}
mat_3[i * p + j] = acc_sum;
}

template <typename T>
__global__ void mm_kernel_optimized(T const* mat_1, T const* mat_2, T* mat_3,
size_t m, size_t n, size_t p)
{
__shared__ T mat_1_tile[BLOCK_DIM][BLOCK_DIM];
__shared__ T mat_2_tile[BLOCK_DIM][BLOCK_DIM];

T acc_sum{0};

for (size_t tile_idx{0};
tile_idx < ceilf(static_cast<float>(n) / BLOCK_DIM); ++tile_idx)
{
size_t i{blockIdx.y * blockDim.y + threadIdx.y};
size_t j{tile_idx * blockDim.x + threadIdx.x};
if ((i < m) && (j < n))
{
mat_1_tile[threadIdx.y][threadIdx.x] = mat_1[i * n + j];
}
else
{
mat_1_tile[threadIdx.y][threadIdx.x] = 0;
}
i = tile_idx * blockDim.y + threadIdx.y;
j = blockIdx.x * blockDim.x + threadIdx.x;
if ((i < n) && (j < p))
{
mat_2_tile[threadIdx.y][threadIdx.x] = mat_2[i * p + j];
}
else
{
mat_2_tile[threadIdx.y][threadIdx.x] = 0;
}
__syncthreads();
for (size_t k{0}; k < BLOCK_DIM; ++k)
{
acc_sum += mat_1_tile[threadIdx.y][k] * mat_2_tile[k][threadIdx.x];
}
__syncthreads();
}

// 2D block and 2D thread
// Each thread computes one cell in mat_3.
size_t i{blockIdx.y * blockDim.y + threadIdx.y};
size_t j{blockIdx.x * blockDim.x + threadIdx.x};

if ((i < m) && (j < p))
{
mat_3[i * p + j] = acc_sum;
}
}

template <typename T>
void mm_cuda(T const* mat_1, T const* mat_2, T* mat_3, size_t m, size_t n,
size_t p,
void (*f)(T const*, T const*, T*, size_t, size_t, size_t))
{
dim3 threads_per_block(BLOCK_DIM, BLOCK_DIM);
dim3 blocks_per_grid(1, 1);
blocks_per_grid.x = std::ceil(static_cast<double>(p) /
static_cast<double>(threads_per_block.x));
blocks_per_grid.y = std::ceil(static_cast<double>(m) /
static_cast<double>(threads_per_block.y));
f<<<blocks_per_grid, threads_per_block>>>(mat_1, mat_2, mat_3, m, n, p);
}

template <typename T>
bool allclose(std::vector<T> const& vec_1, std::vector<T> const& vec_2,
T const& abs_tol)
{
if (vec_1.size() != vec_2.size())
{
return false;
}
for (size_t i{0}; i < vec_1.size(); ++i)
{
if (std::abs(vec_1.at(i) - vec_2.at(i)) > abs_tol)
{
std::cout << vec_1.at(i) << " " << vec_2.at(i) << std::endl;
return false;
}
}
return true;
}

template <typename T>
bool random_test_mm_cuda(size_t m, size_t n, size_t p,
void (*f)(T const*, T const*, T*, size_t, size_t,
size_t))
{
std::vector<T> const mat_1_vec{create_rand_vector<T>(m * n)};
std::vector<T> const mat_2_vec{create_rand_vector<T>(n * p)};
std::vector<T> mat_3_vec(m * p);
std::vector<T> mat_4_vec(m * p);
T const* mat_1{mat_1_vec.data()};
T const* mat_2{mat_2_vec.data()};
T* mat_3{mat_3_vec.data()};
T* mat_4{mat_4_vec.data()};

mm(mat_1, mat_2, mat_3, m, n, p);

T *d_mat_1, *d_mat_2, *d_mat_4;

// Allocate device buffer.
checkCuda(cudaMalloc(&d_mat_1, sizeof(T) * mat_1_vec.size()));
checkCuda(cudaMalloc(&d_mat_2, sizeof(T) * mat_2_vec.size()));
checkCuda(cudaMalloc(&d_mat_4, sizeof(T) * mat_4_vec.size()));

// Copy data from host to device.
checkCuda(cudaMemcpy(d_mat_1, mat_1, sizeof(T) * mat_1_vec.size(),
cudaMemcpyHostToDevice));
checkCuda(cudaMemcpy(d_mat_2, mat_2, sizeof(T) * mat_2_vec.size(),
cudaMemcpyHostToDevice));

// Run matrix multiplication on GPU.
mm_cuda(d_mat_1, d_mat_2, d_mat_4, m, n, p, f);
cudaDeviceSynchronize();
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Matrix Multiplication kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
// Copy data from device to host.
checkCuda(cudaMemcpy(mat_4, d_mat_4, sizeof(T) * mat_4_vec.size(),
cudaMemcpyDeviceToHost));

// Free device buffer.
checkCuda(cudaFree(d_mat_1));
checkCuda(cudaFree(d_mat_2));
checkCuda(cudaFree(d_mat_4));

return allclose<T>(mat_3_vec, mat_4_vec, 1e-4);
}

template <typename T>
bool random_multiple_test_mm_cuda(size_t num_tests,
void (*f)(T const*, T const*, T*, size_t,
size_t, size_t))
{
std::random_device r;
std::default_random_engine e(r());
std::uniform_int_distribution<int> uniform_dist(1, 256);

size_t m{0}, n{0}, p{0};
bool success{false};

for (size_t i{0}; i < num_tests; ++i)
{
m = static_cast<size_t>(uniform_dist(e));
n = static_cast<size_t>(uniform_dist(e));
p = static_cast<size_t>(uniform_dist(e));
success = random_test_mm_cuda<T>(m, n, p, f);
if (!success)
{
return false;
}
}

return true;
}

template <typename T>
float measure_latency_mm_cuda(size_t m, size_t n, size_t p, size_t num_tests,
size_t num_warmups,
void (*f)(T const*, T const*, T*, size_t, size_t,
size_t))
{
cudaEvent_t startEvent, stopEvent;
float time{0.0f};

checkCuda(cudaEventCreate(&startEvent));
checkCuda(cudaEventCreate(&stopEvent));

T *d_mat_1, *d_mat_2, *d_mat_4;

// Allocate device buffer.
checkCuda(cudaMalloc(&d_mat_1, sizeof(T) * m * n));
checkCuda(cudaMalloc(&d_mat_2, sizeof(T) * n * p));
checkCuda(cudaMalloc(&d_mat_4, sizeof(T) * m * p));

for (size_t i{0}; i < num_warmups; ++i)
{
mm_cuda(d_mat_1, d_mat_2, d_mat_4, m, n, p, f);
}

checkCuda(cudaEventRecord(startEvent, 0));
for (size_t i{0}; i < num_tests; ++i)
{
mm_cuda(d_mat_1, d_mat_2, d_mat_4, m, n, p, f);
}
checkCuda(cudaEventRecord(stopEvent, 0));
checkCuda(cudaEventSynchronize(stopEvent));
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Matrix Multiplication kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
checkCuda(cudaEventElapsedTime(&time, startEvent, stopEvent));

// Free device buffer.
checkCuda(cudaFree(d_mat_1));
checkCuda(cudaFree(d_mat_2));
checkCuda(cudaFree(d_mat_4));

float latency{time / num_tests};

return latency;
}

int main()
{
constexpr size_t num_tests{10};

assert(random_multiple_test_mm_cuda<int32_t>(num_tests, mm_kernel));
assert(random_multiple_test_mm_cuda<float>(num_tests, mm_kernel));
assert(random_multiple_test_mm_cuda<double>(num_tests, mm_kernel));

assert(
random_multiple_test_mm_cuda<int32_t>(num_tests, mm_kernel_optimized));
assert(random_multiple_test_mm_cuda<float>(num_tests, mm_kernel_optimized));
assert(
random_multiple_test_mm_cuda<double>(num_tests, mm_kernel_optimized));

constexpr size_t num_measurement_tests{100};
constexpr size_t num_measurement_warmups{10};
const size_t m{1024}, n{1024}, p{1024};

float mm_cuda_int32_latency{measure_latency_mm_cuda<int32_t>(
m, n, p, num_measurement_tests, num_measurement_warmups, mm_kernel)};
float mm_cuda_float_latency{measure_latency_mm_cuda<float>(
m, n, p, num_measurement_tests, num_measurement_warmups, mm_kernel)};
float mm_cuda_double_latency{measure_latency_mm_cuda<double>(
m, n, p, num_measurement_tests, num_measurement_warmups, mm_kernel)};

std::cout << "Matrix Multiplication CUDA Latency" << std::endl;
std::cout << "m: " << m << " "
<< "n: " << n << " "
<< "p: " << p << std::endl;
std::cout << "INT32: " << std::fixed << std::setprecision(5)
<< mm_cuda_int32_latency << " ms" << std::endl;
std::cout << "FLOAT: " << std::fixed << std::setprecision(5)
<< mm_cuda_float_latency << " ms" << std::endl;
std::cout << "DOUBLE: " << std::fixed << std::setprecision(5)
<< mm_cuda_double_latency << " ms" << std::endl;

mm_cuda_int32_latency = measure_latency_mm_cuda<int32_t>(
m, n, p, num_measurement_tests, num_measurement_warmups,
mm_kernel_optimized);
mm_cuda_float_latency = measure_latency_mm_cuda<float>(
m, n, p, num_measurement_tests, num_measurement_warmups,
mm_kernel_optimized);
mm_cuda_double_latency = measure_latency_mm_cuda<double>(
m, n, p, num_measurement_tests, num_measurement_warmups,
mm_kernel_optimized);

std::cout << "Optimized Matrix Multiplication CUDA Latency" << std::endl;
std::cout << "m: " << m << " "
<< "n: " << n << " "
<< "p: " << p << std::endl;
std::cout << "INT32: " << std::fixed << std::setprecision(5)
<< mm_cuda_int32_latency << " ms" << std::endl;
std::cout << "FLOAT: " << std::fixed << std::setprecision(5)
<< mm_cuda_float_latency << " ms" << std::endl;
std::cout << "DOUBLE: " << std::fixed << std::setprecision(5)
<< mm_cuda_double_latency << " ms" << std::endl;
}

Run Optimized Example

In the same Docker container, build and run the following application. We could see that the latency of INT32 and FP32 matrix multiplication got improved too different degrees.

1
2
3
4
5
6
7
8
9
10
11
12
$ nvcc mm_optimization.cu -o mm_optimization --std=c++14
$ ./mm_optimization
Matrix Multiplication CUDA Latency
m: 1024 n: 1024 p: 1024
INT32: 1.04373 ms
FLOAT: 1.02149 ms
DOUBLE: 3.83370 ms
Optimized Matrix Multiplication CUDA Latency
m: 1024 n: 1024 p: 1024
INT32: 0.84207 ms
FLOAT: 0.81759 ms
DOUBLE: 3.95231 ms

Miscellaneous

There are more subtle factors affecting the performance and there are more optimization opportunities to further optimize the matrix multiplication implementation. But those requires more thorough understanding of GPU and CUDA.

References

Author

Lei Mao

Posted on

03-21-2022

Updated on

03-04-2023

Licensed under


Comments