Chain Rule

Introduction

Extending from univariable chain rule to multivariable functions can be confusing sometimes. Using the de novo chain rule expression, which is based on the matrix multiplication of Jacobian matrices, the chain rule can be expressed in a more intuitive way.

In this blog post, I would like to discuss the de novo chain rule expression, how it unifies the univariable chain rule and multivariable chain rule, and how it can be applied to different areas of mathematics.

The Confusing Chain Rule Expressions

Univariable Chain Rule

Let w=f(x) be a differentiable function of x and x=g(t) be a differentiable function of t. Then w=f(g(t)) is a differentiable function of t and

dwdt=dfdxdxdt

Multivariable Chain Rule

We could often see the following expression of multivariable chain rule.

Let w=f(x1,x2,,xm) be a differentiable function of m independent variables, and for each i[1,m], let xi=gi(t1,t2,,tn) be a differentiable function of n independent variables. Then w=f(g1(t1,t2,,tn),g2(t1,t2,,tn),,gn(t1,t2,,tn)) is a differentiable function of n independent variables and

wtj=i=1mwxixitj

for each j[1,n].

This expression is nothing wrong. But its form is different from the univariable chain rule expression. When it comes to vector calculus, which is often used in neural networks, they become confusing and less useful.

The De Novo Chain Rule Expression

The de novo chain rule expression is more intuitive and are more applicable for different areas of mathematics.

In calculus, the chain rule is a formula that expresses the derivative of the composition of two differentiable functions f:YZ and g:XY in terms of the derivatives of f and g. More precisely, if h=fg:XZ is the composition of f and g such that h(x)=f(g(x)), the chain rule states that

h(x)=f(g(x))g(x)

or equivalently

h=(fg)g

Note that the domains X, Y, and Z are not limited to real numbers.

More generally, suppose XRn, YRm, and ZRk, we have

h(x)=f(g(x))

where xRn, g(x)Rm, and f(g(x))Rk.

The chain rule becomes

h(x)=f(g(x))g(x)

The chain rule can be equivalently expressed using the matrix multiplication of Jacobian matrices

Jh(x)=Jf(g(x))Jg(x)

where Jh(x) is the Jacobian matrix of h with respect to x, which is a matrix in Rk×n. Jf(g(x)) is the Jacobian matrix of f with respect to g(x), which is a matrix in Rk×m. Jg(x) is the Jacobian matrix of g with respect to x, which is a matrix in Rm×n.

The chain rule using the Jacobian matrix unifies the univariable chain rule and multivariable chain rule.

Examples

Gradient of Linear Functions

Consider a linear function f(x)=ax, where aRn and xRn. The gradient of f with respect to x is given by

f(x)=a

This is very easy and straightforward to verify using the scalar form of the linear function.

Consider linear functions f(x)=Ax, where ARm×n and xRn. The gradient of f with respect to x is given by

Jf(x)=[f1(x),f2(x),,fm(x)]

Because fi(x)=Aix, where Ai is the i-th row of A , we have

fi(x)=Ai

Therefore, the gradient of f with respect to x is given by

Jf(x)=[f1(x),f2(x),,fm(x)]=[A1,A2,,Am]=A=A

Consider linear functions f(X)=AX, where ARm×n and XRn×k, which is also known as matrix multiplication. Because

f(X)=[f(X:,1),f(X:,2),,f(X:,k)]

where X:,i is the i-th column of X, and f(X:,i)=AX:,i, we have the gradient of f with respect to X:,i is given by

Jf(X:,i)=A

Gradient of Quadratic Functions

Consider a quadratic function f(x)=xAx, where ARn×n and xRn. We could create the functional composition of f instead.

We define the following new functions g and h, such that f=gh.

We first define the function g as

g(x)=x1x2

where x=[x1,x2] and x1,x2Rn.

The partial derivative of g with respect to x is given by

g(x)=[gx1(x),gx2(x)]=[x2,x1]

We then define the function h as

h(x)=[h1(x),h2(x)]=[Ix,Ax]=[x,Ax]

where x1,x2Rn, h1(x)=Ix, and h2(x)=Ax.

The Jacobian matrix of h with respect to x is given by

Jh(x)=[Jh1(x),Jh2(x)]=[I,A]

Then we have

f(x)=g(h(x))=g([x,Ax])=xAx

Using the chain rule, we have

f(x)=g(h(x))Jh(x)=[Ax,x][I,A]=[(Ax),x][I,A]=(Ax)I+xA=xA+xA=x(A+A)

Therefore,

f(x)=f(x)=(x(A+A))=(A+A)x=(A+A)x

Hessian of Quadratic Functions

Because we have already derived the gradient of quadratic functions, we could easily derive the Hessian of quadratic functions.

The gradient of quadratic functions is given by

f(x)=(A+A)x

The Hessian of quadratic functions is a Jacobian matrix of the gradient of quadratic functions with respect to x.

The Hessian of quadratic functions is given by

H(x)=Jf(x)=A+A

Least Squares

The least squares problem is a common optimization problem in machine learning and statistics. It can be formulated as minimizing the sum of squared differences between the observed values and the predicted values.

The least squares problem objective function could be defined as

f(w)=Xwy2

where XRm×n is the design matrix, wRn is the vector of coefficients, yRm is the vector of observed values, and is the L2 norm.

The least squares problem objective function could be usually rewritten as

f(w)=(Xwy)(Xwy)=((Xw)y)(Xwy)=(wXy)(Xwy)=wXXwyXwwXy+yy=wXXwyXwyXw+yy=wXXw2yXw+yy

We have to find the optimal w such that the gradient of f with respect to w is equal to zero.

f(w)=(wXXw2yXw+yy)=(wXXw)(2yXw)+0=(XX+(XX))w2Xy=(XX+XX)w2Xy=2XXw2Xy

It is also feasible to use the chain rule to derive the gradient of least squares problem objective function.

We define the following new functions g and h, such that f=gh.

We first define the function g as

g(x)=x2=xx

where xRm.

The gradient of g with respect to x is given by

g(x)=2x

This is very easy and straightforward to verify using the scalar form of the linear function.

We then define the function h as

h(w)=Xwy=Xwy

where wRn.

The Jacobian matrix of h with respect to w is given by

Jh(w)=X

Then we have

f(w)=g(h(w))=g(Xwy)=Xwy2

Using the chain rule, we have

f(w)=g(h(w))Jh(w)=2(Xwy)X=2(wXy)X=2wXX2yX

Therefore,

f(w)=f(w)=(2wXX2yX)=2XXw2Xy

References

Author

Lei Mao

Posted on

04-06-2025

Updated on

04-06-2025

Licensed under


Comments