The @jit decorator has an optional argument static_argnums to tell it which inputs to be treated as constant. However there is a minor problem with the decorator syntax in Python which means that you can't use it as a decorator and also pass it an argument.
Here's a simple example:
from jax.api import jit
@jit(static_argnums=(0,))
def mymap(f, x):
return f(x)
Fails with TypeError: jit() takes at least 1 argument (1 given).
This works fine, but is a little uglier:
from jax.api import jit
def mymap(f, x):
return f(x)
mymap = jit(mymap, static_argnums=(0,))
It looks like this is a common issue and a bit of wrapping can solve it:
https://stackoverflow.com/questions/5929107/decorators-with-parameters
We've been using a pattern like this:
@partial(jit, static_argnums=(0,))
def foo(...):
...
What do you think of that? I kind of like it, but maybe I'm just used to it...
Ah yes, that works! I never thought of using partial() inside a decorator, to be honest.
We've been using a pattern like this:
@partial(jit, static_argnums=(0,)) def foo(...): ...What do you think of that? I kind of like it, but maybe I'm just used to it...
Is there a benefit of using partial?
I see these benefits of making jit a "decorator factory":
@jit as expectednumba.jitFor people that find this thread in a search: my jit decorator that accepts args:
from functools import wraps
from jax import jit
@wraps(jit)
def jit_args(*args, **kwargs):
def get_jitted(fun):
return jit(fun, *args, **kwargs)
return get_jitted
Thanks for sharing that. I think it can be a good thing when users create convenience wrappers that suit their specific tastes and needs!
IIUC when using jit_args you have to add some extra parentheses, like this:
@jit_args()
def foo(...):
...
You could solve that with some fancy polymorphism, but then it starts to feel a bit complicated.
Our general philosophy is to make the JAX core API as simple and explicit as possible, even if that means some folks add convenience wrappers on top. (We want JAX to make those convenience wrappers straightforward to write!) In this case, the difference is minimal, but an upside is that all JAX transformations, even ones that don't take any parameters, work the same way.
@mattjj
Yes you UC, and i didn't think about that.
That makes sense. needing the empty paranthesis is a sizeable downside.
Thanks for the clearup.
Perhaps you should add a comment/example on the use of partial in the docstring of jit? I had the same question and found the answer here instead of the documentation.
Most helpful comment
Perhaps you should add a comment/example on the use of
partialin the docstring of jit? I had the same question and found the answer here instead of the documentation.