Online Safe Softmax
Introduction
Online safe softmax is a numerically stable and efficient algorithm for computing the softmax function. It has inspired more sophisticated algorithms such as FlashAttention.
In this blog post, I will quickly discuss the original softmax function, the safe softmax function, and the online safe softmax function.
Original Softmax
The original softmax function $y = \text{softmax}(x)$ is defined as
$$
y_i = \frac{\exp(x_i)}{\sum_{j=1}^{n} \exp(x_j)}
$$
where $x = [x_1, x_2, …, x_n]$ is the input vector and $y = [y_1, y_2, …, y_n]$ is the output vector. The softmax function is widely used in machine learning models to normalize the output of a neural network to a probability distribution.
Typically, to compute each output $y_i$, each input $x_i$ has to be read twice from the memory. Once to compute the exponential $\exp(x_i)$ for the normalizer accumulation, and once for the enumerator $\exp(x_i)$ for the output $y_i$. There are 2 reads and 1 write for each element of the output in the softmax computation.
Safe Softmax
The original softmax function is numerically unstable when the input vector $x$ contains many very positive values or all the values are very negative. To make it numerically stable, we can subtract the maximum value of the input vector from each element of the input vector before applying the softmax function. The safe softmax function is defined as
$$
y_i = \frac{\exp(x_i - \max(x))}{\sum_{j=1}^{n} \exp(x_j - \max(x))}
$$
In this way, the maximum value of $\exp(x_i - \max(x))$ is 1.0, the minimum value of $\exp(x_i - \max(x))$ is 0.0, and the denominator $\sum_{j=1}^{n} \exp(x_j - \max(x))$ will not overflow or underflow. Consequently, the safe softmax function is numerically stable.
Typically, the maximum value of the input vector $\max(x)$ has to be computed before applying the safe softmax function. This increases the number of reads from the memory to 3 reads and 1 write for each element of the output in the softmax computation.
Online Safe Softmax
The key idea of the online safe softmax is that the denominator $\sum_{j=1}^{n} \exp(x_j - \max(x))$ and $\max(x)$ can be computed online simultaneously from a stream of inputs, so that $\max(x)$ does not have to be computed beforehand, saving one read from the memory.
Assuming at the $m$-th input from the stream, the online maximum value is $\max([x_1, x_2, …, x_m])$ and the online accumulated normalizer is $\sum_{j=1}^{m} \exp(x_j - \max([x_1, x_2, …, x_m]))$. When it comes to the $m+1$-th input from the stream, there are two cases:
If $x_{m+1} \leq \max([x_1, x_2, …, x_m])$, then $\max([x_1, x_2, …, x_{m+1}]) = \max([x_1, x_2, …, x_m])$. The online accumulated normalizer becomes
$$
\sum_{j=1}^{m+1} \exp(x_j - \max([x_1, x_2, …, x_{m+1}])) = \sum_{j=1}^{m} \exp(x_j - \max([x_1, x_2, …, x_m])) + \exp(x_{m+1} - \max([x_1, x_2, …, x_m]))
$$
If $x_{m+1} > \max([x_1, x_2, …, x_m])$, then $\max([x_1, x_2, …, x_{m+1}]) = x_{m+1}$. The online accumulated normalizer becomes
$$
\begin{align}
\sum_{j=1}^{m+1} \exp(x_j - \max([x_1, x_2, …, x_{m+1}]))
&= \frac{\exp(\max([x_1, x_2, …, x_m]))}{\exp(\max([x_1, x_2, …, x_{m+1}]))} \sum_{j=1}^{m} \exp(x_j - \max([x_1, x_2, …, x_m])) + \exp(x_{m+1} - \max([x_1, x_2, …, x_{m+1}])) \\
&= \exp(\max([x_1, x_2, …, x_m]) - \max([x_1, x_2, …, x_{m+1}])) \sum_{j=1}^{m} \exp(x_j - \max([x_1, x_2, …, x_m])) + 1 \\
\end{align}
$$
In this way, there are 2 reads and 1 write for each element of the output in the online safe softmax computation. Because the softmax clearly is a memory-bound operation, the theoretical peak performance of the online safe softmax is the same as the original softmax and is $\frac{4}{3} = 1.33 \times$ faster than the safe softmax.
Online Safe Softmax + TopK
The online safe softmax algorithm can be further fused with most of the other downstream memory-bound online algorithms such as TopK, argmax, argmin, etc. Because the memory accesses have already been performed in the online safe softmax, the downstream algorithms can directly use the results from the online safe softmax without additional memory accesses, substantially improving the performance compared to executing the algorithms without fusion.
References
Online Safe Softmax