Loading [MathJax]/extensions/tex2jax.js

Forward versus Reverse Mode Automatic Differentiation understood as linear system solving

Alec Jacobson

September 20, 2018

weblog/

Consider we have some function: \(f : ℝ^n → ℝ^m\):

$$ y = f(x) $$

where the input \(x∈ℝ^n\) and the output \(y∈ℝ^m\). For example, the non-linear function \(f\) maps 3 input numbers to two outputs:

$$ \begin{bmatrix} y_1 \\ y_2 \end{bmatrix} = \begin{bmatrix} x_1^2+x_2 \\ x_2 \sin x_3 \end{bmatrix} $$

Writing this out as a computer program and forcing ourselves to only ever conduct a single simple operation at a time we might see something like:

$$ v_1 = x_1^2 \\ v_2 = \sin x_3 \\ y_1 = v_1 + x_2 \\ y_2 = x_2 v_2 $$

where we introduce a vector v of any number of \(k\) auxiliary variables, \(v∈ℝ^k\). In general, these equations for \(v\) look like:

$$ v_i = g_i(x,v_1,v_2, ... , v_{i-1}). $$

In particular, \(v_i\) only depends on \(v_j\) if \(j < i\) (i.e., previously computed auxiliary variables). Similarly,

$$ y_i = h_i(x,v) $$

where \(y_i\) doesn't depend on any other \(y_j\).

Suppose we'd like to compute the matrix of derivatives of all output variables \(y\) with respect to the input variables \(x\), i.e., the Jacobian \(J = ∇ y\), where for the rest of this document ∇ will mean collecting partial differentiation with respect to \(x\): \(∇ = (∂/∂x_1, ... ,∂/∂x_n)\).

The goal of automatic differentiation is to compute this Jacobian for any given value of \(x\) with minor modification to our program for computing the output \(y\). So rather than explicitly coding up a separate program for computing \(∇y\), we'll sprinkle some code on top of the basic "single simple operations" above and compute \(∇y\) automatically.

At this point, most explanations of Automatic Differentiation start tearing away at the chain rule. Instead, we can look at automatic differentiation as the solution to a (typically very sparse) triangular linear system.

Let's first identify all of the variables we're dealing with:

$$ a = \begin{bmatrix} x \\ v \\ y \\ \end{bmatrix} $$

We'd like to compute \(∇y\), which is just the bottom rows of

$$ ∇a = \begin{bmatrix} ∇x \\ ∇v \\ ∇y \\ \end{bmatrix} $$

By applying \(∇\) to both sides of the equations defining each \(v_i\) (or similarly \(y_i\)) above we get:

$$ ∇v_i = ∇g_i(x,v_1,v_2, ... , v_{i-1}). $$

Let's see what this looks like for our simple example:

$$ ∇ v_1 = ∇(x_1^2) \\ ∇ v_2 = ∇(\sin x_3) \\ ∇ y_1 = ∇(v_1 + x_2) \\ ∇ y_2 = ∇(x_2 v_2) \\ $$

becomes

$$ ∇ v_1 = 2∇x_1 \\ ∇ v_2 = \cos(x_3) ∇x_3 \\ ∇ y_1 = ∇v_1 + ∇x_2 \\ ∇ y_2 = x_2∇v_2 + v_2∇x_2 \\ $$

Writing this in matrix form we have:

$$ \begin{bmatrix} -2 & 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 &-\cos x_3 & 0 & 1 & 0 & 0 \\ 0 & -1 & -1 & 0 & 0 & 1 & 0 \\ 0 & v_3 & 0 & 0 &x_2 & 0 & 1 \\ \end{bmatrix} \begin{bmatrix} ∇ x_1 \\ ∇ x_2 \\ ∇ v_1 \\ ∇ v_2 \\ ∇ y_1 \\ ∇ y_2 \\ \end{bmatrix} = \begin{bmatrix} 0 \\
0 \\ 0 \\
0 \end{bmatrix} $$

Don't be scared of our variables \(x,v\) showing up in the matrix on the left. These are just values we'll be computing during the execution of our program for some input \(x\).

In general this set of equations will have the form:

$$ \begin{bmatrix} -B & -L+I & 0 \\ -R & -T & I \end{bmatrix} \begin{bmatrix} \nabla x \\ \nabla v \\ \nabla y \end{bmatrix} = \begin{bmatrix} 0 \\ 0 \end{bmatrix} $$

where \(L\) is a lower triangular matrix (remember we only let \(v_i\) refer to previously computed \(v_j\)).

The term ∇x is simply the identity matrix \(I\). So we can further reduce this to:

$$ \begin{bmatrix} -L+I & 0 \\ -T & I \end{bmatrix} \begin{bmatrix} \nabla v \\ \nabla y \end{bmatrix} = \begin{bmatrix} B \\ R \end{bmatrix} $$

Foregoing the block matrix form we have two matrix-valued equations

$$ K ∇v = B ∇y = R + T ∇v $$

Let's let \(K = (-L+I)\), also lower triangular.

Let's consider two ways to solve these equations.

In the first way, we consider multiplying the first equation by \(K^{-1}\) on the left.

$$ K^{-1} K ∇v = K^{-1} B ∇v = K^{-1} B $$

Then we can substitute this expression for \(∇v\) in the first equation, leaving:

$$ ∇y = R + T (K^{-1} B) $$

In the second way, we consider expressing \(∇v\) in the basis formed by \(K^{-1}\). That is, we let \(∇v = K^{-1} W\) for some yet to be determined \(W\). (We know it exists as long as \(K\) is invertible). We substitute \(K^{-1} W\) for \(∇v\) in both equations:

$$ K K^{-1} W = B \\\ ∇y = R + (T K^{-1}) W $$

The first equation is just saying \(W = B\), and that means:

$$ ∇y = R + (T K^{-1}) B $$

Obviously the result of these choices is (and has to be) the same. But whether we think of first multiplying 1) \(B\) on the left or 2) T on the right by \(K^{-1}\), is essentially the difference between forward and reverse automatic differentiation.

The crucial insight is that \(K\) is lower triangular, if we are using option 1 (forward mode) then we're dividing by \(K\) on the left, this means we're solving top down or more importantly in the order in which we introduce variables in our program. So as we execute each line in our program we can simultaneously solve one more row of the system with \(K\). If the number of input variables is small (\(n << m\)) then \(B\) will have a small number of columns and keeping this running tally of solutions will be cheap.

On the other hand, if we divide by \(K\) on the right (option 2, reverse mode), then we're solving right to left or in reverse through our variables as they were introduced:

$$ T K^{-1} = Z $$

is the same as

$$ (K^{T})^{-1} T^T = Z^T $$

where \(K^T\) is upper triangular.

This means we can't trace back our derivatives until we've seen all of the instructions. In machine learning, this has been rebranded as back propagation. If the number of output variables is very small (\(m << n\)), then \(T\) will be skinny and conducting this solve will be cheap. This direction also leads itself well to directly computing the action of \(∇y\) times some vector, in machine learning the "jvp" or Jacobian vector product.

Source: "A sparse matrix approach to reverse mode automatic differentiation in Matlab", [Forth & Sharma 2010]