NVIDIA Tensor Core Programming

Introduction

NVIDIA Tensor Cores are dedicated accelerators for general matrix multiplication (GEMM) operations on NVIDIA GPUs since the Volta architecture. Because the artificial intelligence computations are usually dominated by GEMM operations, NVIDIA Tensor Core is critical for accelerating the artificial intelligence applications.

NVIDIA Tensor Core

NVIDIA Tensor Cores are specialized in performing the GEMM operations in mixed precision, i.e., the GEMM input matrices are in lower precision whereas the GEMM output matrix are in high precision. The mixed precision training and inference are the key techniques for accelerating the training and inference of neural networks.

NVIDIA Tensor Core GEMM Math

Because NVIDIA Tensor Cores are specifically designed for GEMM, the GEMM throughput using NVIDIA Tensor Core is incredibly much higher than what can be achieved using NVIDIA CUDA Cores which are more suitable for more general parallel programming.

NVIDIA GEMM Throughput Turing Tensor Core VS Pascal CUDA Core

For the NVIDIA Ampere architecture, each SM has 4 Tensor Cores. In particular, NVIDIA A100 GPU has 108 streaming multiprocessors (SMs) which accounts for 432 Tensor Cores in total.

NVIDIA GA100 Full GPU with 128 SMs
Each NVIDIA Ampere SM Has 4 Tensor Cores

NVIDIA Tensor Cores are fully programmable. The Tensor Core programming API at the warp level has been declared in the mma.h header under the nvcuda::wmma namespace.

NVIDIA Tensor Core Programming

Matrix Multiplication Decomposition

NVIDIA CUDA allows the user to program Tensor Core GEMM operations $D = AB + C$ at the warp level. While each Tensor Core could only perform matrix multiplication of some specific small sizes for different data types, as discussed in my previous article “CUDA Matrix Multiplication”, large GEMM can be divided into multiple small GEMMs and accumulation.

Given a GEMM operation $D = AB + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$, the matrices could be divided into smaller matrices.

