CUDA Reduction

Introduction

Reduction is a common operation in parallel computing. Usually the reduction operation is used to compute the sum, the maximum, the minimum, the product, of a sequence of elements.

In this blog post, we will discuss the parallel reduction algorithm and its implementation in CUDA.

Batched Reduce Sum

In this example, we implemented two batched reduce sum kernels in CUDA. The batched reduce sum kernel computes the sum for each array of elements in a batch of arrays.

The idea of the reduction algorithm is simple. For each array in the batch, we will assign a thread block consisting of a fixed number of threads to compute the sum of the elements in the array. Each thread will access multiple elements in the array from the global memory and store the partial sum in the register file. After all the threads have computed the partial sum, we have two ways to further reduce the partial sum to the final sum. One way is to use shared memory to store the partial sum and reduce the partial sum in the shared memory. The other way is to use warp-level primitives to reduce the partial sum in the register file in a warp followed by a smaller scale reduction in the shared memory.

reduce_sum.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
#include <cassert>
#include <functional>
#include <iostream>
#include <string>
#include <vector>

#include <cuda_runtime.h>

#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, char const* func, char const* file, int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

#define CHECK_LAST_CUDA_ERROR() check_last(__FILE__, __LINE__)
void check_last(char const* file, int line)
{
cudaError_t const err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
}

template <class T>
float measure_performance(std::function<T(cudaStream_t)> bound_function,
cudaStream_t stream, size_t num_repeats = 10,
size_t num_warmups = 10)
{
cudaEvent_t start, stop;
float time;

CHECK_CUDA_ERROR(cudaEventCreate(&start));
CHECK_CUDA_ERROR(cudaEventCreate(&stop));

for (size_t i{0}; i < num_warmups; ++i)
{
bound_function(stream);
}

CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

CHECK_CUDA_ERROR(cudaEventRecord(start, stream));
for (size_t i{0}; i < num_repeats; ++i)
{
bound_function(stream);
}
CHECK_CUDA_ERROR(cudaEventRecord(stop, stream));
CHECK_CUDA_ERROR(cudaEventSynchronize(stop));
CHECK_LAST_CUDA_ERROR();
CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop));
CHECK_CUDA_ERROR(cudaEventDestroy(start));
CHECK_CUDA_ERROR(cudaEventDestroy(stop));

float const latency{time / num_repeats};

return latency;
}

std::string std_string_centered(std::string const& s, size_t width,
char pad = ' ')
{
size_t const l{s.length()};
// Throw an exception if width is too small.
if (width < l)
{
throw std::runtime_error("Width is too small.");
}
size_t const left_pad{(width - l) / 2};
size_t const right_pad{width - l - left_pad};
std::string const s_centered{std::string(left_pad, pad) + s +
std::string(right_pad, pad)};
return s_centered;
}

template <size_t NUM_THREADS>
__device__ float shared_data_reduce_sum_v1(float shared_data[NUM_THREADS])
{
static_assert(NUM_THREADS % 32 == 0,
"NUM_THREADS must be a multiple of 32");
size_t const thread_idx{threadIdx.x};
#pragma unroll
for (size_t stride{NUM_THREADS / 2}; stride > 0; stride /= 2)
{
if (thread_idx < stride)
{
shared_data[thread_idx] += shared_data[thread_idx + stride];
}
__syncthreads();
}
return shared_data[0];
}

template <size_t NUM_WARPS>
__device__ float shared_data_reduce_sum_v2(float shared_data[NUM_WARPS])
{
float sum{0.0f};
#pragma unroll
for (size_t i{0}; i < NUM_WARPS; ++i)
{
// There will be no shared memory bank conflicts here.
// Because multiple threads in a warp address the same shared memory
// location, resulting in a broadcast.
sum += shared_data[i];
}
return sum;
}

