Jax: autodiff support for jax.numpy.linalg.eig

Created on 17 Apr 2020  ·  34Comments  ·  Source: google/jax

Note that eigh is already taken care of.

enhancement

Most helpful comment

There is a good post on gauge problem:
https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/
Also related discussion in tensorflow
https://github.com/tensorflow/tensorflow/pull/33808

All 34 comments

I'm pretty sure we could achieve this with a minor variation of our existing JVP rule for eigh, replacing U.T.conj() -> inv(U) (of course it should really use an LU solve rather than computing the inverse directly).

Just wanted to throw in a +1 for wanting this to be implemented.

@shoyer do you have a reference for it? I've just been working through the math by hand and it seems what you said is correct, except that you have to do a slightly awkward correction to ensure that dU.T @ U has ones down the diagonal (which I think is required - this comes from the constraint that the eigenvectors are normalized). Anyway I think I will draft an implementation today.

Edit: It's implemented in Autograd https://github.com/HIPS/autograd/blob/master/autograd/numpy/linalg.py#L152-L173, with a reference to https://arxiv.org/pdf/1701.00392.pdf, eq 4.77.

Edit 2: The jvp equations in that paper are 4.60 and 4.63, but I _think_ 4.63 (the jvp for the eigenvectors) is wrong. The statement above 4.63 ("...can not influence the amplitude of the eigenvectors...") is correct but I don't think they translated that constraint correctly into math. I've tried implementing their version, and my own, neither are working yet so not 100% sure whether I'm right about this.

Also @shoyer how should I do a solve (inv(a) @ b for square matrices matrices a and b)? I think I can't use jax.numpy.linalg.solve from jax.lax_linalg because of circular dependency.

For now I'll use inv as you suggested above.

Section 3.1 from this reference in a comment under eigh_jvp_rule (in lax_linalg.py) works through the general case of how to calculate eigenvector derivative:
https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf

Also @shoyer how should I do a solve (inv(a) @ b for square matrices matrices a and b)? I think I can't use jax.numpy.linalg.solve from jax.lax_linalg because of circular dependency.

My suggestion would be to use a local import, in the JVP rule, e.g.,

def eig_jvp_rule(...):
  from jax.numpy.linalg import solve
  ...

You could also try refactoring, but this the usual hack for circular dependency challenges.

I'd be tempted to at least try the refactoring of moving the guts of solve into lax_linalg.py.

Cool, I've moved it in the draft pr, wasn't too bad to do. Still getting incorrect values for the eig derivatives though 😔.

The JVP seems to be correct now that I've relaxed the test tolerance slightly, but the VJPs are way out, I'm not sure why that is yet.

I notice that testEighGrad is currently skipped because 'Test fails with numeric errors', I wonder if the problems I'm seeing are related, since the eig jvp is mostly copied from the eigh jvp.

