Jax: jit decorator can't accept arguments

Created on 1 Jan 2019  路  6Comments  路  Source: google/jax

The @jit decorator has an optional argument static_argnums to tell it which inputs to be treated as constant. However there is a minor problem with the decorator syntax in Python which means that you can't use it as a decorator and also pass it an argument.

Here's a simple example:

from jax.api import jit

@jit(static_argnums=(0,))
def mymap(f, x):
    return f(x)

Fails with TypeError: jit() takes at least 1 argument (1 given).

This works fine, but is a little uglier:

from jax.api import jit

def mymap(f, x):
    return f(x)
mymap = jit(mymap, static_argnums=(0,))

It looks like this is a common issue and a bit of wrapping can solve it:
https://stackoverflow.com/questions/5929107/decorators-with-parameters

Most helpful comment

Perhaps you should add a comment/example on the use of partial in the docstring of jit? I had the same question and found the answer here instead of the documentation.

All 6 comments

We've been using a pattern like this:

@partial(jit, static_argnums=(0,))
def foo(...):
  ...

What do you think of that? I kind of like it, but maybe I'm just used to it...

Ah yes, that works! I never thought of using partial() inside a decorator, to be honest.

We've been using a pattern like this:

@partial(jit, static_argnums=(0,))
def foo(...):
  ...

What do you think of that? I kind of like it, but maybe I'm just used to it...

Is there a benefit of using partial?

I see these benefits of making jit a "decorator factory":

  • More Intuitive. At least @duvenaud and me came to the issues page because we couldnt use @jit as expected
  • Similarity to the Syntax of numba.jit

For people that find this thread in a search: my jit decorator that accepts args:

from functools import wraps
from jax import jit


@wraps(jit)
def jit_args(*args, **kwargs):
    def get_jitted(fun):
        return jit(fun, *args, **kwargs)
    return get_jitted

Thanks for sharing that. I think it can be a good thing when users create convenience wrappers that suit their specific tastes and needs!

IIUC when using jit_args you have to add some extra parentheses, like this:

@jit_args()
def foo(...):
  ...

You could solve that with some fancy polymorphism, but then it starts to feel a bit complicated.

Our general philosophy is to make the JAX core API as simple and explicit as possible, even if that means some folks add convenience wrappers on top. (We want JAX to make those convenience wrappers straightforward to write!) In this case, the difference is minimal, but an upside is that all JAX transformations, even ones that don't take any parameters, work the same way.

@mattjj

Yes you UC, and i didn't think about that.
That makes sense. needing the empty paranthesis is a sizeable downside.

Thanks for the clearup.

Perhaps you should add a comment/example on the use of partial in the docstring of jit? I had the same question and found the answer here instead of the documentation.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

dwang55 picture dwang55  路  22Comments

martiningram picture martiningram  路  21Comments

NeilGirdhar picture NeilGirdhar  路  23Comments

christopherhesse picture christopherhesse  路  32Comments

JuliusKunze picture JuliusKunze  路  23Comments