Predicated execution and conditional execution are two different approaches to handling control flow in programming, particularly in the context of parallel computing and GPU programming. Predicated execution involves executing all instructions but only committing the results of those instructions that meet a certain condition, while conditional execution involves executing instructions only if a certain condition is met.
In this blog post, I would like to quickly show a few examples of predicated execution and conditional execution using PyTorch and discuss how to choose between them in different scenarios.
Predicated Execution
A common PyTorch API that uses predicated execution is torch.where. The torch.where function takes a condition tensor and two other tensors, and it returns a new tensor where each element is selected from one of the two input tensors based on the corresponding value in the condition tensor.
In the following example, we have two branches of neural network which produces outputs of exactly the same metadata (shape, dtype, device). We can use torch.where to select the output from one of the two branches based on a condition tensor. However, both branches are executed regardless of the condition.
defforward(self, x: torch.Tensor) -> torch.Tensor: # Both branches are executed and produce tensors with identical metadata. y_a = self.branch_a(x) y_b = self.branch_b(x)
# Use the same scalar predicate as the conditional-execution examples. pred = x.mean() > 0
# Predicated execution: both branches run, then one full output is selected. y = torch.where(pred, y_a, y_b) return y
x = torch.randn(batch_size, hidden_dim, device=device)
with torch.device(device): model = PredicatedModel(hidden_dim)
y = model(x)
Conditional Execution
In contrast, conditional execution involves executing only the branch of code that meets a certain condition. In this example, we use torch.cond to achieve conditional execution. The torch.cond function takes a condition tensor and two functions, and it executes only the function corresponding to the value of the condition tensor.
x = torch.randn(batch_size, hidden_dim, device=device)
with torch.device(device): model = ConditionalModel(hidden_dim)
y = model(x)
One caveat of torch.cond is that it will result in host-device synchronization because the instructions in the two branches are dynamically dispatched to the GPU at runtime.
Changing the execution framework will not eliminate this synchronization. Even if it is compiled by AOTInductor or TensorRT running on GPU or it is jax.cond in JAX and compiled by XLA running on TPU, verified by experiments not discussed here, it will still always result in host-device synchronization.
Predicated Execution VS Conditional Execution
In the above examples, the PredicatedModel and the ConditionalModel produce the same output for the same input. The difference is that the PredicatedModel executes both branches and then selects the output based on the condition, while the ConditionalModel executes only one branch based on the condition at a cost of one host-device synchronization.
Consequently, selecting between predicated execution and conditional execution depends on the specific use case. If both branches are lightweight, which introduces minimal overhead to the system, we should consider using predicated execution. However, if at least one branch is heavy and the negative impact of host-device synchronization to the system weighs less than the overhead of executing both branches, we should consider using conditional execution.
Conditional Execution Kernel Fusion Optimization
Previously, I have been wondering why conditional execution cannot be optimized so that the host-device synchronization can be avoided. For example, a neural network compiler technically can see the code of condition and two branches and generate one kernel that performs the conditional execution on the GPU without host-device synchronization.
In the following example, I have implemented a CUDA kernel that performs conditional execution on the GPU without host-device synchronization. The kernel first computes the condition and then executes one of the two branches based on the computed condition. The scalar predicate is kept on the device and avoids host-side dynamic branch dispatch.
float local_sum = 0.0f; for (size_t i{gtid}; i < n; i += stride) { local_sum += x[i]; }
sdata[tid] = local_sum; __syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); }
if (tid == 0) { block_sums[blockIdx.x] = sdata[0]; }
grid.sync();
// One block reduces partial sums to a scalar predicate for the whole grid. if (blockIdx.x == 0) { float partial = 0.0f; for (size_t b{tid}; b < gridDim.x; b += NUM_THREADS) { partial += block_sums[b]; }
sdata[tid] = partial; __syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); }
cudaDeviceProp prop{}; CHECK_CUDA_ERROR(cudaGetDeviceProperties(&prop, 0)); if (!prop.cooperativeLaunch) { std::printf("Device does not support cooperative kernel launch.\n"); return0; }
CHECK_CUDA_ERROR(cudaMalloc(&d_x, n * sizeof(float))); CHECK_CUDA_ERROR(cudaMalloc(&d_y, n * sizeof(float))); CHECK_CUDA_ERROR( cudaMemcpy(d_x, h_x, n * sizeof(float), cudaMemcpyHostToDevice));
int blocks_per_sm = 0; CHECK_CUDA_ERROR(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &blocks_per_sm, conditional_kernel<NUM_THREADS>, NUM_THREADS, 0));
Even though this seems to be a good solution, it is hard to generalize this approach to all scenarios, because the conditions and branches can be arbitrarily complex, and the compiler may not be able to generate a single kernel that handles all cases.
When the conditions and branches are simple, the compiler might be able to generate a single kernel that performs the conditional execution on the GPU without host-device synchronization, as demonstrated in the above example. However, its performance might only just be slightly better than the predicated execution approach, by saving two kernel launch overheads and one branch execution. When the predicated execution approach is also being optimized by the compiler, depending on the branch instructions, it is possible to have horizontal fusion for the two branches, which makes the performance difference between the two approaches even smaller.
Because of these, it is not worth the effort to implement a general solution for conditional execution fusion kernel optimization. Instead, the user will have to choose between predicated execution and conditional execution based on the specific use case, as discussed in the previous section.