Hi all,
I would like to use Jax to compute the diagonal elelments of a Hessian matrix, i.e second partial derivatives \partial y^2 / \partial x_j^2. What's the most efficient way to do this? I know that for columns of the Hessian, I could use Hessian-vector products, but what can I do in this case to avoid computing full Hessians?
Unfortunately, computing the diagonal elements of a Hessian is fundamentally just as expensive as computing the full Hessian (i.e., there aren't any tricks that JAX or any other library could use). See also https://github.com/google/jax/issues/564 and https://github.com/HIPS/autograd/issues/445 for more discussion.
(Assuming there isn't any additional structure to your problem that would make it easier, such as having a component with a diagonal or otherwise sparse Jacobian)
If my jacobian is diagonal, what is the most elegant way to get the diagonal of the hessian? In autograd it was elementwise_grad twice, but for some reason I can't wrap my head around vmap and jacobian for a general form in jax...
Thanks @jekbradbury! My Jacobian does not have any sparsity structure I can exploit, unfortunetely.
@clemisch if you have a function with a diagonal Jacobian, I believe that means it must act elementwise (or act elementwise up to an additive vector constant). For such a function, the equivalent of Autograd elementwise_grad is vmap(grad(f)) where f is the version that acts on a scalar.
If my jacobian is diagonal, what is the most elegant way to get the diagonal of the hessian?
Whether you can use vmap may depend on whether your function is rank-polymorphic. Let's assume it's not, so that we have a function f modeling a function f : R^n -> R which we're promised has a diagonal Hessian.
Mathematically, if we can compute a Hessian-vector product (HVP), then we can reveal the diagonal entries of a diagonal Hessian by applying an HVP to an all-ones vector. Here's one way to do it:
from jax import jvp, grad, hessian
import jax.numpy as jnp
import numpy.random as npr
rng = npr.RandomState(0)
a = rng.randn(4)
x = rng.randn(4)
# function with diagonal Hessian that isn't rank-polymorphic
def f(x):
assert x.ndim == 1
return jnp.sum(jnp.tanh(a * x))
def hvp(f, x, v):
return jvp(grad(f), (x,), (v,))[1]
print(hessian(f)(x))
print(jnp.diag(hessian(f)(x)))
print(hvp(f, x, jnp.ones_like(x)))
$ python issue3801.py
[[-0.03405464 0. 0. 0. ]
[ 0. 0.10269941 0. 0. ]
[ 0. 0. -0.65265197 0. ]
[ 0. 0. 0. 2.9311912 ]]
[-0.03405464 0.10269941 -0.65265197 2.9311912 ]
[-0.03405464 0.10269941 -0.65265197 2.9311912 ]
That hvp implementation is in the autodiff cookbook.
Thank you @jekbradbury and @mattjj for the explanation!
thanks @ibulu!