PyTorch Automatic Differentiation

Introduction

PyTorch automatic differentiation is the key to the success of training neural networks using PyTorch. Automatic differentiation usually has two modes, forward mode and backward mode. For a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$, forward mode is more suitable for the scenario where $m \gg n$ and reverse mode is more suitable for the scenario where where $n \gg m$. In deep learning, $n$ is usually the number of parameters and $m$ is the number of outputs during training and most likely $m = 1$. Therefore, in the past few years, deep learning frameworks, such as PyTorch and TensorFlow, have primarily focused on developing the automatic differentiation reverse mode.

Recently, as the implementation of automatic differentiation reverse mode becomes mature and there are increasing demands of automatic differentiation forward mode in some deep learning researches, PyTorch started slowly adding support for automatic differentiation forward mode.

In this blog post, I would like to show how to use PyTorch to compute gradients, specifically the Jacobian, using automatic differentiation forward mode and backward. More details about the mathematical foundations of automatic differentiation could be found in my article “Automatic Differentiation”.

PyTorch Automatic Differentiation

PyTorch 1.11 has started to add support for automatic differentiation forward mode to torch.autograd. In addition, recently an official PyTorch library functorch has been released to allow the JAX-like composable function transforms for PyTorch. This library was developed to overcome some limitations in native PyTorch, including some automatic differentiation deficiencies.

In the examples below, I would like to show how to compute the Jacobian using 6 different kinds of PyTorch interfaces. The test environment uses an Intel Core i9-9900K CPU and an NVIDIA RTX 2080TI GPU. All the source code could be downloaded from my GitHub.

Jacobian for Inputs

Let’s compute the Jacobian for a linear function and measure the performance of automatic differentiation forward and reverse modes. No batch size is considered. Notice that this time we are actually treating weight and bias as constants and input as variable in the example, i.e., we are computing the Jacobian for inputs.

autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html
# Requires PyTorch >= 1.11

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.benchmark import Timer
import torch.autograd.forward_ad as fwAD
from functorch import vmap, vjp, jvp, jacrev, jacfwd


# Linear layer mapping from R^N to R^M
def predict(weight, bias, x):
return F.linear(x, weight, bias)


N = 16
M = 32
weight = torch.randn(M, N)
bias = torch.randn(M)
x = torch.randn(N) # feature vector

# Reverse Mode

primal = x.clone().requires_grad_()
cotangents = torch.eye(M)

# Method 1
# Use PyTorch autograd reverse mode + `for` loop.
rev_jacobian = []
# 1 forward pass.
output = predict(weight, bias, primal)
# M backward pass.
for cotangent in cotangents:
# Compute vjp, where v = cotangent
(jacobian_row, ) = torch.autograd.grad(outputs=(output, ),
inputs=(primal, ),
grad_outputs=(cotangent, ),
retain_graph=True)
rev_jacobian.append(jacobian_row)
jacobian = torch.stack(rev_jacobian)

# Run a sanity check for the Jacobian.
primal = x.clone().requires_grad_()
output = predict(weight, bias, primal)
# This will not work.
# output.backward()
# As PyTorch gradient compute always assume the function has scalar output.
external_grad = torch.ones_like(output)
# This is equivalent to
# output.sum().backward()
output.backward(gradient=external_grad)
grad = primal.grad
assert torch.allclose(jacobian.sum(dim=0), grad)

# Set the jacobian from method 1 as the reference.

# Method 2
# Using functorch vjp + vmap.
_, vjp_fn = vjp(partial(predict, weight, bias), primal)
# In PyTorch autograd backward,
# in order to compute the grad, there is no need to compute to compute Jacobian.
assert torch.allclose(vjp_fn(external_grad)[0], grad)
# A vectorized implementation for computing Jacobian using vjp.
(rev_jacobian, ) = vmap(vjp_fn)(cotangents)
assert torch.allclose(rev_jacobian, jacobian)

# Method 3
# Use functorch jacrev.
# A vectorized implementation for computing Jacobian using vjp.
(rev_jacobian, ) = jacrev(predict, argnums=(2, ))(weight, bias, primal)
assert torch.allclose(rev_jacobian, jacobian)

