I'm using this snippet from the README
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
combined with the "getting started with pytorch data loaders" colab. How do I compute and use the hessian of this neural network? So far I have tried:
hessian_loss = hessian(loss)(params, x, y) (TypeError: jacfwd() takes 2 positional arguments but 4 were given)hessian_loss = hessian(loss)((params, x, y)) (takes a long time with a small network and then returns AttributeError: 'PyLeaf' object has no attribute 'node_type')hessian_loss = hessian(lambda params_: loss(params_, x, y)(params) (AttributeError: 'PyLeaf' object has no attribute 'node_type')All I really need is to compute the eigenvalues and eigenvectors of this Hessian.
Note: repr(params) = [(DeviceArray{float32[10,784]}, DeviceArray{float32[10]})]
Some ideas that I have:
This is a great question, and one that we've been discussing!
One problem is that JAX gave you very unhelpful error messages for this use case. We need to improve those. Ultimately, we need the error messages to communicate that the hessian function, as with jacfwd and jacrev, only apply to array-input array-output functions (of one argument). In particular, they don't work for tuple/list/dict inputs, or with respect to multiple arguments, for the same reason in both cases.
Here's the heart of the issue: given the example params here (a one-element list of a pair of arrays), how would you want the value of hessian(lambda params: loss(params, x, y))(params) to be stored? More generally, given a tuple/list/dict argument, how should we represent the Hessian?
For arrays, there's a clear answer because it's easy to reason about adding axes. If a function fun takes arrays of shape (in_1, in_2, ..., in_n) to arrays of shape (out_1, out_2, ..., out_m), then it's reasonable for hessian(fun) to be a function that takes an array of shape (in_1, in_2, ..., in_n) to an array of shape (out_1, out_2, ..., out_m, in_1, in_2, ..., in_n, in_1, in_2, ..., in_n), though other conventions could be reasonable too. (As I wrote this, I got a sense of deja vu...)
But if fun is, say, a function that takes a tuple of scalars to a scalar, then what should hessian(fun) return? Some kind of nested tuple structure? How do we organize the nesting?
I really mean that as a question! Do you have a clear sense of what would make sense from your perspective? If there's a clear way to handle cases like these, we'll implement it!
The answer may ultimately be that we can only represent Hessians for array-input array-output functions, in which case flattening a container-input function into an array-input one, then working with the Hessian of the flattened function, may be the right answer. JAX could provide utilities for doing that (we had nice ones in Autograd, with a slick implementation).
In the meantime, flattening things yourself probably makes the most sense, unless you want to wait a few days for JAX to gain some flattening utilities. An alternative might be to use a Lanczos iteration together with a Hessian-vector product, which you can express easily in JAX. Then you'd only have to deal with vectors, rather than having to worry about how to represent matrices, and we know how to handle tuples/lists/dicts there. (But Lanczos would only be accurate for extremal eigenvalues, and its numerical effectiveness would depend on the conditioning of the Hessian, whereas direct eigenvalue algorithms would be independent of the conditioning.)
I really mean that as a question! Do you have a clear sense of what would make sense from your perspective? If there's a clear way to handle cases like these, we'll implement it!
I'm exploring JAX as a more straightforward way to do higher order derivatives. Right now I am using tensorflow (tf.hessians) but it quickly becomes clunky and it doesn't work in eager mode. My real use case is to do some analysis on the eigenvalues and eigenvectors of a neural network with a few thousand parameters.
The answer may ultimately be that we can only represent Hessians for array-input array-output functions, in which case flattening a container-input function into an array-input one, then working with the Hessian of the flattened function, may be the right answer. JAX could provide utilities for doing that (we had nice ones in Autograd, with a slick implementation).
Flattening is currently the most straightforward way to do this in pytorch (which also has nice flattening utilities) say when constructing fisher information matrices in RL. Unfortunately, flattening in tensorflow graph mode cannot be used with tf.hessians unless the flattened version was used to predict the output of the network.
I think the most useful way for hessians of arbitrary matrices or compositions on matrices would be:
[Input_i x Output_i] for i the index of the layer.jax.hessians wrto the computed loss to get a matrix of size [sum(prod(Input_i, Output_i)) x sum(prod(Input_i, Output_i))].An alternative might be to use a Lanczos iteration together with a Hessian-vector product, which you can express easily in JAX. Then you'd only have to deal with vectors, rather than having to worry about how to represent matrices, and we know how to handle tuples/lists/dicts there. (But Lanczos would only be accurate for extremal eigenvalues, and its numerical effectiveness would depend on the conditioning of the Hessian, whereas direct eigenvalue algorithms would be independent of the conditioning.)
I really want to compute the eigenspectrum (density of eigenvalues) and associated eigendirections so accuracy is important.
We had some conversations about this, and I think our plan is to:
hessian to work on containers (we chose a representation to go with).I'm glad #201 didn't close this issue, because it seems to have broken hessian! I'm looking at it now.
Most helpful comment
This is a great question, and one that we've been discussing!
One problem is that JAX gave you very unhelpful error messages for this use case. We need to improve those. Ultimately, we need the error messages to communicate that the
hessianfunction, as withjacfwdandjacrev, only apply to array-input array-output functions (of one argument). In particular, they don't work for tuple/list/dict inputs, or with respect to multiple arguments, for the same reason in both cases.Here's the heart of the issue: given the example
paramshere (a one-element list of a pair of arrays), how would you want the value ofhessian(lambda params: loss(params, x, y))(params)to be stored? More generally, given a tuple/list/dict argument, how should we represent the Hessian?For arrays, there's a clear answer because it's easy to reason about adding axes. If a function
funtakes arrays of shape(in_1, in_2, ..., in_n)to arrays of shape(out_1, out_2, ..., out_m), then it's reasonable forhessian(fun)to be a function that takes an array of shape(in_1, in_2, ..., in_n)to an array of shape(out_1, out_2, ..., out_m, in_1, in_2, ..., in_n, in_1, in_2, ..., in_n), though other conventions could be reasonable too. (As I wrote this, I got a sense of deja vu...)But if
funis, say, a function that takes a tuple of scalars to a scalar, then what shouldhessian(fun)return? Some kind of nested tuple structure? How do we organize the nesting?I really mean that as a question! Do you have a clear sense of what would make sense from your perspective? If there's a clear way to handle cases like these, we'll implement it!
The answer may ultimately be that we can only represent Hessians for array-input array-output functions, in which case flattening a container-input function into an array-input one, then working with the Hessian of the flattened function, may be the right answer. JAX could provide utilities for doing that (we had nice ones in Autograd, with a slick implementation).
In the meantime, flattening things yourself probably makes the most sense, unless you want to wait a few days for JAX to gain some flattening utilities. An alternative might be to use a Lanczos iteration together with a Hessian-vector product, which you can express easily in JAX. Then you'd only have to deal with vectors, rather than having to worry about how to represent matrices, and we know how to handle tuples/lists/dicts there. (But Lanczos would only be accurate for extremal eigenvalues, and its numerical effectiveness would depend on the conditioning of the Hessian, whereas direct eigenvalue algorithms would be independent of the conditioning.)