__device__ float warp_reduce_sum(float val)
{
constexpr unsigned int FULL_MASK{0xffffffff};
#pragma unroll
for (size_t offset{16}; offset > 0; offset /= 2)
{
val += __shfl_down_sync(FULL_MASK, val, offset);
}
// Only the first thread in the warp will return the correct result.
return val;
}

template <size_t NUM_THREADS>
__device__ float block_reduce_sum_v1(float const* __restrict__ input_data,
float shared_data[NUM_THREADS],
size_t num_elements)
{
static_assert(NUM_THREADS % 32 == 0,
"NUM_THREADS must be a multiple of 32");
size_t const num_elements_per_thread{(num_elements + NUM_THREADS - 1) /
NUM_THREADS};
size_t const thread_idx{threadIdx.x};
float sum{0.0f};
for (size_t i{0}; i < num_elements_per_thread; ++i)
{
size_t const offset{thread_idx + i * NUM_THREADS};
if (offset < num_elements)
{
sum += input_data[offset];
}
}
shared_data[thread_idx] = sum;
__syncthreads();
float const block_sum{shared_data_reduce_sum_v1<NUM_THREADS>(shared_data)};
return block_sum;
}

template <size_t NUM_THREADS, size_t NUM_WARPS = NUM_THREADS / 32>
__device__ float block_reduce_sum_v2(float const* __restrict__ input_data,
float shared_data[NUM_WARPS],
size_t num_elements)
{
size_t const num_elements_per_thread{(num_elements + NUM_THREADS - 1) /
NUM_THREADS};
size_t const thread_idx{threadIdx.x};
float sum{0.0f};
for (size_t i{0}; i < num_elements_per_thread; ++i)
{
size_t const offset{thread_idx + i * NUM_THREADS};
if (offset < num_elements)
{
sum += input_data[offset];
}
}
sum = warp_reduce_sum(sum);
if (threadIdx.x % 32 == 0)
{
shared_data[threadIdx.x / 32] = sum;
}
__syncthreads();
float const block_sum{shared_data_reduce_sum_v2<NUM_WARPS>(shared_data)};
return block_sum;
}

template <size_t NUM_THREADS>
__global__ void batched_reduce_sum_v1(float* __restrict__ output_data,
float const* __restrict__ input_data,

size_t num_elements_per_batch)
{
static_assert(NUM_THREADS % 32 == 0,
"NUM_THREADS must be a multiple of 32");
size_t const block_idx{blockIdx.x};
size_t const thread_idx{threadIdx.x};
__shared__ float shared_data[NUM_THREADS];
float const block_sum{block_reduce_sum_v1<NUM_THREADS>(
input_data + block_idx * num_elements_per_batch, shared_data,
num_elements_per_batch)};
if (thread_idx == 0)
{
output_data[block_idx] = block_sum;
}
}

template <size_t NUM_THREADS>
__global__ void batched_reduce_sum_v2(float* __restrict__ output_data,
float const* __restrict__ input_data,

size_t num_elements_per_batch)
{
static_assert(NUM_THREADS % 32 == 0,
"NUM_THREADS must be a multiple of 32");
constexpr size_t NUM_WARPS{NUM_THREADS / 32};
size_t const block_idx{blockIdx.x};
size_t const thread_idx{threadIdx.x};
__shared__ float shared_data[NUM_WARPS];
float const block_sum{block_reduce_sum_v2<NUM_THREADS, NUM_WARPS>(
input_data + block_idx * num_elements_per_batch, shared_data,
num_elements_per_batch)};
if (thread_idx == 0)
{
output_data[block_idx] = block_sum;
}
}

template <size_t NUM_THREADS>
void launch_batched_reduce_sum_v1(float* output_data, float const* input_data,
size_t batch_size,
size_t num_elements_per_batch,
cudaStream_t stream)
{
size_t const num_blocks{batch_size};
batched_reduce_sum_v1<NUM_THREADS><<<num_blocks, NUM_THREADS, 0, stream>>>(
output_data, input_data, num_elements_per_batch);
CHECK_LAST_CUDA_ERROR();
}

