Implicit Differentiation

Reference: https://implicit-layers-tutorial.org/implicit_functions/

Implicit Differentiation

Let $f(\mathbf{a},\mathbf{z})$ be a function of the form

\[f: \mathbb{R}^p\times\mathbb{R}^n\rightarrow\mathbb{R}^n\]

We will think of $f$ as a family of functions from $\mathbb{R}^n$ to itself parameterized by $\mathbf{a}\in\mathbb{R}^p$.

By the implicit function theorem, assuming that the Jacobian $\partial_\mathbf{z}f$ is non singular, the equation

\[f(\mathbf{a}, \mathbf{z}) = 0\]

define a function

\[z^*: \mathbb{R}^p \rightarrow \mathbb{R}^n\]

which sends a parameter $\mathbf{a}\in\mathbb{R}^p$ to a solution $z^*(\mathbf{a})\in\mathbb{R}^n$ satisfying:

\[f(\mathbf{a}, z^*(\mathbf{a})) = 0\]

We’d like to compute the Jacobian

\[\partial_\mathbf{a}z^*\in\mathbb{R}^{n\times p}\]

in terms of the partial derivatives of $f$.

Let $\mathbf{a}_0\in\mathbb{R}^p$ be a parameter and set

\[\mathbf{z}_0=z^*(\mathbf{a}_0)\]

By the chain rule, the derivative of $f(\mathbf{a}, z^*(\mathbf{a}))$ at the point $\mathbf{a}_0$ is:

\[\frac{\partial}{\partial\mathbf{a}} f(\mathbf{a}_0, z^*(\mathbf{a}_0)) = \partial_\mathbf{a}f(\mathbf{a}_0, \mathbf{z}_0) + \partial_\mathbf{z}f(\mathbf{a}_0,\mathbf{z}_0)\partial_\mathbf{a}z^*(\mathbf{a}_0)\]

Rearranging:

\[\partial_\mathbf{a}z^*(\mathbf{a}_0) = -(\partial_\mathbf{z}f(\mathbf{a}_0,\mathbf{z}_0))^{-1} \partial_\mathbf{a}f(\mathbf{a}_0, \mathbf{z}_0)\]

Fixed Points

As a special case, consider the fixed point equation:

\[f(\mathbf{a}, \mathbf{z}) = \mathbf{z}\]

We can convert identify solutions to this equation as an implicit function by defining:

\[g(\mathbf{a}, \mathbf{z}) := f(\mathbf{a}, \mathbf{z}) - \mathbf{z}\]

Solutions $(\mathbf{a}_0, \mathbf{z}_0)$ to the fixed point equation satisfy:

\[g(\mathbf{a}_0, \mathbf{z}_0) = 0\]

Therefore, by the implicit function theorem we can define $z^*(\mathbf{a})\in\mathbb{R}^n$ as the fixed point corresponding to $\mathbf{a}$. Furthermore, by the previous section:

\[\begin{align*} \partial_\mathbf{a}z^*(\mathbf{a}_0) &= -(\partial_\mathbf{z}g(\mathbf{a}_0,\mathbf{z}_0))^{-1} \partial_\mathbf{a}g(\mathbf{a}_0, \mathbf{z}_0) \\ &= (\mathrm{Id} - \partial_\mathbf{z}f(\mathbf{a}_0,\mathbf{z}_0))^{-1} \partial_\mathbf{a}f(\mathbf{a}_0, \mathbf{z}_0) \end{align*}\]

Jacobian Vector Products

Let $h: \mathbb{R}^n\rightarrow\mathbb{R}^m$ be a function. The Jacobian vector product (JVP) maps a pair

\[(\mathbf{x}, \mathbf{v}) \in \mathbb{R}^n\times\mathbb{R}^n\]

to the pair

\[(h(\mathbf{x}), \partial h(\mathbf{x})\mathbf{v}) \in \mathbb{R}^m\times\mathbb{R}^m\]

Forward mode auto differentiation in Jax is implemented in terms of evaluating a JVP.

Let’s see how to evaluate the JVP of the fixed point function $z^*:\mathbb{R}^p\rightarrow\mathbb{R}^n$ from the previous section. Recall that:

\[\partial_\mathbf{a}z^*(\mathbf{a}) = (\mathrm{Id} - \partial_\mathbf{z}f(\mathbf{a},\mathbf{z}))^{-1} \partial_\mathbf{a}f(\mathbf{a}, \mathbf{z})\]

Let $v\in\mathbb{R}^p$ be a tangent vector. We can use the JVP of $f$ with respect to the argument $\mathbf{a}$ to evaluate

\[\mathbf{u} := \partial_\mathbf{a}f(\mathbf{a}, \mathbf{z})\mathbf{v}\]

We now have to compute:

\[\mathbf{w} := (\mathrm{Id} - \partial\mathbf{z}f(\mathbf{a},\mathbf{z}))^{-1}\mathbf{u}\]

Rearranging:

\[(\mathrm{Id} - \partial\mathbf{z}f(\mathbf{a},\mathbf{z}))\mathbf{w} = \mathbf{u}\]

We can solve for $\mathbf{w}$ using a linear solver such as GMRES. Note that GMRES solves equations of the form

\[A\mathbf{x} = \mathbf{b}\]

and only requires a function that can evaluate the product $A\mathbf{x}$, rather than the full matrix $A$. In this case, we can perform this evaluation using the JVP of $f$ with respect to $\mathbf{z}$.