# Forward Mode

primal = x.clone().requires_grad_()
tangents = torch.eye(N)

# Method 1
# Use PyTorch autograd forward mode + `for` loop.
fwd_jacobian = []
with fwAD.dual_level():
# N forward pass
for tangent in tangents:
dual_input = fwAD.make_dual(primal, tangent)
# Tensors that do not not have an associated tangent are automatically
# considered to have a zero-filled tangent of the same shape.
dual_output = predict(weight, bias, dual_input)
# Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``
# as attributes
jacobian_column = fwAD.unpack_dual(dual_output).tangent
fwd_jacobian.append(jacobian_column)
fwd_jacobian = torch.stack(fwd_jacobian).T
torch.allclose(fwd_jacobian, jacobian)

# Method 2
# Using functorch vjp + `for` loop.
fwd_jacobian = []
# No functorch vmap for jvp
for tangent in tangents:
_, jacobian_column = jvp(func=partial(predict, weight, bias),
primals=(primal, ),
tangents=(tangent, ))
fwd_jacobian.append(jacobian_column)
fwd_jacobian = torch.stack(fwd_jacobian).T
torch.allclose(fwd_jacobian, jacobian)

# Method 3
# Use functorch jacfwd.
(fwd_jacobian, ) = jacfwd(predict, argnums=(2, ))(weight, bias, primal)
assert torch.allclose(fwd_jacobian, jacobian)

# Measure Performance

cpu = torch.device("cpu")
cuda = torch.device("cuda:0")
for device in [cuda, cpu]:
for N, M in [(16, 10240), (10240, 16)]:
print(f"N: {N}, M: {M}, Device: {device}")
weight = torch.randn(M, N).to(device)
bias = torch.randn(M).to(device)
x = torch.randn(N).to(device)

using_fwd = Timer(
stmt="jacfwd(predict, argnums=(2,))(weight, bias, x)",
globals=globals())
using_bwd = Timer(
stmt="jacrev(predict, argnums=(2,))(weight, bias, x)",
globals=globals())

jacfwd_timing = using_fwd.timeit(100)
jacrev_timing = using_bwd.timeit(100)

print(f"Forward mode jacfwd time: {jacfwd_timing.mean * 1000:.5f} ms")
print(f"Reverse mode jacrev time: {jacrev_timing.mean * 1000:.5f} ms")

The performance is as expected. When $n \gg m$, reverse mode is much faster, whereas when $m \gg n$, forward mode is much faster.

1
2
3
4
5
6
7
8
9
10
11
12
13
$ python autograd.py 
N: 16, M: 10240, Device: cuda:0
Forward mode jacfwd time: 0.48102 ms
Reverse mode jacrev time: 2.29559 ms
N: 10240, M: 16, Device: cuda:0
Forward mode jacfwd time: 3.29483 ms
Reverse mode jacrev time: 0.41302 ms
N: 16, M: 10240, Device: cpu
Forward mode jacfwd time: 0.61181 ms
Reverse mode jacrev time: 135.02476 ms
N: 10240, M: 16, Device: cpu
Forward mode jacfwd time: 139.07241 ms
Reverse mode jacrev time: 0.40635 ms

Jacobian for Weights

This time we are actually treating input and bias as constants and weight as variable in the example, i.e., we are computing the Jacobian for weights. Again, no batch size is considered. Computing the Jacobian for weights is slightly more brain-twisting as the Jacobian is 3D matrix instead of a 2D matrix.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html
# Requires PyTorch >= 1.11

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.benchmark import Timer
import torch.autograd.forward_ad as fwAD
from functorch import vmap, vjp, jvp, jacrev, jacfwd


# Linear layer mapping from R^N to R^M
def predict(weight, bias, x):
return F.linear(x, weight, bias)


N = 16
M = 32
weight = torch.randn(M, N)
bias = torch.randn(M)
x = torch.randn(N) # feature vector

# Reverse Mode

primal = weight.clone().requires_grad_()
cotangents = torch.eye(M)

