Dear jax team,
after recent improvements to linalg.solve (thanks again!) I found that it has some issues with being autodiff'ed. I have not yet digged into the error messages, but they rather sound like a minute error in handling the shapes than some fundamental limitation.
import jax
import jax.numpy as np
import numpy as onp
@jax.jit
def solve(A, b):
return np.linalg.solve(A, b)
def to_gpu(arr):
return jax.device_put(arr.astype(onp.float32))
A = to_gpu(onp.random.rand(100, 3, 3))
b = to_gpu(onp.random.rand(100, 3))
x = solve(A, b)
assert onp.allclose(jax.vmap(np.dot)(A, x), b, atol=1e-4, rtol=1e-4)
print("# BATCHED")
try:
jac0 = jax.jacobian(solve, argnums=0)(A, b) # error
except Exception as e:
print(e)
# triangular_solve requires both arguments to have the same number of dimensions and equal batch dimensions, got (100, 3, 3) and (100, 100, 3, 3)
try:
jac1 = jax.jacobian(solve, argnums=1)(A, b) # error
except Exception as e:
print(e)
# 'Zero' object is not subscriptable
print("# SINGLE")
try:
jac0 = jax.jacobian(solve, argnums=0)(A[0], b[0]) # fine
except Exception as e:
print(e)
try:
jac1 = jax.jacobian(solve, argnums=1)(A[0], b[0]) # error
except Exception as e:
print(e)
# 'Zero' object is not subscriptable
PR https://github.com/google/jax/pull/1152 should fix this.
Ok, please try it out! Until we make another Pypi release, you'll need to use Python part of jax from Github head, i.e.,
git clone https://github.com/google/jax.git
pip install -e jax
My code example now works for me :+1: Thank you very much!