Predicated Execution VS Conditional Execution

Introduction

Predicated execution and conditional execution are two different approaches to handling control flow in programming, particularly in the context of parallel computing and GPU programming. Predicated execution involves executing all instructions but only committing the results of those instructions that meet a certain condition, while conditional execution involves executing instructions only if a certain condition is met.

In this blog post, I would like to quickly show a few examples of predicated execution and conditional execution using PyTorch and discuss how to choose between them in different scenarios.

Predicated Execution

A common PyTorch API that uses predicated execution is torch.where. The torch.where function takes a condition tensor and two other tensors, and it returns a new tensor where each element is selected from one of the two input tensors based on the corresponding value in the condition tensor.

In the following example, we have two branches of neural network which produces outputs of exactly the same metadata (shape, dtype, device). We can use torch.where to select the output from one of the two branches based on a condition tensor. However, both branches are executed regardless of the condition.

predicated_execution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn


class BranchA(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim), nn.ReLU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class BranchB(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim), nn.GELU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class PredicatedModel(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.branch_a = BranchA(dim)
self.branch_b = BranchB(dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Both branches are executed and produce tensors with identical metadata.
y_a = self.branch_a(x)
y_b = self.branch_b(x)

# Use the same scalar predicate as the conditional-execution examples.
pred = x.mean() > 0

# Predicated execution: both branches run, then one full output is selected.
y = torch.where(pred, y_a, y_b)
return y


if __name__ == "__main__":

torch.manual_seed(0)
torch.cuda.set_sync_debug_mode(debug_mode="warn")

device = "cuda"

batch_size = 8
hidden_dim = 16

x = torch.randn(batch_size, hidden_dim, device=device)

with torch.device(device):
model = PredicatedModel(hidden_dim)

y = model(x)

Conditional Execution

In contrast, conditional execution involves executing only the branch of code that meets a certain condition. In this example, we use torch.cond to achieve conditional execution. The torch.cond function takes a condition tensor and two functions, and it executes only the function corresponding to the value of the condition tensor.

conditional_execution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn


class BranchA(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim), nn.ReLU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class BranchB(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim), nn.GELU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class ConditionalModel(nn.Module):

def __init__(self, dim: int) -> None:
super().__init__()
self.branch_a = BranchA(dim)
self.branch_b = BranchB(dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
pred = x.mean() > 0

def true_fn(x_: torch.Tensor) -> torch.Tensor:
return self.branch_a(x_)

def false_fn(x_: torch.Tensor) -> torch.Tensor:
return self.branch_b(x_)

# torch.cond traces both branches but executes only one branch at runtime.
return torch.cond(pred, true_fn, false_fn, (x, ))


if __name__ == "__main__":

torch.manual_seed(0)
torch.cuda.set_sync_debug_mode(debug_mode="warn")

device = "cuda"

batch_size = 8
hidden_dim = 16

x = torch.randn(batch_size, hidden_dim, device=device)

with torch.device(device):
model = ConditionalModel(hidden_dim)

y = model(x)

One caveat of torch.cond is that it will result in host-device synchronization because the instructions in the two branches are dynamically dispatched to the GPU at runtime.

Changing the execution framework will not eliminate this synchronization. Even if it is compiled by AOTInductor or TensorRT running on GPU or it is jax.cond in JAX and compiled by XLA running on TPU, verified by experiments not discussed here, it will still always result in host-device synchronization.

Predicated Execution VS Conditional Execution

In the above examples, the PredicatedModel and the ConditionalModel produce the same output for the same input. The difference is that the PredicatedModel executes both branches and then selects the output based on the condition, while the ConditionalModel executes only one branch based on the condition at a cost of one host-device synchronization.

Consequently, selecting between predicated execution and conditional execution depends on the specific use case. If both branches are lightweight, which introduces minimal overhead to the system, we should consider using predicated execution. However, if at least one branch is heavy and the negative impact of host-device synchronization to the system weighs less than the overhead of executing both branches, we should consider using conditional execution.

Conditional Execution Kernel Fusion Optimization

Previously, I have been wondering why conditional execution cannot be optimized so that the host-device synchronization can be avoided. For example, a neural network compiler technically can see the code of condition and two branches and generate one kernel that performs the conditional execution on the GPU without host-device synchronization.

In the following example, I have implemented a CUDA kernel that performs conditional execution on the GPU without host-device synchronization. The kernel first computes the condition and then executes one of the two branches based on the computed condition. The scalar predicate is kept on the device and avoids host-side dynamic branch dispatch.

conditional_execution.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#include <algorithm>
#include <cooperative_groups.h>
#include <cstdio>
#include <cstdlib>
#include <cuda_runtime.h>
#include <iostream>

namespace cg = cooperative_groups;

#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, char const* func, char const* file, int line)
{
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
std::exit(EXIT_FAILURE);
}
}

#define CHECK_LAST_CUDA_ERROR() check_last(__FILE__, __LINE__)
void check_last(char const* file, int line)
{
cudaError_t const err{cudaGetLastError()};
if (err != cudaSuccess)
{
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
<< std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
std::exit(EXIT_FAILURE);
}
}

__device__ __forceinline__ bool compute_condition(float sum, int n)
{
return (sum / static_cast<float>(n)) > 0.0f;
}

__device__ __forceinline__ float branch_true(float x)
{
// Example branch: y = 2x + 1
return 2.0f * x + 1.0f;
}

__device__ __forceinline__ float branch_false(float x)
{
// Example branch: y = x^2 - 1
return x * x - 1.0f;
}

template <size_t NUM_THREADS>
__global__ void
conditional_kernel(float const* __restrict__ x, float* __restrict__ y, size_t n,
float* __restrict__ block_sums, bool* __restrict__ pred_ptr)
{
cg::grid_group grid = cg::this_grid();

size_t const tid{threadIdx.x};
size_t const gtid{blockIdx.x * NUM_THREADS + tid};
size_t const stride{NUM_THREADS * gridDim.x};

__shared__ float sdata[NUM_THREADS];

float local_sum = 0.0f;
for (size_t i{gtid}; i < n; i += stride)
{
local_sum += x[i];
}

sdata[tid] = local_sum;
__syncthreads();

for (int s = blockDim.x / 2; s > 0; s >>= 1)
{
if (tid < s)
{
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}

if (tid == 0)
{
block_sums[blockIdx.x] = sdata[0];
}

grid.sync();

// One block reduces partial sums to a scalar predicate for the whole grid.
if (blockIdx.x == 0)
{
float partial = 0.0f;
for (size_t b{tid}; b < gridDim.x; b += NUM_THREADS)
{
partial += block_sums[b];
}

sdata[tid] = partial;
__syncthreads();

for (int s = blockDim.x / 2; s > 0; s >>= 1)
{
if (tid < s)
{
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}

if (tid == 0)
{
*pred_ptr = compute_condition(sdata[0], n);
}
}
grid.sync();

bool const pred{*pred_ptr};

for (size_t idx{gtid}; idx < n; idx += stride)
{
float const xi{x[idx]};
y[idx] = pred ? branch_true(xi) : branch_false(xi);
}
}

int main()
{
size_t n{8};
constexpr size_t NUM_THREADS{256};

float h_x[n] = {-2.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 2.0f, 3.0f};
float h_y[n] = {0.0f};

float* d_x = nullptr;
float* d_y = nullptr;
float* d_block_sums = nullptr;
bool* d_pred = nullptr;

cudaDeviceProp prop{};
CHECK_CUDA_ERROR(cudaGetDeviceProperties(&prop, 0));
if (!prop.cooperativeLaunch)
{
std::printf("Device does not support cooperative kernel launch.\n");
return 0;
}

CHECK_CUDA_ERROR(cudaMalloc(&d_x, n * sizeof(float)));
CHECK_CUDA_ERROR(cudaMalloc(&d_y, n * sizeof(float)));
CHECK_CUDA_ERROR(
cudaMemcpy(d_x, h_x, n * sizeof(float), cudaMemcpyHostToDevice));

int blocks_per_sm = 0;
CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&blocks_per_sm, conditional_kernel<NUM_THREADS>, NUM_THREADS, 0));

int const max_blocks{blocks_per_sm * prop.multiProcessorCount};
int const blocks{std::max(
1, std::min(static_cast<int>((n + NUM_THREADS - 1) / NUM_THREADS),
max_blocks))};

CHECK_CUDA_ERROR(cudaMalloc(&d_block_sums, blocks * sizeof(float)));
CHECK_CUDA_ERROR(cudaMalloc(&d_pred, sizeof(bool)));

void* args[] = {&d_x, &d_y, &n, &d_block_sums, &d_pred};
CHECK_CUDA_ERROR(cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(conditional_kernel<NUM_THREADS>), blocks,
NUM_THREADS, args, 0, 0));
CHECK_LAST_CUDA_ERROR();

CHECK_CUDA_ERROR(
cudaMemcpy(h_y, d_y, n * sizeof(float), cudaMemcpyDeviceToHost));
CHECK_CUDA_ERROR(cudaDeviceSynchronize());

for (size_t i{0}; i < n; ++i)
{
std::printf("x=%6.2f -> y=%6.2f\n", h_x[i], h_y[i]);
}

CHECK_CUDA_ERROR(cudaFree(d_x));
CHECK_CUDA_ERROR(cudaFree(d_y));
CHECK_CUDA_ERROR(cudaFree(d_block_sums));
CHECK_CUDA_ERROR(cudaFree(d_pred));

return 0;
}

Even though this seems to be a good solution, it is hard to generalize this approach to all scenarios, because the conditions and branches can be arbitrarily complex, and the compiler may not be able to generate a single kernel that handles all cases.

When the conditions and branches are simple, the compiler might be able to generate a single kernel that performs the conditional execution on the GPU without host-device synchronization, as demonstrated in the above example. However, its performance might only just be slightly better than the predicated execution approach, by saving two kernel launch overheads and one branch execution. When the predicated execution approach is also being optimized by the compiler, depending on the branch instructions, it is possible to have horizontal fusion for the two branches, which makes the performance difference between the two approaches even smaller.

Because of these, it is not worth the effort to implement a general solution for conditional execution fusion kernel optimization. Instead, the user will have to choose between predicated execution and conditional execution based on the specific use case, as discussed in the previous section.

References

Author

Lei Mao

Posted on

07-01-2026

Updated on

07-01-2026

Licensed under


Comments