PyTorch Leaf Tensor
PyTorch leaf tensor is a concept that is sometimes confusing to the users who are not familiar with the PyTorch’s automatic differentiation engine
In this blog post, I would like to quickly discuss the PyTorch leaf tensor concept from the perspective of mathematics without going into too much implementation detail.
Depending on whether a PyTorch tensor requires gradient and whether a PyTorch tensor is explicitly created by the user, there are four categories of PyTorch tensors. Each tensor has attributes of whether it is a leaf tensor and whether the gradient will be populated for the tensor which are determined by whether the PyTorch tensor requires gradient and whether the PyTorch tensor is explicitly created by the user.
|Requires Grad||User Created||Is Leaf||Grad Populated|
Here, “Requires Grad” is the
requires_grad attribute of a
torch.Tensor indicating whether it is a constant or variable; “User Created” is true means that a
torch.Tensor is not the result of an operation and so the
grad_fn attribute of the
None; “Is Leaf” is true means that a
torch.Tensor is a leaf node in a
torch.autograd directed acyclic graph (DAG) which only consists of a root (tensor) node, many leaf (tensor) nodes, and many intermediate (backward function call) nodes; “Grad Populated” is true means that the gradient with respect to a
torch.Tensor will be saved in the tensor object (for optimization) so that the
grad attribute of the
torch.Tensor will not be
None after a backward pass.
In addition to the examples from the PyTorch documentation which are rather confusing, we have a more concrete example here illustrating the role of leaf node in
$ python3 leaf_tensor.py
In some scenarios, the user would expect that the variable tensor
variable_tensor_cuda would have
grad after the backward pass so that it can be optimized during the neural network training. However, we could see that the
None whereas the
variable_tensor_cpu tensor has
grad. This means the
variable_tensor_cpu is the actually the variable for optimization. After the optimization is performed after the backward pass, the
variable_tensor_cuda value will not be the same as the
variable_tensor_cpu until the next forward pass is performed.
In fact, there is a warning when the user tries to access the
.grad attribute of a non-leaf tensor which by default has no
leaf.py:6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/pytorch/pytorch/build/aten/src/ATen/core/TensorBody.h:480.)
We could also visualize the DAG using a third party library
torchviz library could be installed using the following command.
$ sudo apt update
torch.autograd DAG is built as the Python script is executed.
torchviz can visualize the DAG from a root tensor which is the
loss tensor in our example.
Notice that the DAG visualized using
torchviz will not display the leaf node that does not require grad.
The blue box in the DAG diagram, although having no tensor name, is the leaf tensor
variable_tensor_cpu in our program.
Conventionally, only leaf tensors, usually model parameters to be trained, deserves grad. All the non-leaf tensors, such as the intermediate activation tensors, do not deserve grad. Why would we need to keep a grad for the activation tensors? Even if we keep the grad in the activation tensor and apply the grad to the activation tensor values in the optimization, those values will be overwritten in the next forward pass. So populating grad for non-leaf tensors is usually a waste of memory and computation.
However, in some “rare” use cases, the user would need the grad for non-leaf tensors, and PyTorch has the API
torch.Tensor.retain_grad() for that. But usually it’s not making sense and is an indication of problematic implementation.
PyTorch Leaf Tensor