$$
A =
\begin{bmatrix}
A_{1,1}^{d \times d} & A_{1,2}^{d \times d} & \cdots & A_{1,k/d}^{d \times d} \\
A_{2,1}^{d \times d} & A_{2,2}^{d \times d} & \cdots & A_{2,k/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
A_{m/d,1}^{d \times d} & A_{m/d,2}^{d \times d} & \cdots & A_{m/d,k/d}^{d \times d} \\
\end{bmatrix}
$$

$$
B =
\begin{bmatrix}
B_{1,1}^{d \times d} & B_{1,2}^{d \times d} & \cdots & B_{1,n/d}^{d \times d} \\
B_{2,1}^{d \times d} & B_{2,2}^{d \times d} & \cdots & B_{2,n/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
B_{k/d,1}^{d \times d} & B_{k/d,2}^{d \times d} & \cdots & B_{k/d,n/d}^{d \times d} \\
\end{bmatrix}
$$

$$
C =
\begin{bmatrix}
C_{1,1}^{d \times d} & C_{1,2}^{d \times d} & \cdots & C_{1,n/d}^{d \times d} \\
C_{2,1}^{d \times d} & C_{2,2}^{d \times d} & \cdots & C_{2,n/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
C_{m/d,1}^{d \times d} & C_{m/d,2}^{d \times d} & \cdots & C_{m/d,n/d}^{d \times d} \\
\end{bmatrix}
$$

$$
D =
\begin{bmatrix}
D_{1,1}^{d \times d} & D_{1,2}^{d \times d} & \cdots & D_{1,n/d}^{d \times d} \\
D_{2,1}^{d \times d} & D_{2,2}^{d \times d} & \cdots & D_{2,n/d}^{d \times d} \\
\vdots & \vdots & \ddots & \vdots \\
D_{m/d,1}^{d \times d} & D_{m/d,2}^{d \times d} & \cdots & D_{m/d,n/d}^{d \times d} \\
\end{bmatrix}
$$

Each small matrix in $D$ is computed as multiple small GEMMs and accumulation.

$$
D_{i_m,i_n}^{d \times d} = \sum_{i_k=1}^{k/d} A_{i_m,i_k}^{d \times d} B_{i_k,i_n}^{d \times d}
$$

In my previous article “CUDA Matrix Multiplication”, I used CUDA Core and CUDA shared memory to perform the above mathematics and each thread block computes one $D_{i_m,i_n}^{d \times d}$. This time instead, I will use Tensor Core to compute exactly the same mathematics where each warp computes one $D_{i_m,i_n}^{d \times d}$. More specifically, each warp computes a $16 \times 16 \times 16$ GEMM resulting in a $16 \times 16$ tile in the $D$ matrix, i.e., $d = 16$.

Matrix Multiplication Implementation Using NVIDIA Tensor Core

In this implementation, we will use Tensor Core to perform GEMM operations using HMMA (half matrix multiplication and accumulation) and IMMA (integer matrix multiplication and accumulation) instructions. In addition, four different types of GEMM which involves transposed matrix multiplications have been implemented and verified.

  • $D = AB + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$
  • $D = A^{\top}B + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{k \times m}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$
  • $D = AB^{\top} + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{n \times k}$, $C \in \mathbb{R}^{m \times n}$
  • $D = A^{\top}B^{\top} + C$, where $D \in \mathbb{R}^{m \times n}$, $A \in \mathbb{R}^{k \times m}$, $B \in \mathbb{R}^{n \times k}$, $C \in \mathbb{R}^{m \times n}$

In this implementation, we will mainly focus on the matrix multiplication part in the GEMM operation by treating the $C = 0$.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
#include <cassert>
#include <chrono>
#include <functional>
#include <iomanip>
#include <iostream>
#include <random>
#include <utility>
#include <vector>

#include <cuda_runtime.h>
#include <mma.h>

#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
void check(T err, const char* const func, const char* const file,
int const line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

#define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__)
void checkLast(const char* const file, int const line)
{
cudaError_t const err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <class T>
float measure_performance(std::function<T(cudaStream_t)> bound_function,
cudaStream_t stream, int num_repeats = 100,
int num_warmups = 100)
{
cudaEvent_t start, stop;
float time;

CHECK_CUDA_ERROR(cudaEventCreate(&start));
CHECK_CUDA_ERROR(cudaEventCreate(&stop));

for (int i{0}; i < num_warmups; ++i)
{
bound_function(stream);
}

CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

CHECK_CUDA_ERROR(cudaEventRecord(start, stream));
for (int i{0}; i < num_repeats; ++i)
{
bound_function(stream);
}
CHECK_CUDA_ERROR(cudaEventRecord(stop, stream));
CHECK_CUDA_ERROR(cudaEventSynchronize(stop));
CHECK_LAST_CUDA_ERROR();
CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop));
CHECK_CUDA_ERROR(cudaEventDestroy(start));
CHECK_CUDA_ERROR(cudaEventDestroy(stop));

float const latency{time / num_repeats};

return latency;
}

// All the data in the matrices are stored in a column-major order,
// which is the consistent with most of the cuBLAS GEMM APIs.
// For matrix A of shape M x N, the leading dimension is M.
// For matrix A that is transposed and is of shape N x M,
// the leading dimension is N.
// Matrix A: M x K, or K x N (if transposed).
// Matrix B: K x M, or M x K (if transposed).
// Matrix C: M x N.
// WMMA_FRAG_LAYOUT_A: nvcuda::wmma::row_major if A is
// transposed, otherwise nvcuda::wmma::col_major.
// WMMA_FRAG_LAYOUT_B: nvcuda::wmma::row_major if B is
// transposed, otherwise nvcuda::wmma::col_major.
template <typename T1, typename T2, int WMMA_M, int WMMA_N, int WMMA_K,
typename WMMA_FRAG_LAYOUT_A, typename WMMA_FRAG_LAYOUT_B>
__global__ void wmma_gemm_a_col_major_b_col_major(
T1 const* A, T1 const* B, T2* C, uint32_t m, uint32_t n, uint32_t k,
uint32_t lda, uint32_t ldb, uint32_t ldc, bool is_A_transpose,
bool is_B_transpose, float alpha, float beta)
{
// Tile using a 2D grid.
// Determine the warp 2D index.
uint32_t const warpM{(blockIdx.x * blockDim.x + threadIdx.x) / warpSize};
uint32_t const warpN{blockIdx.y * blockDim.y + threadIdx.y};

// Declare the fragments.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, T1,
WMMA_FRAG_LAYOUT_A>
a_frag{};
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, T1,
WMMA_FRAG_LAYOUT_B>
b_frag{};
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
T2>
acc_frag{};
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
T2>
c_frag{};

// Make sure the accumulator starts from 0.
nvcuda::wmma::fill_fragment(acc_frag, static_cast<T2>(0));

// Loop over K.
for (uint32_t ki{0}; ki < k; ki += WMMA_K)
{
// Determine the first element of the mma matrices on the linear memory.
// Matrix A mma matrix
uint32_t const matrix_mma_a_row_idx{is_A_transpose ? ki
: warpM * WMMA_M};
uint32_t const matrix_mma_a_col_idx{is_A_transpose ? warpM * WMMA_M
: ki};
// Matrix B mma matrix
uint32_t const matrix_mma_b_row_idx{is_B_transpose ? warpN * WMMA_N
: ki};
uint32_t const matrix_mma_b_col_idx{is_B_transpose ? ki
: warpN * WMMA_N};

// Bounds checking
if (matrix_mma_a_row_idx < (is_A_transpose ? k : m) &&
matrix_mma_a_col_idx < (is_A_transpose ? m : k) &&
matrix_mma_b_row_idx < (is_B_transpose ? n : k) &&
matrix_mma_b_col_idx < (is_B_transpose ? k : n))
{
// Determine the memory address of the first element of the mma
// matrices. Notice that all the matrices are assumed to be
// column-major. Therefore, the indexing is different from the
// row-major indexing that we commonly see.
T1 const* matrix_mma_a_mptr{A + matrix_mma_a_row_idx +
matrix_mma_a_col_idx * lda};
T1 const* matrix_mma_b_mptr{B + matrix_mma_b_row_idx +
matrix_mma_b_col_idx * ldb};
// Load the mma matrix inputs.
nvcuda::wmma::load_matrix_sync(a_frag, matrix_mma_a_mptr, lda);
nvcuda::wmma::load_matrix_sync(b_frag, matrix_mma_b_mptr, ldb);

// Perform the matrix multiplication
nvcuda::wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}

// Load in the current value of c, scale it by beta, and add this our result
// scaled by alpha.
uint32_t const matrix_mma_c_row_idx{warpM * WMMA_M};
uint32_t const matrix_mma_c_col_idx{warpN * WMMA_N};

if (matrix_mma_c_row_idx < m && matrix_mma_c_col_idx < n)
{
T2* matrix_mma_c_mptr{C + matrix_mma_c_row_idx +
matrix_mma_c_col_idx * ldc};
nvcuda::wmma::load_matrix_sync(c_frag, matrix_mma_c_mptr, ldc,
nvcuda::wmma::mem_col_major);
// Let the compiler figure out how to do the elementwise operation.
// Such elementwise operation can be scaling, accumulation,
// quantization, etc.
// https://docs.nvidia.com/cuda/archive/12.0.1/cuda-c-programming-guide/#id40
// Be careful when dealing with the integer types.
for (uint32_t i = 0; i < c_frag.num_elements; i++)
{
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
// Store the output
nvcuda::wmma::store_matrix_sync(matrix_mma_c_mptr, c_frag, ldc,
nvcuda::wmma::mem_col_major);
}
}

template <typename T1, typename T2>
void launch_wmma_mm(T1 const* A, T1 const* B, T2* C, uint32_t m, uint32_t n,
uint32_t k, bool is_A_transpose, bool is_B_transpose,
cudaStream_t stream)
{
// Assume there is no padding in our data.
uint32_t const lda{is_A_transpose ? k : m};
uint32_t const ldb{is_B_transpose ? n : k};
uint32_t const ldc{m};
float const alpha{1.0f};
float const beta{0.0f};

constexpr int WMMA_M{16};
constexpr int WMMA_N{16};
constexpr int WMMA_K{16};

constexpr int WARP_SIZE{32};

dim3 gridDim;
dim3 blockDim;

// blockDim.x must be a multple of warpSize
// Block size of 128x4 means we have 16 (4x4) warps,
// each warp computes a 16x16 output tile,
// and a block computes a 64x64 output tile.
// Each block has 4x4 warps, totalling 4x4x32 threads.
int const num_warps_x = 4;
int const num_warps_y = 4;
blockDim.x = num_warps_x * WARP_SIZE;
blockDim.y = num_warps_y;
// Round up.
gridDim.x = (m + (WMMA_M * num_warps_x - 1)) / (WMMA_M * num_warps_x);
gridDim.y = (n + WMMA_N * num_warps_y - 1) / (WMMA_N * num_warps_y);

// C = A * B
if ((!is_A_transpose) && (!is_B_transpose))
{
wmma_gemm_a_col_major_b_col_major<T1, T2, WMMA_M, WMMA_N, WMMA_K,
nvcuda::wmma::col_major,
nvcuda::wmma::col_major>
<<<gridDim, blockDim, 0, stream>>>(A, B, C, m, n, k, lda, ldb, ldc,
is_A_transpose, is_B_transpose,
alpha, beta);
}
// C = A^T * B
else if ((is_A_transpose) && (!is_B_transpose))
{
wmma_gemm_a_col_major_b_col_major<T1, T2, WMMA_M, WMMA_N, WMMA_K,
nvcuda::wmma::row_major,
nvcuda::wmma::col_major>
<<<gridDim, blockDim, 0, stream>>>(A, B, C, m, n, k, lda, ldb, ldc,
is_A_transpose, is_B_transpose,
alpha, beta);
}
// C = A * B^T
else if ((!is_A_transpose) && (is_B_transpose))
{
wmma_gemm_a_col_major_b_col_major<T1, T2, WMMA_M, WMMA_N, WMMA_K,
nvcuda::wmma::col_major,
nvcuda::wmma::row_major>
<<<gridDim, blockDim, 0, stream>>>(A, B, C, m, n, k, lda, ldb, ldc,
is_A_transpose, is_B_transpose,
alpha, beta);
}
// C = A^T * B^T
else
{
wmma_gemm_a_col_major_b_col_major<T1, T2, WMMA_M, WMMA_N, WMMA_K,
nvcuda::wmma::row_major,
nvcuda::wmma::row_major>
<<<gridDim, blockDim, 0, stream>>>(A, B, C, m, n, k, lda, ldb, ldc,
is_A_transpose, is_B_transpose,
alpha, beta);
}
CHECK_LAST_CUDA_ERROR();
}

// A and B are column-major matrices.
template <typename T1, typename T2>
void mm_a_col_major_b_col_major(T1 const* A, T1 const* B, T2* C, uint32_t m,
uint32_t n, uint32_t k, uint32_t lda,
uint32_t ldb, uint32_t ldc, bool is_A_transpose,
bool is_B_transpose)
{
for (uint32_t ni{0}; ni < n; ++ni)
{
for (uint32_t mi{0}; mi < m; ++mi)
{
// Compute C[mi, ni]
T2 accum{0};
// C = A * B
if ((!is_A_transpose) && (!is_B_transpose))
{
for (uint32_t ki{0}; ki < k; ++ki)
{
// A[mi, ki] * B[ki, ni]
accum += A[ki * lda + mi] * B[ni * ldb + ki];
}
}
// C = A^T * B
else if ((is_A_transpose) && (!is_B_transpose))
{
for (uint32_t ki{0}; ki < k; ++ki)
{
// A[ki, mi] * B[ki, ni]
accum += A[mi * lda + ki] * B[ni * ldb + ki];
}
}
// C = A * B^T
else if ((!is_A_transpose) && (is_B_transpose))
{
for (uint32_t ki{0}; ki < k; ++ki)
{
// A[mi, ki] * B[ni, ki]
accum += A[ki * lda + mi] * B[ki * ldb + ni];
}
}
// C = A^T * B^T
else
{
for (uint32_t ki{0}; ki < k; ++ki)
{
// A[ki, mi] * B[ni, ki]
accum += A[mi * lda + ki] * B[ki * ldb + ni];
}
}
C[ni * ldc + mi] = accum;
}
}
}

template <typename T1, typename T2>
void launch_mm(T1 const* A, T1 const* B, T2* C, uint32_t m, uint32_t n,
uint32_t k, bool is_A_transpose, bool is_B_transpose)
{
// Assume there is no padding in our data.
uint32_t const lda{is_A_transpose ? k : m};
uint32_t const ldb{is_B_transpose ? n : k};
uint32_t const ldc{m};
mm_a_col_major_b_col_major(A, B, C, m, n, k, lda, ldb, ldc, is_A_transpose,
is_B_transpose);
}

void fill_random_float_values(float* arr, size_t n,
std::default_random_engine& e)
{
std::uniform_real_distribution<float> uniform_dist(-256, 256);
for (size_t i{0}; i < n; ++i)
{
arr[i] = uniform_dist(e);
}
}

void fill_random_int8_values(int8_t* arr, size_t n,
std::default_random_engine& e)
{
std::uniform_int_distribution<int8_t> uniform_dist(-128, 127);
for (size_t i{0}; i < n; ++i)
{
arr[i] = uniform_dist(e);
}
}

void fill_random_int32_values(int32_t* arr, size_t n,
std::default_random_engine& e)
{
std::uniform_int_distribution<int32_t> uniform_dist(-128, 127);
for (size_t i{0}; i < n; ++i)
{
arr[i] = uniform_dist(e);
}
}

void float2half(__half* half_arr, float const* float_arr, size_t n)
{
for (size_t i{0}; i < n; ++i)
{
half_arr[i] = __float2half(float_arr[i]);
}
}

template <typename T>
float get_avg_abs_diff_ratio(T const* arr_1, T const* arr_2, size_t n)
{
float sum_abs_diff_ratio{0};
for (size_t i{0}; i < n; ++i)
{
sum_abs_diff_ratio += std::abs(static_cast<float>(arr_1[i]) -
static_cast<float>(arr_2[i])) /
std::abs(static_cast<float>(arr_1[i]) +
static_cast<float>(arr_2[i]));
}
return sum_abs_diff_ratio / n;
}

template <typename T>
bool array_equal(T const* arr_1, T const* arr_2, size_t n)
{
for (size_t i{0}; i < n; ++i)
{
if (arr_1[i] != arr_2[i])
{
return false;
}
}
return true;
}

void print_test_header(bool is_A_transpose, bool is_B_transpose)
{
// C = A * B
if ((!is_A_transpose) && (!is_B_transpose))
{
std::cout << "C = A * B" << std::endl;
}
// C = A^T * B
else if ((is_A_transpose) && (!is_B_transpose))
{
std::cout << "C = A^T * B" << std::endl;
}
// C = A * B^T
else if ((!is_A_transpose) && (is_B_transpose))
{
std::cout << "C = A * B^T" << std::endl;
}
// C = A^T * B^T
else
{
std::cout << "C = A^T * B^T" << std::endl;
}
}

int main()
{
constexpr int num_repeats{10};
constexpr int num_warmups{10};

uint32_t const matrix_size_m{1024};
uint32_t const matrix_size_n{1024};
uint32_t const matrix_size_k{1024};
std::cout << "Matrix Sizes" << std::endl;
std::cout << "M: " << matrix_size_m << std::endl;
std::cout << "N: " << matrix_size_n << std::endl;
std::cout << "K: " << matrix_size_k << std::endl;

std::default_random_engine random_engine(0);

cudaStream_t stream;
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));

// HMMA
std::cout << "FP16 HMMA" << std::endl;
std::vector<float> matrix_a_float(matrix_size_m * matrix_size_k);
std::vector<float> matrix_b_float(matrix_size_k * matrix_size_n);
std::vector<__half> matrix_a_half(matrix_size_m * matrix_size_k);
std::vector<__half> matrix_b_half(matrix_size_k * matrix_size_n);
std::vector<float> matrix_c_float(matrix_size_m * matrix_size_n);
std::vector<float> matrix_c_float_reference(matrix_size_m * matrix_size_n);

float* h_matrix_a_float{matrix_a_float.data()};
float* h_matrix_b_float{matrix_b_float.data()};
__half* h_matrix_a_half{matrix_a_half.data()};
__half* h_matrix_b_half{matrix_b_half.data()};
float* h_matrix_c_float{matrix_c_float.data()};
float* h_matrix_c_float_reference{matrix_c_float_reference.data()};

fill_random_float_values(h_matrix_a_float, matrix_a_float.size(),
random_engine);
fill_random_float_values(h_matrix_b_float, matrix_b_float.size(),
random_engine);
fill_random_float_values(h_matrix_c_float, matrix_c_float.size(),
random_engine);
fill_random_float_values(h_matrix_c_float_reference,
matrix_c_float_reference.size(), random_engine);
float2half(h_matrix_a_half, h_matrix_a_float, matrix_a_float.size());
float2half(h_matrix_b_half, h_matrix_b_float, matrix_b_float.size());

half *d_matrix_a_half, *d_matrix_b_half;
float* d_matrix_c_float;

CHECK_CUDA_ERROR(cudaMalloc(&d_matrix_a_half,
matrix_size_m * matrix_size_k * sizeof(half)));
CHECK_CUDA_ERROR(cudaMalloc(&d_matrix_b_half,
matrix_size_k * matrix_size_n * sizeof(half)));
CHECK_CUDA_ERROR(cudaMalloc(&d_matrix_c_float,
matrix_size_m * matrix_size_n * sizeof(float)));

// Copy data from host to device.
CHECK_CUDA_ERROR(cudaMemcpy(d_matrix_a_half, h_matrix_a_half,
matrix_a_float.size() * sizeof(__half),
cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(d_matrix_b_half, h_matrix_b_half,
matrix_b_float.size() * sizeof(__half),
cudaMemcpyHostToDevice));

for (bool is_A_transpose : {true, false})
{
for (bool is_B_transpose : {true, false})
{
print_test_header(is_A_transpose, is_B_transpose);
// Compute matrix multiplication reference output using CPU.
launch_mm(h_matrix_a_float, h_matrix_b_float,
h_matrix_c_float_reference, matrix_size_m, matrix_size_n,
matrix_size_k, is_A_transpose, is_B_transpose);
// Compute matrix multiplication reference output using CUDA WMMA.
launch_wmma_mm(d_matrix_a_half, d_matrix_b_half, d_matrix_c_float,
matrix_size_m, matrix_size_n, matrix_size_k,
is_A_transpose, is_B_transpose, stream);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

CHECK_CUDA_ERROR(cudaMemcpy(h_matrix_c_float, d_matrix_c_float,
matrix_c_float.size() * sizeof(float),
cudaMemcpyDeviceToHost));

float const avg_abs_diff_ratio{get_avg_abs_diff_ratio(
h_matrix_c_float, h_matrix_c_float_reference,
matrix_c_float.size())};
if (avg_abs_diff_ratio > 0.01)
{
std::cout << "Got high average absolute diff ratio: "
<< avg_abs_diff_ratio << std::endl;
}

// Performance measurement.
std::function<void(cudaStream_t)> const function_hmma{std::bind(
launch_wmma_mm<__half, float>, d_matrix_a_half, d_matrix_b_half,
d_matrix_c_float, matrix_size_m, matrix_size_n, matrix_size_k,
is_A_transpose, is_B_transpose, std::placeholders::_1)};
float const latency_hmma{measure_performance(
function_hmma, stream, num_repeats, num_warmups)};
std::cout << std::fixed << std::setprecision(3)
<< "HMMA Latency: " << latency_hmma << " ms" << std::endl;
}
}

CHECK_CUDA_ERROR(cudaFree(d_matrix_a_half));
CHECK_CUDA_ERROR(cudaFree(d_matrix_b_half));
CHECK_CUDA_ERROR(cudaFree(d_matrix_c_float));

// IMMA
std::cout << "INT8 IMMA" << std::endl;
std::vector<int8_t> matrix_a_int8(matrix_size_m * matrix_size_k);
std::vector<int8_t> matrix_b_int8(matrix_size_k * matrix_size_n);
std::vector<int32_t> matrix_c_int32(matrix_size_m * matrix_size_n);
std::vector<int32_t> matrix_c_int32_reference(matrix_size_m *
matrix_size_n);

int8_t* h_matrix_a_int8{matrix_a_int8.data()};
int8_t* h_matrix_b_int8{matrix_b_int8.data()};
int32_t* h_matrix_c_int32{matrix_c_int32.data()};
int32_t* h_matrix_c_int32_reference{matrix_c_int32_reference.data()};

fill_random_int8_values(h_matrix_a_int8, matrix_a_int8.size(),
random_engine);
fill_random_int8_values(h_matrix_b_int8, matrix_b_int8.size(),
random_engine);
fill_random_int32_values(h_matrix_c_int32, matrix_c_int32.size(),
random_engine);
fill_random_int32_values(h_matrix_c_int32_reference,
matrix_c_int32_reference.size(), random_engine);

// Profile INT8 IMMA without verifying the correctness.
int8_t *d_matrix_a_int8, *d_matrix_b_int8;
int32_t* d_matrix_c_int32;

CHECK_CUDA_ERROR(cudaMalloc(
&d_matrix_a_int8, matrix_size_m * matrix_size_k * sizeof(int8_t)));
CHECK_CUDA_ERROR(cudaMalloc(
&d_matrix_b_int8, matrix_size_k * matrix_size_n * sizeof(int8_t)));
CHECK_CUDA_ERROR(cudaMalloc(
&d_matrix_c_int32, matrix_size_m * matrix_size_n * sizeof(int32_t)));

CHECK_CUDA_ERROR(cudaMemcpy(d_matrix_a_int8, h_matrix_a_int8,
matrix_a_int8.size() * sizeof(int8_t),
cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(d_matrix_b_int8, h_matrix_b_int8,
matrix_b_int8.size() * sizeof(int8_t),
cudaMemcpyHostToDevice));

for (bool is_A_transpose : {true, false})
{
for (bool is_B_transpose : {true, false})
{
print_test_header(is_A_transpose, is_B_transpose);
// Compute matrix multiplication reference output using CPU.
launch_mm(h_matrix_a_int8, h_matrix_b_int8,
h_matrix_c_int32_reference, matrix_size_m, matrix_size_n,
matrix_size_k, is_A_transpose, is_B_transpose);
// Compute matrix multiplication reference output using CUDA WMMA.
launch_wmma_mm(d_matrix_a_int8, d_matrix_b_int8, d_matrix_c_int32,
matrix_size_m, matrix_size_n, matrix_size_k,
is_A_transpose, is_B_transpose, stream);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
CHECK_CUDA_ERROR(cudaMemcpy(h_matrix_c_int32, d_matrix_c_int32,
matrix_c_int32.size() * sizeof(int32_t),
cudaMemcpyDeviceToHost));
// Integer matrix multiplications from CPU and CUDA should be
// bitwise identical.
assert(array_equal(h_matrix_c_int32, h_matrix_c_int32_reference,
matrix_c_int32.size()));

// Performance measurement.
std::function<void(cudaStream_t)> const function_imma{
std::bind(launch_wmma_mm<int8_t, int32_t>, d_matrix_a_int8,
d_matrix_b_int8, d_matrix_c_int32, matrix_size_m,
matrix_size_n, matrix_size_k, is_A_transpose,
is_B_transpose, std::placeholders::_1)};
float const latency_imma{measure_performance(
function_imma, stream, num_repeats, num_warmups)};
std::cout << std::fixed << std::setprecision(3)
<< "IMMA Latency: " << latency_imma << " ms" << std::endl;
}
}

CHECK_CUDA_ERROR(cudaFree(d_matrix_a_int8));
CHECK_CUDA_ERROR(cudaFree(d_matrix_b_int8));
CHECK_CUDA_ERROR(cudaFree(d_matrix_c_int32));

CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
}

All the transposed matrix multiplication implementations did not actually transpose the matrices. Instead, we used the row-major and column-major trick introduced in my previous article “Row-Major VS Column-Major”.

We also observed that for matrix multiplication for matrices stored in column-major order, $C = A^{\top}B$ is the fastest and $C = A B^{\top}$ is the slowest, for GEMM implementations using HMMA and IMMA instructions on an NVIDIA RTX 3090 GPU.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
$ nvcc mma.cu -o mma --gpu-architecture=compute_86
$ ./mma
Matrix Sizes
M: 1024
N: 1024
K: 1024
FP16 HMMA
C = A^T * B^T
HMMA Latency: 0.177 ms
C = A^T * B
HMMA Latency: 0.169 ms
C = A * B^T
HMMA Latency: 0.189 ms
C = A * B
HMMA Latency: 0.177 ms
INT8 IMMA
C = A^T * B^T
IMMA Latency: 0.129 ms
C = A^T * B
IMMA Latency: 0.090 ms
C = A * B^T
IMMA Latency: 0.170 ms
C = A * B
IMMA Latency: 0.129 ms

Conclusions

NVIDIA Tensor Cores are programmable and can be used for accelerating computations that are dominated by GEMM operations.

References

Author

Lei Mao

Posted on

05-18-2023

Updated on

12-27-2023

Licensed under


Comments