Currently, the following script
from jax.scipy.linalg import solve_triangular
import jax.numpy as jnp
dim = 130
x = jnp.broadcast_to(jnp.eye(dim), (2, dim, dim))
solve_triangular(x, x) # error
raises the issue: RuntimeError: Invalid argument: The rank of the operand and the padding configuration do not match: f32[126,126] vs dimensions { } dimensions { edge_padding_low: 2 } dimensions { }.:
This does not fail when
dim <= 128(2,) in the above script)I think this is a duplicate of https://github.com/google/jax/issues/4773 and already fixed in jaxlib head.
(We're also in the process of pushing new wheels, so hopefully the deployed wheels also.)
Please reopen if that turns out not to be true!
Oh, thanks @hawkinsp! What a coincidence. :)
I think jaxlib 0.1.57 may already be available on pypi, so try a pip install --upgrade jaxlib and see if the issue persists.
Thanks, Matt! I can confirm that the issue has been resolved.
Most helpful comment
I think jaxlib 0.1.57 may already be available on pypi, so try a
pip install --upgrade jaxliband see if the issue persists.