Neural Network Mixed Precision Training
Introduction
Neural network mixed precision has been widely used like a black-box in modern deep learning frameworks to accelerate the training and inference performance.
In this blog post, I would like to quickly discuss the theories behind the mixed precision, primarily focusing on neural network training.
Neural Network Mixed Precision Training
The motivation of low bit-width mixed precision training or inference is straightforward. If we could represent all the values encountered during the neural network training or inference using low bit-width formats, usually less than 32 bit-width, we could have more values to be processed usually in a more dedicated processor in unit amount of time, given a fixed memory bandwidth, and therefore the neural network training or inference is accelerated.
The drawback of low bit-width formats is that they usually have smaller dynamic range and precisions comparing to 32-bit formats, such as FP32, and value overflow or underflow are more likely to happen. Therefore, directly using low bit-width formats everywhere throughout the neural network is not applicable, if there are lots of values that fall outside the low bit-width format representable range or the values cannot be represented using the low-width format with sufficient precision.
In fact, the neural network forward pass (inference) is usually not very sensitive to low bit-width formats. The values involved in this process are mostly weight tensors and activation tensors. So just making the weight tensors and activation tensors low bit-width will significantly improves performance while maintaining the accuracy. There are some tensors, such as the batch normalization statistics, that usually have to remain in high bit-width in order to achieve the same accuracy. But since they are usually fused into other operations or not dominating the entire process, it’s OK to process those tensors in high bit-width. For many neural networks, even without post-training quantization, FP16 is able to achieve the same accuracy as FP32. Under some special settings, such as by only making the weight tensor and the input activation tensor to certain operations, such as GEMM, FP8 is also able to achieve the same accuracy as FP32 for some neural networks.
The neural network backward pass, however, is usually sensitive to low bit-width formats. In particular, some gradient values can be very small but useful, and the values in the gradient tensors can often fall outside the representable range of low bit-width formats, such as FP16. In almost all the cases, those values underflow and therefore are clipped to 0. Suppose for a gradient tensor with respective to activation $\frac{\partial L}{\partial \mathbf{x}}$ for some layer there are many elements, such as $\frac{\partial L}{\partial \mathbf{x}_i}$, are clipped to 0 due to low bit-width, all the weights that contributes those activations, such as $\frac{\partial L}{\partial \mathbf{w}_j} = \frac{\partial L}{\partial \mathbf{x}_i} \frac{\partial \mathbf{x}_i}{\partial \mathbf{w}_j}$, will also be 0, due to the chain-rule. Therefore, representing the small gradient with respect to activation using low bit-width formats is critical for training a neural network. But the question is, how to representing the small gradient with respect to activation using low bit-width formats?
Suppose $x$ is a small value that is outside the representable range of a low bit-width format, given some scaling factor $s$, we have
$$
x = \left(s x\right) \frac{1}{s}
$$
where $sx$ can be represented by the low-width format.
We could therefore scale those underflow gradient tensors using a scaled tensor and a scaling factor. Suppose for a gradient tensor with respective to activation $\frac{\partial L}{\partial \mathbf{x}}$ for some layer there are many elements, such as $\frac{\partial L}{\partial \mathbf{x}_i}$, are clipped to 0 due to low bit-width, instead of using the actual underflow gradient tensor, we use the scaled and unscaled gradient $\frac{1}{s} \left(s\frac{\partial L}{\partial \mathbf{x}_i}\right) \frac{\partial \mathbf{x}_i}{\partial \mathbf{w}_j}$ for updating the weight. Notice that $\frac{1}{s} \left(s\frac{\partial L}{\partial \mathbf{x}_i}\right) \frac{\partial \mathbf{x}_i}{\partial \mathbf{w}_j}$ can still underflow, but it will just mean the weight update is 0, and the back propagation is not prohibited due to the gradient tensor with respective to activation underflow.
The magnitude of weight update can be very small and when it’s added to the weight in low bit-width format, it can have no effect to the weight. Therefore, the weights should have a master copy in high bit-width format and weight update should be performed in high bit-width format, such as FP32, after computing the weight update in low bit-width format efficiently.
This is a feature of floating point numbers that when small values are added to large values, the small values have no effect to the large values. This is also why the neural network training is usually not sensitive to the precision of the weight tensors and activation tensors. To see an example, consider the following C++ code in which a small FP32 value is added to a large FP32 value.
1 |
|
1 | $ g++ add_small_value.cpp -o add_small_value |
The final neural network from mixed precision training will usually be upgraded to high bit-width format, such as FP32, and saved. This will introduce no loss to the model accuracy. To deploy the neural network using low bit-width format, such as FP16, if the low bit-width format is the one used in mixed precision training, then it should work right out of the box.
If there is a discrepancy between the precision that the network is trained and deployed, there can be some issues in some scenarios. For example, if the training of the neural network uses FP32 and the deployment of the neural network uses FP16, there can be some accuracy loss, although I would expect that it will not be significant. However, if the training of the neural network uses FP32/FP16 and the deployment of the neural network uses FP8, the accuracy loss can be much larger. To remedy the accuracy loss, FP8 post-training quantization might work.
Automatic Mixed Precision Training
Using the scaling factor(s) for backward pass has some options to be considered during training:
- The scaling factor(s) can be all the same or different for each layer in the neural network.
The former is a special case of the latter. Implementing the former is very straight-forward. The user would just have to scale the loss $L$ by $s$ after the forward pass and unscale the gradient with respect to weight before weight update. Implementing the latter requires keeping an scaling factor for each layer specifically, which should not be difficult either. - The scaling factor(s) can be constant or dynamic during training.
The former is a special case of the latter. Picking a scaling factor that is constant throughout the training may require some human trial-and-error. Sometimes, the gradient tensors with respect to activations do not underflow at the beginning of training. However, as the training progress, they start to underflow. This requires the scaling factors to be adjusted dynamically during training. - The scaling factors(s) can be applied to the gradient tensors only or other tensors such as weight tensors, activation tensors, and gradient tensors. The scaling factors are applied to the gradient tensors because they can often not be represented by low bit-width formats. In the inference, those scaling factors are gone. However, for some bit-width formats, such as FP8, the weight tensors and activation tensors can also fall outside the representable range, thus requiring scaling factors. In the inference, those scaling factors remain for scaling and unscaling values between FP32/FP16 and FP8. The automatic mixed precision training also becomes quantization aware training.
As of now, the existing automatic mixed precision frameworks for FP32/FP16 mixed precision training, such as NVIDIA APEX and PyTorch Automatic Mixed Precision (an NVIDIA APEX variant/subset), implemented the algorithm to handle the scenario where the scaling factors are all the same for each layer, the scaling factors can be dynamic during training, and the scaling factors are only applied to the gradient tensors. It works well for most of the neural networks.
Concretely, the FP32/FP16 automatic mixed precision algorithm uses following procedure at a high level:
- Maintain a primary copy of weights in FP32.
- Initialize $S$ to a large value.
- For each iteration:
a. Make an FP16 copy of the weights.
b. Forward propagation (FP16 weights and activations).
c. Multiply the resulting loss with the scaling factor $S$.
d. Backward propagation (FP16 weights, activations, and their gradients).
e. If there is an Inf or NaN in weight gradients:
i. Reduce $S$.
ii. Skip the weight update and move to the next iteration.
f. Multiply the weight gradient with $1/S$.
g. Complete the weight update (including gradient clipping, etc.).
h. If there hasn’t been an Inf or NaN in the last $N$ iterations, increase $S$.
For more aggressive FP32/FP16/FP8 mixed precision training, while still under more thorough study from the industry and research community, some experiments showed that per-tensor scaling factors for some layers might be required for achieving the full precision accuracy for some neural networks. This would suggest that for some neural networks the FP32/FP16/FP8 mixed precision algorithm should also be forward pass (inference) quantization aware, which makes it more complicated than the FP32/FP16 automatic mixed precision algorithm.
Mixed Precision Training VS Quantization Aware Training
Mixed precision training and quantization aware training share a few similarities. As I mentioned above, for aggressive mixed precision training using extremely low bit-width formats, such as FP8, it might also need to be quantization aware for some neural networks, making it less distinguishable from quantization aware training. So they are not exclusive to each other and they can be used simultaneously.
The major difference between mixed precision training and quantization aware training, conceptually, is that mixed precision training emphasizes on the training performance whereas quantization aware training does not. Mixed precision training uses real low bit-width format that usually has hardware support for accelerating both the forward pass and backward pass in neural network training. However, quantization aware training does not even require the low bit-width format and its corresponding hardware support. If there are low bit-width format and its corresponding hardware support, the quantized forward pass can be conducted using low bit-width format with hardware acceleration. If there are no low bit-width format and its corresponding hardware support, the quantized forward pass can still be simulated without hardware acceleration. Simulation is how quantization aware training is implemented in most of the frameworks so that the framework does not have to be dependent on special hardware. Because quantization aware training does not emphasize on training performance, the backward pass does not have to be run in low bit-width, either.
References
Neural Network Mixed Precision Training
https://leimao.github.io/blog/Neural-Network-Mixed-Precision-Training/