# Method 1
# Use PyTorch autograd reverse mode + `for` loop.
rev_jacobian = []
# 1 forward pass.
output = predict(primal, bias, x)
# M backward pass.
for cotangent in cotangents:
# Compute vjp, where v = cotangent
(jacobian_row, ) = torch.autograd.grad(outputs=(output, ),
inputs=(primal, ),
grad_outputs=(cotangent, ),
retain_graph=True)
rev_jacobian.append(jacobian_row)
jacobian = torch.stack(rev_jacobian)
# This is a "3D" jacobian since weight is 2D.
assert jacobian.shape == torch.Size([M, M, N])

# Run a sanity check for the Jacobian.
primal = weight.clone().requires_grad_()
output = predict(primal, bias, x)
# This will not work.
# output.backward()
# As PyTorch gradient compute always assume the function has scalar output.
external_grad = torch.ones_like(output)
# This is equivalent to
# output.sum().backward()
output.backward(gradient=external_grad)
grad = primal.grad
assert torch.allclose(jacobian.sum(dim=0), grad)

# Set the jacobian from method 1 as the reference.

# Method 2
# Using functorch vjp + vmap.
_, vjp_fn = vjp(partial(predict, bias=bias, x=x), primal)
# In PyTorch autograd backward,
# in order to compute the grad, there is no need to compute to compute Jacobian.
assert torch.allclose(vjp_fn(external_grad)[0], grad)
# A vectorized implementation for computing Jacobian using vjp.
(rev_jacobian, ) = vmap(vjp_fn)(cotangents)
assert torch.allclose(rev_jacobian, jacobian)

# Method 3
# Use functorch jacrev.
# A vectorized implementation for computing Jacobian using vjp.
(rev_jacobian, ) = jacrev(predict, argnums=(0, ))(primal, bias, x)
assert torch.allclose(rev_jacobian, jacobian)

# Forward Mode

primal = weight.clone().requires_grad_()
# tangents = torch.eye(N)

# Method 1
# Use PyTorch autograd forward mode + `for` loop.
fwd_jacobian = []
with fwAD.dual_level():
# N forward pass
for i in range(M):
fwd_jacobian_columns = []
for j in range(N):
tangent = torch.zeros_like(primal)
# print(tangent.shape)
tangent[i, j] = 1
dual_input = fwAD.make_dual(primal, tangent)
# Tensors that do not not have an associated tangent are automatically
# considered to have a zero-filled tangent of the same shape.
dual_output = predict(dual_input, bias, x)
# Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``
# as attributes
jacobian_column = fwAD.unpack_dual(dual_output).tangent
fwd_jacobian_columns.append(jacobian_column)
fwd_jacobian.append(torch.stack(fwd_jacobian_columns).T)
fwd_jacobian = torch.stack(fwd_jacobian)
torch.allclose(fwd_jacobian, jacobian)

# Method 2
# Using functorch vjp + `for` loop.
fwd_jacobian = []
# No functorch vmap for jvp
for i in range(M):
fwd_jacobian_columns = []
for j in range(N):
_, jacobian_column = jvp(func=partial(predict, bias=bias, x=x),
primals=(primal, ),
tangents=(tangent, ))
fwd_jacobian_columns.append(jacobian_column)
fwd_jacobian.append(torch.stack(fwd_jacobian_columns).T)
fwd_jacobian = torch.stack(fwd_jacobian)
torch.allclose(fwd_jacobian, jacobian)

# Method 3
# Use functorch jacfwd.
(fwd_jacobian, ) = jacfwd(predict, argnums=(0, ))(primal, bias, x)
assert torch.allclose(fwd_jacobian, jacobian)

# Measure Performance

cpu = torch.device("cpu")
cuda = torch.device("cuda:0")
for device in [cuda, cpu]:
for N, M in [(16, 1024), (1024, 16)]:
print(f"M x N: {M * N}, M: {M}, Device: {device}")
weight = torch.randn(M, N).to(device)
bias = torch.randn(M).to(device)
x = torch.randn(N).to(device)

using_fwd = Timer(
stmt="jacfwd(predict, argnums=(0,))(weight, bias, x)",
globals=globals())
using_bwd = Timer(
stmt="jacrev(predict, argnums=(0,))(weight, bias, x)",
globals=globals())

