CUDA Compilation Architecture Macro

Introduction

In C++, macros are often used for controlling the code for compilation for difference use cases. Similarly, in CUDA, it is often necessary to compile the same source code file for different GPU architectures.

In this blog post, I would like to quickly discuss how to use NVCC compilation architecture macro to control the compilation for different GPU architectures.

Half Addition Example

According to the CUDA arithmetic instructions, FP16 add arithmetic instruction could only be performed with compute capability >= 5.3.

In this example, with architecture macro, different FP16 add implementation could be switched for different virtual GPU architectures.

No Architecture Macro

Without using architecture macro, we could not control the device side implementation for different virtual GPU architectures.

half_addition_no_macro.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <cmath>
#include <cstdint>
#include <cuda_fp16.h>
#include <functional>
#include <iomanip>
#include <iostream>
#include <vector>

#define checkCuda(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, const char* const func, const char* const file,
const int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <class T>
float measure_cuda_performance(std::function<T(void)> bound_function,
const int num_repeats = 100,
const int num_warmups = 100)
{

cudaEvent_t start;
cudaEvent_t stop;

cudaEventCreate(&start);
cudaEventCreate(&stop);

for (int i{0}; i < num_warmups; ++i)
{
bound_function();
}
cudaDeviceSynchronize();

cudaEventRecord(start, 0);
for (int i{0}; i < num_repeats; ++i)
{
bound_function();
}
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA half addition kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}

float time_elapsed{0.0f};
cudaEventElapsedTime(&time_elapsed, start, stop);

float latency{time_elapsed / static_cast<float>(num_repeats)};

return latency;
}

__global__ void half_addition(__half* output, __half const* input_1,
__half const* input_2, uint32_t const n)
{
const uint32_t idx{blockDim.x * blockIdx.x + threadIdx.x};
const uint32_t stride{blockDim.x * gridDim.x};
for (uint32_t i{idx}; i < n; i += stride)
{
output[i] = __hadd(input_1[i], input_2[i]);
}
}

void launch_half_addition(__half* output, __half const* input_1,
__half const* input_2, uint32_t const n)
{
dim3 threads_per_block{1024};
dim3 blocks_per_grid{32};
half_addition<<<blocks_per_grid, threads_per_block>>>(output, input_1,
input_2, n);
}

int main()
{
constexpr uint32_t n{100000};
constexpr float a{1.0f}, b{2.0f}, c{3.0f};

std::vector<__half> input_1(n, __float2half(a));
std::vector<__half> input_2(n, __float2half(b));
std::vector<__half> output(n, __float2half(0.0f));

__half *d_input_1, *d_input_2, *d_output;

checkCuda(cudaMalloc(&d_input_1, n * sizeof(__half)));
checkCuda(cudaMalloc(&d_input_2, n * sizeof(__half)));
checkCuda(cudaMalloc(&d_output, n * sizeof(__half)));
checkCuda(cudaMemcpy(d_input_1, input_1.data(), n * sizeof(__half),
cudaMemcpyHostToDevice));
checkCuda(cudaMemcpy(d_input_2, input_2.data(), n * sizeof(__half),
cudaMemcpyHostToDevice));

launch_half_addition(d_output, d_input_1, d_input_2, n);
cudaDeviceSynchronize();
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA half addition kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}

checkCuda(cudaMemcpy(output.data(), d_output, n * sizeof(__half),
cudaMemcpyDeviceToHost));

for (uint32_t i{0}; i < n; ++i)
{
if (std::abs(__half2float(output.at(i)) - c) > 1e-8)
{
std::cerr << "CUDA half addition kernel implementation has error!."
<< std::endl;
std::cout << "Expect " << c << " at position " << i << " but got "
<< __half2float(output.at(i)) << std::endl;
break;
}
}

std::function<void(void)> bound_function{
std::bind(launch_half_addition, d_output, d_input_1, d_input_2, n)};
float latency{measure_cuda_performance(bound_function, 100, 100)};
std::cout << std::fixed << std::setprecision(5) << "Latency: " << latency
<< " ms" << std::endl;
}

Although compiling the FP16 addition program against the virtual GPU architecture compute_52 did not produce compilation error, runtime sanity check shows that the the CUDA kernel has issues. Compiling the same program against the virtual GPU architecture compute_53 is fine. This is expected because FP16 add arithmetic instruction __hadd could only be performed with virtual GPU architecture >= 5.3.

1
2
3
4
5
6
7
8
$ nvcc half_addition_no_macro.cu -o half_addition_no_macro --gpu-architecture=compute_52
$ ./half_addition_no_macro
CUDA half addition kernel implementation has error!.
Expect 3 at position 0 but got 1
Latency: 0.00313 ms
$ nvcc half_addition_no_macro.cu -o half_addition_no_macro --gpu-architecture=compute_53
$ ./half_addition_no_macro
Latency: 0.00286 ms

With Architecture Macro

For virtual GPU architecture < 5.3, if we care less about the performance, we could still do the FP16 addition by converting FP16 values to FP32, perform the FP32 addition, and convert the FP32 sum back to FP16. __CUDA_ARCH__ is the architecture macro representing the virtual GPU architecture.

