Jax: Suggestion: expose `_jit_is_disabled` utility

Created on 8 Aug 2020  路  4Comments  路  Source: google/jax

Hi there,

This is just a small comment / suggestion. I just found the _jit_is_disabled() utility and I'm finding it very useful to do some asserts that aren't possible in jit context.

The general structure:

@jax.jit
def func(arr):
    if jax.api._jit_is_disabled():
        assert arr.min() > 0, "data-dependent check failed"
    # implement the rest of the function
    ...

Are you planning on making _jit_is_disabled a first-class citizen in the jax api, e.g. jax.jit_is_disabled()?

Cheers!

enhancement

Most helpful comment

I'm not a big fan of how it looks, but the try-and-fail feels more robust in this case.

For what it's worth, the ability to switch between concrete and abstract contexts is what I love most about JAX. It might be worth emphasizing this feature (to win the hearts and minds of those people who are still on the fence).

All 4 comments

The predicate _jit_is_disabled() does not imply that arrays are concrete. In your example, this means that the expression under assert might not work as intended. For instance, suppose we vmap the function you've defined, as in:

@jax.vmap
@jax.jit
def func(arr): ...

and then run:

with jax.disable_jit():
  func(jax.numpy.ones((3, 4)))

This will result in an error:

jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
[...]

Even though jit is effectively off, vmap will still lead to abstract interpretation of func, and arr will be an abstract tracer value at the point of the data-dependent check.

We don't plan to make this particular function part of the public API. More broadly, JAX's interface intentionally offers no means of asking whether a function is being evaluated under a transformation. In part, doing so tends against JAX's central functional programming requirement (and indeed func behaves differently based on state outside the function).

I see, right. I should really be checking if an input is concrete.

What's the preferred way to check if an array is concrete?

# this feels a bit clunky
concrete = (jax.xla.DeviceArray, jax.abstract_arrays.ConcreteArray)
if isinstance(x, concrete):
    ...

# or would you try and fail silently?
try:
    ...
except jax.core.ConcretizationTypeError:
    pass

Thanks

I think we've recommended the try-and-fail (i.e. EAFP) approach before.

I wonder if assertions are a special case here, and we should do something to enable them. This is a good example...

I'm not a big fan of how it looks, but the try-and-fail feels more robust in this case.

For what it's worth, the ability to switch between concrete and abstract contexts is what I love most about JAX. It might be worth emphasizing this feature (to win the hearts and minds of those people who are still on the fence).

Was this page helpful?
0 / 5 - 0 ratings