jacfwd_timing = using_fwd.timeit(100)
jacrev_timing = using_bwd.timeit(100)

print(f"Forward mode jacfwd time: {jacfwd_timing.mean * 1000:.5f} ms")
print(f"Reverse mode jacrev time: {jacrev_timing.mean * 1000:.5f} ms")

The performance is still as expected. Because it is always true that $n \times m > m$, reverse mode is much faster than forward mode.

1
2
3
4
5
6
7
8
9
10
11
12
13
$ python autograd_weights.py 
M x N: 16384, M: 1024, Device: cuda:0
Forward mode jacfwd time: 5.61630 ms
Reverse mode jacrev time: 0.43003 ms
M x N: 16384, M: 16, Device: cuda:0
Forward mode jacfwd time: 3.96956 ms
Reverse mode jacrev time: 0.41933 ms
M x N: 16384, M: 1024, Device: cpu
Forward mode jacfwd time: 406.97572 ms
Reverse mode jacrev time: 19.89774 ms
M x N: 16384, M: 16, Device: cpu
Forward mode jacfwd time: 302.21520 ms
Reverse mode jacrev time: 0.36077 ms

The Jacobian for bias could also be computed similarly.

Jacobian with Batch

Computing the Jacobian with batch could be even more brain-twisting. However, in the worst scenario, we could sacrifice iterate each sample from the batch, compute the Jacobians (iteratively or in parallel), and stack the Jacobians together.

The following example shows how to compute Jacobian with batch using the PyTorch interface we discussed in the previous sections.

autograd_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html
# Requires PyTorch >= 1.11

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.benchmark import Timer
import torch.autograd.forward_ad as fwAD
from functorch import vmap, vjp, jvp, jacrev, jacfwd


# Linear layer mapping from R^N to R^M
def predict(weight, bias, x):
return F.linear(x, weight, bias)


B = 4
N = 16
M = 32
weight = torch.randn(M, N)
bias = torch.randn(M)
x = torch.randn(B, N) # feature vector

# Reverse Mode

primal = x.clone().requires_grad_()
diagonal_matrix = torch.eye(M)
cotangents = torch.stack([cotangent.repeat(B, 1) for cotangent in diagonal_matrix])

# Method 1
# Use PyTorch autograd reverse mode + `for` loop.
rev_jacobian = []
# 1 forward pass.
output = predict(weight, bias, primal)
# M backward pass.
for cotangent in cotangents:
# Compute vjp, where v = cotangent
(jacobian_row, ) = torch.autograd.grad(outputs=(output, ),
inputs=(primal, ),
grad_outputs=(cotangent, ),
retain_graph=True)
rev_jacobian.append(jacobian_row)
# jacobian: [M, B, N]
jacobian = torch.stack(rev_jacobian)
# jacobian: [B, M, N]
jacobian = jacobian.transpose(1, 0)

# Run a sanity check for the Jacobian.
primal = x.clone().requires_grad_()
output = predict(weight, bias, primal)
# This will not work.
# output.backward()
# As PyTorch gradient compute always assume the function has scalar output.
external_grad = torch.ones_like(output)
# This is equivalent to
# output.sum().backward()
output.backward(gradient=external_grad)
grad = primal.grad
assert torch.allclose(jacobian.sum(dim=1), grad)

# Set the jacobian from method 1 as the reference.

# Method 2
# Using functorch vjp + vmap.
_, vjp_fn = vjp(partial(predict, weight, bias), primal)
# In PyTorch autograd backward,
# in order to compute the grad, there is no need to compute to compute Jacobian.
assert torch.allclose(vjp_fn(external_grad)[0], grad)
# A vectorized implementation for computing Jacobian using vjp.
(rev_jacobian, ) = vmap(vjp_fn)(cotangents)
rev_jacobian = rev_jacobian.transpose(1, 0)
assert torch.allclose(rev_jacobian, jacobian)

# Method 3
# Use functorch jacrev + vmap.
# A vectorized implementation for computing Jacobian using vjp.
# https://pytorch.org/functorch/stable/generated/functorch.vmap.html#functorch.vmap
compute_batch_jacobian = vmap(jacrev(predict, argnums=(2, )), in_dims=(None, None, 0))
(rev_jacobian, ) = compute_batch_jacobian(weight, bias, primal)
assert torch.allclose(rev_jacobian, jacobian)