template <size_t NUM_THREADS>
void launch_batched_reduce_sum_v2(float* output_data, float const* input_data,
size_t batch_size,
size_t num_elements_per_batch,
cudaStream_t stream)
{
size_t const num_blocks{batch_size};
batched_reduce_sum_v2<NUM_THREADS><<<num_blocks, NUM_THREADS, 0, stream>>>(
output_data, input_data, num_elements_per_batch);
CHECK_LAST_CUDA_ERROR();
}

float profile_batched_reduce_sum(
std::function<void(float*, float const*, size_t, size_t, cudaStream_t)>
batched_reduce_sum_launch_function,
size_t batch_size, size_t num_elements_per_batch)
{
size_t const num_elements{batch_size * num_elements_per_batch};

cudaStream_t stream;
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));

constexpr float element_value{1.0f};
std::vector<float> input_data(num_elements, element_value);
std::vector<float> output_data(batch_size, 0.0f);

float* d_input_data;
float* d_output_data;

CHECK_CUDA_ERROR(cudaMalloc(&d_input_data, num_elements * sizeof(float)));
CHECK_CUDA_ERROR(cudaMalloc(&d_output_data, batch_size * sizeof(float)));

CHECK_CUDA_ERROR(cudaMemcpy(d_input_data, input_data.data(),
num_elements * sizeof(float),
cudaMemcpyHostToDevice));

batched_reduce_sum_launch_function(d_output_data, d_input_data, batch_size,
num_elements_per_batch, stream);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));

// Verify the correctness of the kernel.
CHECK_CUDA_ERROR(cudaMemcpy(output_data.data(), d_output_data,
batch_size * sizeof(float),
cudaMemcpyDeviceToHost));
for (size_t i{0}; i < batch_size; ++i)
{
if (output_data.at(i) != num_elements_per_batch * element_value)
{
std::cout << "Expected: " << num_elements_per_batch * element_value
<< " but got: " << output_data.at(i) << std::endl;
throw std::runtime_error("Error: incorrect sum");
}
}
std::function<void(cudaStream_t)> const bound_function{std::bind(
batched_reduce_sum_launch_function, d_output_data, d_input_data,
batch_size, num_elements_per_batch, std::placeholders::_1)};
float const latency{measure_performance<void>(bound_function, stream)};
std::cout << "Latency: " << latency << " ms" << std::endl;

// Compute effective bandwidth.
size_t num_bytes{num_elements * sizeof(float) + batch_size * sizeof(float)};
float const bandwidth{(num_bytes * 1e-6f) / latency};
std::cout << "Effective Bandwidth: " << bandwidth << " GB/s" << std::endl;

CHECK_CUDA_ERROR(cudaFree(d_input_data));
CHECK_CUDA_ERROR(cudaFree(d_output_data));
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));

return latency;
}

