Jax: No autodiff on linalg.solve

Created on 9 Aug 2019  路  3Comments  路  Source: google/jax

Dear jax team,

after recent improvements to linalg.solve (thanks again!) I found that it has some issues with being autodiff'ed. I have not yet digged into the error messages, but they rather sound like a minute error in handling the shapes than some fundamental limitation.

import jax
import jax.numpy as np
import numpy as onp

@jax.jit
def solve(A, b):
    return np.linalg.solve(A, b)

def to_gpu(arr):
    return jax.device_put(arr.astype(onp.float32))

A = to_gpu(onp.random.rand(100, 3, 3))
b = to_gpu(onp.random.rand(100, 3))

x = solve(A, b)
assert onp.allclose(jax.vmap(np.dot)(A, x), b, atol=1e-4, rtol=1e-4)

print("# BATCHED")
try:
    jac0 = jax.jacobian(solve, argnums=0)(A, b)  # error
except Exception as e: 
    print(e) 
# triangular_solve requires both arguments to have the same number of dimensions and equal batch dimensions, got (100, 3, 3) and (100, 100, 3, 3)

try:
    jac1 = jax.jacobian(solve, argnums=1)(A, b)  # error
except Exception as e:
    print(e)
# 'Zero' object is not subscriptable

print("# SINGLE")
try:
    jac0 = jax.jacobian(solve, argnums=0)(A[0], b[0])  # fine
except Exception as e:
    print(e)

try:
    jac1 = jax.jacobian(solve, argnums=1)(A[0], b[0])  # error
except Exception as e:
    print(e)
# 'Zero' object is not subscriptable
bug

All 3 comments

Ok, please try it out! Until we make another Pypi release, you'll need to use Python part of jax from Github head, i.e.,

git clone https://github.com/google/jax.git
pip install -e jax

My code example now works for me :+1: Thank you very much!

Was this page helpful?
0 / 5 - 0 ratings