Hi all,
Following up on #2197, I was hoping to implement natural gradient descent or Gauss-Newton for DNNs, for which I have to invert a large p-by-p matrix, where p is the number of parameters of my neural network. I understand that if I only want to perform a small number of gradient steps, I could implement H^-1 * g by using conjugate gradient descent, but I'm not sure how to best do that in JAX. Is there an existing implementation of CG in JAX? I haven't been able to find one. In general, am I right to assume that using the full inverse Hessian as a preconditioner should be feasible with CG in JAX?
Thank you!
I have a (mostly) working version of CG in JAX with preconditioner support in:
https://github.com/google/jax/pull/2566
Hopefully we'll be able to merge that in soon.
If you're feeling ambitious, you could give it a spin.
Thanks for the quick reply @shoyer! I'll check it out!
Are there any existing issues with the CG implementation that I should be aware of in case I run into problems?
We just merged in the cg implementation, so you can try it out now if you install JAX from source. All the limitations _should_ be documented in the docstring, but please let us know if anything surprises you.
Most helpful comment
I have a (mostly) working version of CG in JAX with preconditioner support in:
https://github.com/google/jax/pull/2566
Hopefully we'll be able to merge that in soon.
If you're feeling ambitious, you could give it a spin.