half_addition_with_macro.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include <cmath>
#include <cstdint>
#include <cuda_fp16.h>
#include <functional>
#include <iomanip>
#include <iostream>
#include <vector>

#define checkCuda(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, const char* const func, const char* const file,
const int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <class T>
float measure_cuda_performance(std::function<T(void)> bound_function,
const int num_repeats = 100,
const int num_warmups = 100)
{
cudaEvent_t start;
cudaEvent_t stop;

cudaEventCreate(&start);
cudaEventCreate(&stop);

for (int i{0}; i < num_warmups; ++i)
{
bound_function();
}
cudaDeviceSynchronize();

cudaEventRecord(start, 0);
for (int i{0}; i < num_repeats; ++i)
{
bound_function();
}
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA half addition kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}

float time_elapsed{0.0f};
cudaEventElapsedTime(&time_elapsed, start, stop);

float latency{time_elapsed / static_cast<float>(num_repeats)};

return latency;
}

__global__ void half_addition(__half* output, __half const* input_1,
__half const* input_2, uint32_t const n)
{
const uint32_t idx{blockDim.x * blockIdx.x + threadIdx.x};
const uint32_t stride{blockDim.x * gridDim.x};
for (uint32_t i{idx}; i < n; i += stride)
{
#if __CUDA_ARCH__ >= 530
output[i] = __hadd(input_1[i], input_2[i]);
#else
output[i] =
__float2half(__half2float(input_1[i]) + __half2float(input_2[i]));
#endif
}
}

void launch_half_addition(__half* output, __half const* input_1,
__half const* input_2, uint32_t const n)
{
dim3 threads_per_block{1024};
dim3 blocks_per_grid{32};
half_addition<<<blocks_per_grid, threads_per_block>>>(output, input_1,
input_2, n);
}

int main()
{
constexpr uint32_t n{100000};
constexpr float a{1.0f}, b{2.0f}, c{3.0f};

std::vector<__half> input_1(n, __float2half(a));
std::vector<__half> input_2(n, __float2half(b));
std::vector<__half> output(n, __float2half(0.0f));

__half *d_input_1, *d_input_2, *d_output;

checkCuda(cudaMalloc(&d_input_1, n * sizeof(__half)));
checkCuda(cudaMalloc(&d_input_2, n * sizeof(__half)));
checkCuda(cudaMalloc(&d_output, n * sizeof(__half)));
checkCuda(cudaMemcpy(d_input_1, input_1.data(), n * sizeof(__half),
cudaMemcpyHostToDevice));
checkCuda(cudaMemcpy(d_input_2, input_2.data(), n * sizeof(__half),
cudaMemcpyHostToDevice));

launch_half_addition(d_output, d_input_1, d_input_2, n);
cudaDeviceSynchronize();
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA half addition kernel failed to execute."
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
checkCuda(cudaMemcpy(output.data(), d_output, n * sizeof(__half),
cudaMemcpyDeviceToHost));

for (uint32_t i{0}; i < n; ++i)
{
if (std::abs(__half2float(output.at(i)) - c) > 1e-8)
{
std::cerr << "CUDA half addition kernel implementation has error!."
<< std::endl;
std::cout << "Expect " << c << " at position " << i << " but got "
<< __half2float(output.at(i)) << std::endl;
break;
}
}

std::function<void(void)> bound_function{
std::bind(launch_half_addition, d_output, d_input_1, d_input_2, n)};
float latency{measure_cuda_performance(bound_function, 100, 100)};
std::cout << std::fixed << std::setprecision(5) << "Latency: " << latency
<< " ms" << std::endl;
}

In this implementation, when the virtual GPU architecture is compute_52, __float2half(__half2float(input_1[i]) + __half2float(input_2[i])) will be used for compilation; when the virtual GPU architecture is compute_53, __hadd(input_1[i], input_2[i]) will be used for compilation.

1
2
3
4
5
6
$ nvcc half_addition_with_macro.cu -o half_addition_with_macro --gpu-architecture=compute_52
$ ./half_addition_with_macro
Latency: 0.00305 ms
$ nvcc half_addition_with_macro.cu -o half_addition_with_macro --gpu-architecture=compute_53
$ ./half_addition_with_macro
Latency: 0.00292 ms

Caveats

This macro can be used in the implementation of GPU functions for determining the virtual architecture for which it is currently being compiled. The host code (the non-GPU code) must not depend on it. This means the __CUDA_ARCH__ macro could only live inside the functions decorated with __device__.

In the following example, we could see that the __CUDA_ARCH__ macro is useless inside a host function.

host.cu
1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>

void test_host_function()
{
#if __CUDA_ARCH__ >= 530
std::cout << "__CUDA_ARCH__ >= 530" << std::endl;
#else
std::cout << "__CUDA_ARCH__ < 530" << std::endl;
#endif
}

int main() { test_host_function(); }
1
2
3
4
5
6
$ nvcc host.cu -o host --gpu-architecture=compute_52
$ ./host
__CUDA_ARCH__ < 530
$ nvcc host.cu -o host --gpu-architecture=compute_53
$ ./host
__CUDA_ARCH__ < 530

References

Author

Lei Mao

Posted on

05-01-2022

Updated on

05-01-2022

Licensed under


Comments