Jax: Computing the diagonal elements of a Hessian

Created on 20 Jul 2020  路  9Comments  路  Source: google/jax

Hi all,

I would like to use Jax to compute the diagonal elelments of a Hessian matrix, i.e second partial derivatives \partial y^2 / \partial x_j^2. What's the most efficient way to do this? I know that for columns of the Hessian, I could use Hessian-vector products, but what can I do in this case to avoid computing full Hessians?

question

All 9 comments

Unfortunately, computing the diagonal elements of a Hessian is fundamentally just as expensive as computing the full Hessian (i.e., there aren't any tricks that JAX or any other library could use). See also https://github.com/google/jax/issues/564 and https://github.com/HIPS/autograd/issues/445 for more discussion.

(Assuming there isn't any additional structure to your problem that would make it easier, such as having a component with a diagonal or otherwise sparse Jacobian)

If my jacobian is diagonal, what is the most elegant way to get the diagonal of the hessian? In autograd it was elementwise_grad twice, but for some reason I can't wrap my head around vmap and jacobian for a general form in jax...

Thanks @jekbradbury! My Jacobian does not have any sparsity structure I can exploit, unfortunetely.

@clemisch if you have a function with a diagonal Jacobian, I believe that means it must act elementwise (or act elementwise up to an additive vector constant). For such a function, the equivalent of Autograd elementwise_grad is vmap(grad(f)) where f is the version that acts on a scalar.

If my jacobian is diagonal, what is the most elegant way to get the diagonal of the hessian?

Whether you can use vmap may depend on whether your function is rank-polymorphic. Let's assume it's not, so that we have a function f modeling a function f : R^n -> R which we're promised has a diagonal Hessian.

Mathematically, if we can compute a Hessian-vector product (HVP), then we can reveal the diagonal entries of a diagonal Hessian by applying an HVP to an all-ones vector. Here's one way to do it:

from jax import jvp, grad, hessian
import jax.numpy as jnp
import numpy.random as npr

rng = npr.RandomState(0)
a = rng.randn(4)
x = rng.randn(4)

# function with diagonal Hessian that isn't rank-polymorphic
def f(x):
  assert x.ndim == 1
  return jnp.sum(jnp.tanh(a * x))


def hvp(f, x, v):
  return jvp(grad(f), (x,), (v,))[1]

print(hessian(f)(x))
print(jnp.diag(hessian(f)(x)))
print(hvp(f, x, jnp.ones_like(x)))
$ python issue3801.py
[[-0.03405464  0.          0.          0.        ]
 [ 0.          0.10269941  0.          0.        ]
 [ 0.          0.         -0.65265197  0.        ]
 [ 0.          0.          0.          2.9311912 ]]
[-0.03405464  0.10269941 -0.65265197  2.9311912 ]
[-0.03405464  0.10269941 -0.65265197  2.9311912 ]

That hvp implementation is in the autodiff cookbook.

Thank you @jekbradbury and @mattjj for the explanation!

thanks @ibulu!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

asross picture asross  路  3Comments

lhk picture lhk  路  3Comments

sussillo picture sussillo  路  3Comments

yfji picture yfji  路  3Comments

clemisch picture clemisch  路  3Comments