Jax: cholesky not working with vmap on complex arrays

Created on 6 Jan 2020  路  9Comments  路  Source: google/jax

I tried to use cholesky on complex arrays with vmap but I get Unimplemented error.

A = np.eye(3).astype(np.complex64)
A_batched = np.stack([A, A])
jax.vmap(np.linalg.cholesky)(A_batched) 

RuntimeError: Unimplemented: Complex types are not implemented in Cholesky; got shape c64[2,3,3]: 

This isn't a problem if I don't use vmap.

np.linalg.cholesky(A)
DeviceArray([[1.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 1.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 1.+0.j]], dtype=complex64)

Is this a bug or am I doing anything wrong? This looks really strange.

Most helpful comment

@am-khan Oops! Sorry: I didn't see you wanted to work on this. Thanks for volunteering!

PR #1956 fixes this, and also improves the efficiency of batched Cholesky on GPU at the same time.

All 9 comments

Thanks for the question!

What backend are you using (CPU/GPU/TPU)? (I'm guessing CPU.)

It looks like you're hitting this unimplemented error in XLA.

I think the reason it's only coming up when you use vmap is that JAX falls back to an XLA-based implementation, rather than using a LAPACK-based one as in the un-batched case.

@hawkinsp we could just write a loop in our lapack.pyx bindings to handle batch dimensions, right? I think that's how we handle batch dimensions in other kernels.

I'm going to mark this as a "good first issue" because there are plenty of examples in lapack.pyx and lax_linalg.py to pattern-match off of.

(Also I un-assigned myself because I'm not working on this right now.)

will attempt this!

Thanks for the quick responses! @mattjj @hawkinsp @am-khan Just tested on GPU and it is a CPU-only issue. Look forward to the fix.

@am-khan Oops! Sorry: I didn't see you wanted to work on this. Thanks for volunteering!

PR #1956 fixes this, and also improves the efficiency of batched Cholesky on GPU at the same time.

@am-khan while @hawkinsp got the jump on you for this issue, it looks like there's a very similar TODO for batched triangular solves.

PR #1956 fixes this but it requires a jaxlib rebuild (or waiting until we push new jaxlib wheels). Hope that helps!

thanks @mattjj is there a corresponding issue or can I just go ahead?

You can just go ahead, though if you feel the same amount of satisfaction from closing issues that I do, feel free to open one and mention that you're working on it :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

zhongwen picture zhongwen  路  3Comments

lonelykid picture lonelykid  路  3Comments

shannon63 picture shannon63  路  3Comments

yfji picture yfji  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments