I tried to use cholesky on complex arrays with vmap but I get Unimplemented error.
A = np.eye(3).astype(np.complex64)
A_batched = np.stack([A, A])
jax.vmap(np.linalg.cholesky)(A_batched)
RuntimeError: Unimplemented: Complex types are not implemented in Cholesky; got shape c64[2,3,3]:
This isn't a problem if I don't use vmap.
np.linalg.cholesky(A)
DeviceArray([[1.+0.j, 0.+0.j, 0.+0.j],
[0.+0.j, 1.+0.j, 0.+0.j],
[0.+0.j, 0.+0.j, 1.+0.j]], dtype=complex64)
Is this a bug or am I doing anything wrong? This looks really strange.
Thanks for the question!
What backend are you using (CPU/GPU/TPU)? (I'm guessing CPU.)
It looks like you're hitting this unimplemented error in XLA.
I think the reason it's only coming up when you use vmap is that JAX falls back to an XLA-based implementation, rather than using a LAPACK-based one as in the un-batched case.
@hawkinsp we could just write a loop in our lapack.pyx bindings to handle batch dimensions, right? I think that's how we handle batch dimensions in other kernels.
I'm going to mark this as a "good first issue" because there are plenty of examples in lapack.pyx and lax_linalg.py to pattern-match off of.
(Also I un-assigned myself because I'm not working on this right now.)
will attempt this!
Thanks for the quick responses! @mattjj @hawkinsp @am-khan Just tested on GPU and it is a CPU-only issue. Look forward to the fix.
@am-khan Oops! Sorry: I didn't see you wanted to work on this. Thanks for volunteering!
PR #1956 fixes this, and also improves the efficiency of batched Cholesky on GPU at the same time.
@am-khan while @hawkinsp got the jump on you for this issue, it looks like there's a very similar TODO for batched triangular solves.
PR #1956 fixes this but it requires a jaxlib rebuild (or waiting until we push new jaxlib wheels). Hope that helps!
thanks @mattjj is there a corresponding issue or can I just go ahead?
You can just go ahead, though if you feel the same amount of satisfaction from closing issues that I do, feel free to open one and mention that you're working on it :)
Most helpful comment
@am-khan Oops! Sorry: I didn't see you wanted to work on this. Thanks for volunteering!
PR #1956 fixes this, and also improves the efficiency of batched Cholesky on GPU at the same time.