# Forward Mode

primal = x.clone().requires_grad_()
diagonal_matrix = torch.eye(N)
tangents = torch.stack([cotangent.repeat(B, 1) for cotangent in diagonal_matrix])

# Method 1
# Use PyTorch autograd forward mode + `for` loop.
fwd_jacobian = []
with fwAD.dual_level():
# N forward pass
for tangent in tangents:
dual_input = fwAD.make_dual(primal, tangent)
# Tensors that do not not have an associated tangent are automatically
# considered to have a zero-filled tangent of the same shape.
dual_output = predict(weight, bias, dual_input)
# Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``
# as attributes
jacobian_column = fwAD.unpack_dual(dual_output).tangent
fwd_jacobian.append(jacobian_column)
fwd_jacobian = torch.stack(fwd_jacobian).permute(1, 2, 0)
torch.allclose(fwd_jacobian, jacobian)

# Method 2
# Using functorch vjp + `for` loop.
fwd_jacobian = []
# No functorch vmap for jvp
for tangent in tangents:
_, jacobian_column = jvp(func=partial(predict, weight, bias), primals=(primal, ), tangents=(tangent, ))
fwd_jacobian.append(jacobian_column)
fwd_jacobian = torch.stack(fwd_jacobian).permute(1, 2, 0)
torch.allclose(fwd_jacobian, jacobian)

# Method 3
# Use functorch jacfwd.
compute_batch_jacobian = vmap(jacfwd(predict, argnums=(2, )), in_dims=(None, None, 0))
(fwd_jacobian, ) = compute_batch_jacobian(weight, bias, primal)
assert torch.allclose(fwd_jacobian, jacobian)

# Measure Performance

cpu = torch.device("cpu")
cuda = torch.device("cuda:0")
for device in [cuda, cpu]:
for B, N, M in [(4, 16, 10240), (4, 10240, 16)]:
print(f"B: {B}, N: {N}, M: {M}, Device: {device}")
weight = torch.randn(M, N).to(device)
bias = torch.randn(M).to(device)
x = torch.randn(B, N).to(device)

using_fwd = Timer(
stmt="vmap(jacfwd(predict, argnums=(2, )), in_dims=(None, None, 0))(weight, bias, x)",
globals=globals())
using_bwd = Timer(
stmt="vmap(jacrev(predict, argnums=(2, )), in_dims=(None, None, 0))(weight, bias, x)",
globals=globals())

jacfwd_timing = using_fwd.timeit(100)
jacrev_timing = using_bwd.timeit(100)

print(f"Forward mode jacfwd time: {jacfwd_timing.mean * 1000:.5f} ms")
print(f"Reverse mode jacrev time: {jacrev_timing.mean * 1000:.5f} ms")

Again, the The performance is still as expected.

1
2
3
4
5
6
7
8
9
10
11
12
13
$ python autograd_batch.py 
B: 4, N: 16, M: 10240, Device: cuda:0
Forward mode jacfwd time: 0.61777 ms
Reverse mode jacrev time: 8.34837 ms
B: 4, N: 10240, M: 16, Device: cuda:0
Forward mode jacfwd time: 13.53177 ms
Reverse mode jacrev time: 0.54021 ms
B: 4, N: 16, M: 10240, Device: cpu
Forward mode jacfwd time: 1.40448 ms
Reverse mode jacrev time: 563.17475 ms
B: 4, N: 10240, M: 16, Device: cpu
Forward mode jacfwd time: 574.62579 ms
Reverse mode jacrev time: 0.71066 ms

Conclusions

The functorch interface is much cleaner than the torch.autograd interface. In other use cases, it might do what torch.autograd are restricted to do. Again, I should emphasize that computing Jacobian is expensive. If we just want to compute the gradients in forward mode or reverse mode, we don’t have to compute it explicitly.

References

Author

Lei Mao

Posted on

03-15-2022

Updated on

03-15-2022

Licensed under


Comments