PyTorch Automatic Differentiation
Introduction
PyTorch automatic differentiation is the key to the success of training neural networks using PyTorch. Automatic differentiation usually has two modes, forward mode and backward mode. For a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$, forward mode is more suitable for the scenario where $m \gg n$ and reverse mode is more suitable for the scenario where where $n \gg m$. In deep learning, $n$ is usually the number of parameters and $m$ is the number of outputs during training and most likely $m = 1$. Therefore, in the past few years, deep learning frameworks, such as PyTorch and TensorFlow, have primarily focused on developing the automatic differentiation reverse mode.
Recently, as the implementation of automatic differentiation reverse mode becomes mature and there are increasing demands of automatic differentiation forward mode in some deep learning researches, PyTorch started slowly adding support for automatic differentiation forward mode.
In this blog post, I would like to show how to use PyTorch to compute gradients, specifically the Jacobian, using automatic differentiation forward mode and backward. More details about the mathematical foundations of automatic differentiation could be found in my article “Automatic Differentiation”.
PyTorch Automatic Differentiation
PyTorch 1.11 has started to add support for automatic differentiation forward mode to torch.autograd
. In addition, recently an official PyTorch library functorch
has been released to allow the JAX-like composable function transforms for PyTorch. This library was developed to overcome some limitations in native PyTorch, including some automatic differentiation deficiencies.
In the examples below, I would like to show how to compute the Jacobian using 6 different kinds of PyTorch interfaces. The test environment uses an Intel Core i9-9900K CPU and an NVIDIA RTX 2080TI GPU. All the source code could be downloaded from my GitHub.
Jacobian for Inputs
Let’s compute the Jacobian for a linear function and measure the performance of automatic differentiation forward and reverse modes. No batch size is considered. Notice that this time we are actually treating weight and bias as constants and input as variable in the example, i.e., we are computing the Jacobian for inputs.
1 | # https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html |
The performance is as expected. When $n \gg m$, reverse mode is much faster, whereas when $m \gg n$, forward mode is much faster.
1 | $ python autograd.py |
Jacobian for Weights
This time we are actually treating input and bias as constants and weight as variable in the example, i.e., we are computing the Jacobian for weights. Again, no batch size is considered. Computing the Jacobian for weights is slightly more brain-twisting as the Jacobian is 3D matrix instead of a 2D matrix.
1 | # https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html |
The performance is still as expected. Because it is always true that $n \times m > m$, reverse mode is much faster than forward mode.
1 | $ python autograd_weights.py |
The Jacobian for bias could also be computed similarly.
Jacobian with Batch
Computing the Jacobian with batch could be even more brain-twisting. However, in the worst scenario, we could sacrifice iterate each sample from the batch, compute the Jacobians (iteratively or in parallel), and stack the Jacobians together.
The following example shows how to compute Jacobian with batch using the PyTorch interface we discussed in the previous sections.
1 | # https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html |
Again, the The performance is still as expected.
1 | $ python autograd_batch.py |
Conclusions
The functorch
interface is much cleaner than the torch.autograd
interface. In other use cases, it might do what torch.autograd
are restricted to do. Again, I should emphasize that computing Jacobian is expensive. If we just want to compute the gradients in forward mode or reverse mode, we don’t have to compute it explicitly.
References
PyTorch Automatic Differentiation
https://leimao.github.io/blog/PyTorch-Automatic-Differentiation/