I often hit messages like
"Singular value decomposition is only implemented on the CPU backend"
(https://github.com/google/jax/blob/27746b8c73f9ca9928da5da40b7382ae648a5f8d/jax/lax_linalg.py)
for many jax.numpy.linalg ops. Examples I've hit so far are when calling:
jax.numpy.linalg.svd;
jax.numpy.linalg.eigh.
It would be nice to be able to run on GPU.
Thanks!
Thanks for raising this. Given your specific request here, what do you think about changing the issue title to be a request to get linalg support on GPU? That's probably easier for us to add.
Done, thank you!
I want to know if jax.numpy.linalg.inv ops can run on the GPU?
As it happens, jax.numpy.linalg.inv should run on GPU right now! (It might not be terribly fast, since it's using a QR decomposition instead of an LU decomposition, and one that isn't necessarily that well tuned.)
Another thing missing is `np.linalg.solve':
solve = np.linalg.solve(np.eye(3), np.ones(3))
Traceback (most recent call last):
File "/opt/intellij-ue-2019.1/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
exec(exp, global_vars, local_vars)
File "<input>", line 1, in <module>
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/numpy/linalg.py", line 237, in solve
lu, pivots = lax_linalg.lu(a)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/lax_linalg.py", line 53, in lu
lu, pivots = lu_p.bind(x)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/core.py", line 136, in bind
return self.impl(*args, **kwargs)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/lax_linalg.py", line 358, in lu_impl
lu, pivot = xla.apply_primitive(lu_p, operand)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/interpreters/xla.py", line 52, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *abstract_args, **params)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/util.py", line 174, in memoized_fun
ans = cache[key] = fun(*args, **kwargs)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/interpreters/xla.py", line 58, in xla_primitive_callable
built_c = primitive_computation(prim, *shapes, **params)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/util.py", line 174, in memoized_fun
ans = cache[key] = fun(*args, **kwargs)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/interpreters/xla.py", line 69, in primitive_computation
xla_result = translation_rule(prim)(c, *xla_args, **params)
File "/usr/local/google/_blaze_romann/513cef43ffae8d7478c0c7058e6a84e4/execroot/google3/blaze-out/k8-cuda9-py3-opt/bin/experimental/users/romann/ntk_tuner/train_and_eval.runfiles/google3/third_party/py/jax/lax_linalg.py", line 363, in lu_translation_rule
"LU decomposition is only implemented on the CPU backend")
NotImplementedError: LU decomposition is only implemented on the CPU backend
There is now an LU decomposition implementation that works on GPU. However, it may not be the most performant (it's implemented in JAX itself). We still would do well to add a specialized GPU implementation that calls cuSolver or MAGMA.
(FYI, slight overlap with https://github.com/google/jax/issues/723)