OK I think I know why what I have is incorrect. Eigenvectors are only unique up to (complex) scalar multiple. The eigenvectors returned by numpy.linalg.eig are normalized so that they have length 1 (I already knew this), and also so that their _largest_ component is real (see http://www.netlib.org/lapack/lapack-3.1.1/html/dgeev.f.html). That constraint I was not previously aware of and I think it might take some work to correct the derivations + implementation that I have.

A similar issue might also explain why the derivative tests for eigh are failing - the eigenvectors are normalized so they have length 1, but are still only unique up to multiplication by a complex scalar whose absolute value is 1 (i.e. there is one degree of freedom per eigenvector). It's not clear from the low level eigh docs (http://www.netlib.org/lapack/lapack-3.1.1/html/zheevd.f.html) how this non-uniqueness is addressed.

Edit: just running np.linalg.eigh on a couple of inputs it looks like the eigenvectors are normalized so that the _first_ component of each is real. It seems a bit strange that eigh uses a different convention to eig, and this means that you'll get np.linalg.eigh(x) != np.linalg.eig(x) for complex, hermitian x. The eigh convention should be easier to differentiate, and maybe we should change our eig_p primitive to match the eigh convention, so that lax_linalg.eig(x) == lax_linalg.eigh(x) for all hermitian x.

We could certainly pick a new convention for normalizing vectors from eig if that makes it easier to differentiate. The downside is that this would probably require a bit more computation. If it's only O(n^2) time, I would say definitely go for it, maybe more questionable if we need dense matrix/matrix multiplication which O(n^3). In the later case we might add an optional argument for making eig differentiate.

For what it's worth, I have _feeling_ that the right answer for how to differentiate eig/eigh in most cases is _don't_, precisely because eigen-decomposition is often not a well defined function. The right function to differentiate is something downstream of eigen-decomposition where the outputs of the numerical method become a well defined function, e.g., the result of a matrix power series calculation. If we can characterize the full set of such "well defined functions of eigen-decompositions" then perhaps those are the right primitives for which to define auto-diff rules.

Yeah I agree. It would be very weird and likely a bug if a user implemented a function that depended on the length of an eigenvector, since the normalization is essentially an implementation detail. Catering for these design decisions with correct derivatives is also really awkward, so maybe we should indeed look for another level at which to provide derivatives.

@ianwilliamson do you have a use case for eig derivatives? Would be useful to know a bit about it.

Also I think it is reasonable to support eigh of a real symmetric matrix, where there is a pretty obvious and straightforward unique value and derivative.

Also I think it is reasonable to support eigh of a real symmetric matrix, where there is a pretty obvious and straightforward unique value and derivative.

Even eigh is only uniquely defined if the eigenvalues are unique. If degeneracies are valid (common in many applications) it isn't a well defined function.

We had a related discussion when I was fixing eigh in autograd: https://github.com/HIPS/autograd/pull/527

Essentially, the vjp there works for objective functions that do not depend on the arbitrary phase of the eigenvectors (the "gauge"), and the tests are written for such functions. This is because in a general solver this phase is just arbitrary, so even finite-difference derivatives won't work, i.e. eig(X) and eig(X + dX) can spit out eigenvectors with arbitrary phase difference. It sounds like in jax you are actually setting the gauge (largest element to be real), so you could try to make the vjp account for that and match the finite-difference derivative under that gauge, but I think you can't really expect the user to know that you're doing that. Meaning that if I'm a user and a have a function that depends on the phase of an eigenvector, the correct way to do it is to manually set the gauge to whatever I want it to be, in a way tracked by jax. Or in other words: you can first get the vjp to work for gauge-independent functions, and then add the normalization on top of that.

The problem with degeneracies is harder. In one of my packages, I purposefully add small noise to avoid symmetries that result in degeneracies, but that's obviously a workaround. Here's a paper that could provide some indication on how this could be handled ("generalized gradient"), but I don't really understand it well: https://oaktrust.library.tamu.edu/bitstream/handle/1969.1/184232/document-1.pdf?sequence=1

@ianwilliamson do you have a use case for eig derivatives? Would be useful to know a bit about it.

Just in case this is helpful, I've never actually had a use case for eig derivatives per se since I've always had Hermitian matrices available, but the last time I reached for an eigh derivative it was because I needed to find a representative set of inputs yielding a singular Jacobian (a determinant derivative would have worked fine, but that was a bit slower and more unstable iirc -- I stopped searching when eigh derivatives were good enough). The scipy.linalg package was more helpful to me than the numpy wrapper because of its ability to single out a range of eigenvalues.

Most natural uses of an eig derivative I think would follow a similar pattern of having a deterministic scheme for choosing a particular eigenvalue (smallest magnitude, largest real part, etc) that relates to the problem being studied, or perhaps as inputs to a symmetric function.

I know you asked this in the context of eigenvector normalization, and fwiw I've always had to normalize them myself in whichever way suites the current problem and have never needed their derivatives except to compute higher-order derivatives of eigenvalues. Sorry I can't be more help there.

The problem with degeneracies is harder. In one of my packages, I purposefully add small noise to avoid symmetries that result in degeneracies, but that's obviously a workaround. Here's a paper that could provide some indication on how this could be handled ("generalized gradient"), but I don't really understand it well: https://oaktrust.library.tamu.edu/bitstream/handle/1969.1/184232/document-1.pdf?sequence=1

Jax doesn't support subgradients at all does it? E.g., grad(abs)(0.)==1 even though the subdifferential there is the entire closed interval [-1, 1].

Jax doesn't support subgradients at all does it? E.g., grad(abs)(0.)==1 even though the subdifferential there is the entire closed interval [-1, 1].

Ohh I see what this is about. Yeah I wouldn't expect this to be something that will be supported in jax. By the way, #3112 and #3114 might also be of interest to you.

Hello, just wondering what the status is of the implementation of the np.linalg.eig function in JAX? I am working as quantum physicist and really like the JAX library, I successfully used it for an optimization problem involving the the eigh function in a previous project. For a new project however I am dealing with non-hermitian matrices so I require the eig function.

@LuukCoopmans np.linalg.eig is implemented but its derivatives are not. Do you need to be able to differentiate eig?

@j-towns yes I need to be able differentiate it.

Cool, as you can see in the comments above, the derivative for the eigen-_vectors_ is quite awkward to get right because they’re only defined up to ‘gauge’ (that is up to multiplication by a complex scalar with absolute value 1).

@LuukCoopmans sorry to keep quizzing you, but does your objective function depend on the whole output of eig or just on the eigenvalues? The latter might be easier to support.

In the short term you might be interested in using JAX’s custom_jvp and custom_vjp for implementing your own workarounds where we haven’t managed to implement derivatives, like in this case.

@j-towns actually I find this an interesting problem, in physics the quantum wavefunction (an eigenvector) is always defined up to a 'gauge' the same way as you describe. However, this gauge is usually not important for the calculation of interested quantities (expectation values) because it gets multiplied out, like say O is a matrix and v is the eigenvector of some other matrix O' then we are interested in quantities v.T.conj()Ov. Also in my case I eventually take the absolute value squared of the eigenvector so the phase is not important. I can however see that for the derivative this might give a problem, because the gauge on the eigenvector and the derivative can come back different if I am correct?

However, this gauge is usually not important for the calculation of interested quantities (expectation values) because it gets multiplied out

This is my experience as well. I've only ever needed eigenvector derivatives in scenarios where the gauge didn't matter to the final calculation. I usually _did_ need a particular magnitude, e.g. normalizing to |v|=1; jax does however easily support differentiating that normalization step, so I'm not sure that really matters for an eig() derivative.

Also in my case I eventually take the absolute value squared of the eigenvector so the phase is not important.

This quantity isn't uniquely defined either, and it's similar to the gauge problem. Eigenvectors are only unique up to a non-zero constant multiple from the relevant field.

As I noted above in https://github.com/google/jax/issues/2748#issuecomment-627444706, I think np.linalg.eig is rarely the right level at which to calculate derivatives. We have conventions for how to pick the gauge for calculations, but those aren't necessarily consistent with derivatives. I think the problem of calculating reverse mode derivatives of eig may be fundamentally undefined from a mathematical perspective -- there does not necessarily exist a single choice of gauge for which the eig function is entirely continuous.

Instead, we need higher level auto-diff primitives, corresponding to well defined functions that are invariant of gauge. For example, we can calculate derivatives for any matrix-valued function of a Hermitian matrix (see https://github.com/FluxML/Zygote.jl/pull/355). We could add helper functions for calculating these sorts of things, ideally with support for calculating the underlying functions in flexible ways (e.g., using eig internally).

Instead, we need higher level auto-diff primitives, corresponding to well defined functions that are invariant of gauge.

That makes sense. Do you think it's still reasonable to support eigenvalue derivatives except on the measure-zero sets where they don't exist (either raising an error or providing a default value in such cases, sort of like how abs is handled)?

Do you think it's still reasonable to support eigenvalue derivatives except on the measure-zero sets where they don't exist (either raising an error or providing a default value in such cases, sort of like how abs is handled)?

Yes, absolutely!

This is basically what we do currently for eigh. If there are degeneracies, then the derivative with respect to the eigenvectors will be all NaN.

Is there a straightforward way for us to provide eigenvalue derivatives without providing eigenvector derivatives (since this gauge issue only affects evectors afaict)? Do you think we ought to have a primitive which only returns eigenvalues?

Is there a straightforward way for us to provide eigenvalue derivatives without providing eigenvector derivatives (since this gauge issue only affects evectors afaict)? Do you think we ought to have a primitive which only returns eigenvalues?

:+1: That would definitely solve my problem!

I'm working on a project in which we would like to compute gradients for a function that depends on eigenvalues of non-hermitian matrices (but not eigenvectors). From what I understand, the difficulty lies in computing gradients for the eigenvectors of eig due to ambiguity in the phase.
Would it be possible to implement gradients only for eigvals first (which internally calls eig without computing eigenvectors and only returns the eigenvalues.)

I think this would already cover a large fraction of applications in theoretical/mathematical physics.

Hey @nikikilbertus, this has become straightforward since I wrote that comment because the JAX eig primitive now has kwargs to turn off the computation of eigenvectors. https://github.com/google/jax/pull/4941 should do the trick for you, just make sure your objective function uses jax.numpy.linalg.eigvals, rather than jax.numpy.linalg.eig.

Note that unfortunately second (and higher) derivatives aren't supported, I hope that's good enough to get you somewhere. If you need second derivs that should be possible but might be a bit more tricky to implement.

Thanks so much @j-towns this is great! 👏

There is a good post on gauge problem:
https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/
Also related discussion in tensorflow
https://github.com/tensorflow/tensorflow/pull/33808

Was this page helpful?
0 / 5 - 0 ratings