Jax: Inverse Fisher/Hessian as GD Preconditioner

Created on 31 Mar 2020  路  3Comments  路  Source: google/jax

Hi all,

Following up on #2197, I was hoping to implement natural gradient descent or Gauss-Newton for DNNs, for which I have to invert a large p-by-p matrix, where p is the number of parameters of my neural network. I understand that if I only want to perform a small number of gradient steps, I could implement H^-1 * g by using conjugate gradient descent, but I'm not sure how to best do that in JAX. Is there an existing implementation of CG in JAX? I haven't been able to find one. In general, am I right to assume that using the full inverse Hessian as a preconditioner should be feasible with CG in JAX?

Thank you!

Most helpful comment

I have a (mostly) working version of CG in JAX with preconditioner support in:
https://github.com/google/jax/pull/2566

Hopefully we'll be able to merge that in soon.

If you're feeling ambitious, you could give it a spin.

All 3 comments

I have a (mostly) working version of CG in JAX with preconditioner support in:
https://github.com/google/jax/pull/2566

Hopefully we'll be able to merge that in soon.

If you're feeling ambitious, you could give it a spin.

Thanks for the quick reply @shoyer! I'll check it out!

Are there any existing issues with the CG implementation that I should be aware of in case I run into problems?

We just merged in the cg implementation, so you can try it out now if you install JAX from source. All the limitations _should_ be documented in the docstring, but please let us know if anything surprises you.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  26Comments

murphyk picture murphyk  路  31Comments

dwang55 picture dwang55  路  22Comments

dionhaefner picture dionhaefner  路  22Comments

kirk86 picture kirk86  路  22Comments