int main()
{
size_t const batch_size{2048};
size_t const num_elements_per_batch{1024 * 256};

constexpr size_t string_width{50U};
std::cout << std_string_centered("", string_width, '~') << std::endl;
std::cout << std_string_centered("NVIDIA GPU Device Info", string_width,
' ')
<< std::endl;
std::cout << std_string_centered("", string_width, '~') << std::endl;

// Query deive name and peak memory bandwidth.
int device_id{0};
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
std::cout << "Device Name: " << device_prop.name << std::endl;
float const memory_size{static_cast<float>(device_prop.totalGlobalMem) /
(1 << 30)};
std::cout << "Memory Size: " << memory_size << " GB" << std::endl;
float const peak_bandwidth{
static_cast<float>(2.0f * device_prop.memoryClockRate *
(device_prop.memoryBusWidth / 8) / 1.0e6)};
std::cout << "Peak Bandwitdh: " << peak_bandwidth << " GB/s" << std::endl;

std::cout << std_string_centered("", string_width, '~') << std::endl;
std::cout << std_string_centered("Reduce Sum Profiling", string_width, ' ')
<< std::endl;
std::cout << std_string_centered("", string_width, '~') << std::endl;

std::cout << std_string_centered("", string_width, '=') << std::endl;
std::cout << "Batch Size: " << batch_size << std::endl;
std::cout << "Number of Elements Per Batch: " << num_elements_per_batch
<< std::endl;
std::cout << std_string_centered("", string_width, '=') << std::endl;

constexpr size_t NUM_THREADS_PER_BATCH{256};
static_assert(NUM_THREADS_PER_BATCH % 32 == 0,
"NUM_THREADS_PER_BATCH must be a multiple of 32");
static_assert(NUM_THREADS_PER_BATCH <= 1024,
"NUM_THREADS_PER_BATCH must be less than or equal to 1024");

std::cout << "Batched Reduce Sum V1" << std::endl;
float const latency_v1{profile_batched_reduce_sum(
launch_batched_reduce_sum_v1<NUM_THREADS_PER_BATCH>, batch_size,
num_elements_per_batch)};
std::cout << std_string_centered("", string_width, '-') << std::endl;

std::cout << "Batched Reduce Sum V2" << std::endl;
float const latency_v2{profile_batched_reduce_sum(
launch_batched_reduce_sum_v2<NUM_THREADS_PER_BATCH>, batch_size,
num_elements_per_batch)};
std::cout << std_string_centered("", string_width, '-') << std::endl;
}

To build and run the reduce sum example, please run the following commands.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
$ nvcc reduce_sum.cu -o reduce_sum
$ ./reduce_sum
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
NVIDIA GPU Device Info
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Device Name: NVIDIA GeForce RTX 3090
Memory Size: 23.6694 GB
Peak Bandwitdh: 936.096 GB/s
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Reduce Sum Profiling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
==================================================
Batch Size: 2048
Number of Elements Per Batch: 262144
==================================================
Batched Reduce Sum V1
Latency: 2.42976 ms
Effective Bandwidth: 883.83 GB/s
--------------------------------------------------
Batched Reduce Sum V2
Latency: 2.44303 ms
Effective Bandwidth: 879.028 GB/s
--------------------------------------------------

It turns out that the two batched reduce sum kernels have similar performance. The effective bandwidth is about 94% of the peak bandwidth of the GPU. It should be noted that on my system, the effective bandwidth can vary from run to run at different times of the day, from 750 GB/s to 900 GB/s.

Large Array Reduce Sum

What if we have much larger arrays and much smaller batch sizes? The maximum number of threads in a thread block is 1024. If only one thread block is assigned to compute the sum of the elements in a much larger array and the batch size is very small, the GPU utilization and the effective bandwidth will be very low.

In this case, we will need to split a large array into multiple smaller arrays as if each large array is a batch of arrays. We will assign multiple thread blocks to compute the sum of the elements in each smaller array. Once the sum of the elements in each smaller array is computed, we will further reduce the partial sum to the final sum using the batched reduce sum kernels again.

Concretely, suppose a batch of data is of shape (batch_size, num_elements_per_batch), if num_elements_per_batch is very large and batch_size is very small, we can always reshape the data into a shape of (batch_size * inner_batch_size, inner_num_elements_per_batch) and run batched reduce sum kernel. The resulting reduced sum will be of shape (batch_size * inner_batch_size, 1). We can further reshape the reduced sum into a shape of (batch_size, inner_batch_size) (let’s call it (batch_size, num_elements_per_batch) again) and run batched reduce sum kernel. This process can be repeated until the num_elements_per_batch is not too large.

Of course, instead of running the batched reduce sum kernel and synchronization multiple times, we can also try adding the partial sum of each smaller array to the final sum in the global memory using atomic operations. This, however, may or may not have performance degradations comparing to running the batched reduce sum kernel and synchronization multiple times.

References

Author

Lei Mao

Posted on

07-30-2024

Updated on

07-30-2024

Licensed under


Comments