I have not used JAX enough.
In my understanding to be able to use JAX, for instance in order to get the gradient of a function, one needs to define a function using jax.numpy functions instead of numpy.
How do I avoid double declarations?
Some would suggest to use only JAX, but, aside performance and correctness concerns, I have concerns about behavior: do jax.numpy function behave identically to numpy functions? What happens when numpy functions are updated to change some behaviors?
Wouldn't it be possible to create an instance of the function for JAX purposes with a function which would inspect the code and replace the numpy or scipy methods with JAX analogs?
That's a really interesting idea, and it could be quite useful, although it would be a challenge to implement it well.
I was thinking about whether this could be done with no change to the jax source; this is a terrible hack that is quite brittle, but something like this will work in simple cases:
import inspect
import numpy as np
from functools import wraps
def jaxify(func):
import jax.numpy
namespace = func.__globals__.copy()
namespace['np'] = namespace['numpy'] = jax.numpy
namespace['jaxify'] = lambda func: func
source = inspect.getsource(func)
exec(source, namespace)
return wraps(func)(namespace[func.__name__])
@jaxify
def my_func(N):
return np.arange(N).sum()
my_func(10)
# DeviceArray(45, dtype=int32)
An alternative approach to overloading NumPy code in-place was explored in #1565 and prototyped in #611. If you're interested in that functionality, please chime in on #1565, since a major reason it wasn't merged was a relative lack of interested users (compared to the added complexity).
Numba does tricks like this inside numba.jit and it seems to work pretty well for their users.
That said, I think this would be very hard to do in JAX because we occasionally see people using original NumPy inside JAX functions. It's also not very explicit or composable.
I do like the idea of trying to support overrides of NumPy's API via NumPy's own protocols (#1565), which would at least solve most of the composability issues.
Most helpful comment
That's a really interesting idea, and it could be quite useful, although it would be a challenge to implement it well.
I was thinking about whether this could be done with no change to the jax source; this is a terrible hack that is quite brittle, but something like this will work in simple cases: