Jax: Calling a jit'd function with or without kwarg name has different behavior

Created on 22 Mar 2019  路  10Comments  路  Source: google/jax

Consider this example:

def foo(x, foo_power=10):
   # ...

jit_foo = jax.jit(foo)
jit_foo(x, foo_power=11)
jit_foo(x, 11)

It would be reasonable to expect the two calls to jit_foo to have the same behavior/semantics. However, the first call with the kwarg name included will treat foo_power as a static arg, whereas the second call will treat it as a dynamic arg. This can cause errors if foo expects foo_power to be a compile-time constant (I ran into this jitting the fft method I'm working on).

A workaround is to always include kwargs in the static_argnums argument to jit. We might wanna consider doing this automatically in jit to have more predictable behavior by default.

enhancement

Most helpful comment

@mattjj, what's unclear from the docstring (and confusing) is that the keyword vs positional args distinction is _not_ made at the function definition, but rather at the call site. This makes sense from the implementation perspective, but IMO from the API perspective, it makes more sense to use how the args are specified in the definition since it can't change across calls, and especially if you use jit() as a function decorator.

So I think the kwarg behavior is fine and makes sense, but we should use introspection to do it based on the function definition, not the function call.

All 10 comments

Yeah, we should either document this or change it.

Actually, this _is_ documented in the jit docstring. So my previous comment isn't right.

Do we want to change it?

I don't see where it's documented? But I will admit I didn't read the jit docstring before filing this bug, so I'm in favor of fixing the gotcha instead of documenting it :) I don't think it's urgent though.

From https://jax.readthedocs.io/en/latest/jax.html#jax.jit:

Keyword arguments and positional arguments specified by static_argnums can be anything at all. These are treated as static (see below).

The next line explains what "static" means:

Operations that only depend on static arguments will be constant-folded. Calling the jitted function with different values for these constants will trigger recompilation.

What do you think? Clarify that more? Change the kwarg behavior?

+1 to not reading docstrings. Seriously, we don't want people to have to do that unless necessary. Both you and @sussillo pointed out this kwarg call-site behavior as surprising, so I lean towards revising it.

+1 to revising behavior. As Skye's example nicely points out, it's quite surprising to change the function behavior as a result of naming an argument.

@mattjj, what's unclear from the docstring (and confusing) is that the keyword vs positional args distinction is _not_ made at the function definition, but rather at the call site. This makes sense from the implementation perspective, but IMO from the API perspective, it makes more sense to use how the args are specified in the definition since it can't change across calls, and especially if you use jit() as a function decorator.

So I think the kwarg behavior is fine and makes sense, but we should use introspection to do it based on the function definition, not the function call.

Related code from the last time we took signature parsing pretty seriously:

https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L149-L190

It's not just jit, but also vmap, pmap, grad, and value_and_grad treat kwargs this way too (not as an accident, but as a simplifying design choice that we're now considering revising).

+1 to @skye 's comment, also "principle of least astonishment".

Was this page helpful?
0 / 5 - 0 ratings

Related issues

yfji picture yfji  路  3Comments

asross picture asross  路  3Comments

zhongwen picture zhongwen  路  3Comments

alexbw picture alexbw  路  3Comments

sussillo picture sussillo  路  3Comments