Jax: Compatibility with original numpy and scipy

Created on 15 Oct 2019  路  3Comments  路  Source: google/jax

The embedded DeviceArray is not fully compatibility with the original numpy array.

When I pass DeviceArray and gradient functions to the original scipy.optimize.minimize. It works with some of the minimizers while some others not (BFGS ok, L-BFGS-B not). It seems like DeviceArray breaks compatibility with scipy.

I could wrap it with onp.asarray(). Does this create a significant overhead? If so, what's the most efficient way to achieve maximum compatibility with the original numpy and scipy? Since you do not plan to completely implement scipy, it would be great if I could use Jax with the original scipy at a low cost.

Thanks!

Most helpful comment

@shoyer Thanks for the reply! It would really exciting to have that implemented.

In case anyone interested, to make L-BFGS-B work for now, one need to wrap jacobian DeviceArray with onp.array() to make a copy. With onp.asarray, it will complain about array not being Fortran contiguous, which is a misleading error. The actually problem is the array being readonly. https://github.com/pybind/pybind11/issues/935

As for the time consumption, not bad, but not negligible.
````
x = np.ones(100)
myfunc = lambda x: np.sum(x**2)
gfunc = grad(myfunc)
jit_gfunc = jit(gfunc)
jit_gfunc_np1 = lambda x:onp.array(jit_gfunc(x))
jit_gfunc_np2 = lambda x:onp.asarray(jit_gfunc(x))

%timeit gfunc(x)
%timeit jit_gfunc(x)
%timeit jit_gfunc_np1(x)
%timeit jit_gfunc_np2(x)
````
100 loops, best of 3: 3.59 ms per loop
1000 loops, best of 3: 288 碌s per loop
1000 loops, best of 3: 412 碌s per loop
1000 loops, best of 3: 417 碌s per loop

All 3 comments

I've also noticed that SciPy's L-BFGS-B can't handle DeviceArray objects.

But this isn't really an issue with JAX. SciPy needs to be updated internally to call asarray() on arrays returned from functions being optimized.

Also, to be clear we are interested in implementing a much larger portion of SciPy, including quasi-Newton optimization methods: https://github.com/google/jax/issues/1400

@shoyer Thanks for the reply! It would really exciting to have that implemented.

In case anyone interested, to make L-BFGS-B work for now, one need to wrap jacobian DeviceArray with onp.array() to make a copy. With onp.asarray, it will complain about array not being Fortran contiguous, which is a misleading error. The actually problem is the array being readonly. https://github.com/pybind/pybind11/issues/935

As for the time consumption, not bad, but not negligible.
````
x = np.ones(100)
myfunc = lambda x: np.sum(x**2)
gfunc = grad(myfunc)
jit_gfunc = jit(gfunc)
jit_gfunc_np1 = lambda x:onp.array(jit_gfunc(x))
jit_gfunc_np2 = lambda x:onp.asarray(jit_gfunc(x))

%timeit gfunc(x)
%timeit jit_gfunc(x)
%timeit jit_gfunc_np1(x)
%timeit jit_gfunc_np2(x)
````
100 loops, best of 3: 3.59 ms per loop
1000 loops, best of 3: 288 碌s per loop
1000 loops, best of 3: 412 碌s per loop
1000 loops, best of 3: 417 碌s per loop

For the record as he took me a bit of time to use the proposed work-around, it seems like one also needs to switch to double precision to allow using lbfgsb with jax (See also https://github.com/google/jax/issues/936 and https://github.com/scipy/scipy/issues/5832). Here is a short self-contained use case:

import jax
import jax.numpy as jnp
import numpy as onp
import scipy

# This needs to run at startup
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision
jax.config.update('jax_enable_x64', True)

def run(np,optname):
    print(f"\nRunning {optname} on {np}")
    def rosen(x):
        return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

    # double precision is needed for lbfgsb and probably other optimisers
    # https://github.com/google/jax/issues/936
    # https://github.com/scipy/scipy/issues/5832
    x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float64')

    optopt = {'disp': True}

    if np==onp:
      return scipy.optimize.minimize(rosen, x0, method=optname, options=optopt)
    else:
      g_rosen = jax.grad(rosen)

      g_rosen_used = g_rosen
      if optname=='L-BFGS-B':
        # Need to make sure the data is copy from jax for LBFGSB
        # https://github.com/google/jax/issues/1510
        # asarray is not sufficient
        # g_rosen_as_np = lambda x:onp.asarray(jax.jit(g_rosen)(x))
        g_rosen_np = lambda x:onp.array(jax.jit(g_rosen)(x))
        g_rosen_used = g_rosen_np

      return scipy.optimize.minimize(rosen, x0, jac=g_rosen_used, method=optname, options=optopt)


print(run(onp,'BFGS').x)
print(run(jnp,'BFGS').x)
print(run(onp,'L-BFGS-B').x)
print(run(jnp,'L-BFGS-B').x)
Was this page helpful?
0 / 5 - 0 ratings