Hi,
I noticed some significant slowdowns in my code from using jax.numpy instead of numpy and from the other issues it seems the solution is to use jit. However, when I try to use jit in a single script file for testing purposes it seems to work, but when I separate the function that I want to jit into another class I have problems.
import jax.numpy as np
import numpy as onp
from jax import jit, jacfwd, grad
from jax.numpy import sin, cos, exp
class odes:
def __init__(self):
print("odes file initialized")
@jit
def simpleODE(self, t,q):
return np.array([[q[1]], [cos(q[0])]])
from odes import *
from jax import jit, jacfwd, grad
ODE = odes()
Jac = jacfwd(ODE.simpleODE, argnums = (1,))
q = np.ones(2)
A = Jac(0,q)
print(A)
gives the following error,
TypeError: Argument '
You might be able to work with this pattern:
from functools import partial
class odes:
def __init__(self):
print("odes file initialized")
@partial(jit, static_argnums=(0,))
def simpleODE(self, t, q):
return np.array([[q[1]], [cos(q[0])]])
In words, we're marking the first argument (index 0) as a static argument.
What do you think?
Ah works, for me.
Thanks for the help! Was it due to some interaction between the JAX wrapper and the self object?
Glad to hear that helped!
Yes, the issue is that jit only knows how to compile numerical computations on arrays (i.e. what XLA can do), not arbitrary Python computations. In particular that means it only knows how to work with array data types, not arbitrary classes, and in this case the self argument is an instance of ode. By using static_argnums we're telling jit to compile only the computation that gets applied to the other arguments, and just to re-trace and re-compile every time the first argument changes its Python object id. That re-tracing basically means jit lets Python handle everything to do with the self argument.
Partial decorator solved my problem! I was also trying to use syntax like jacfwd(test_class.test_func, argnums=[1]), where test_func is defined as def test_func(self, x), to avoid considering the self argument, but that returned me tuple out of range. How should I fix this error?
@mattjj any pointers as to how it would be done for jax.grad?
Most helpful comment
You might be able to work with this pattern:
In words, we're marking the first argument (index 0) as a static argument.
What do you think?