Jax: lax.dynamic_slice inside jit

Created on 9 Jul 2019  Â·  7Comments  Â·  Source: google/jax

Should this work?

import jax
import jax.numpy as np

@jax.jit
def sum_first_k(a, k):
  return np.sum(lax.dynamic_slice(a, (0,), (k,)))

sum_first_k(np.arange(3.0), 2)

Here's the traceback I get:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-167-645715d2be42> in <module>()
----> 1 sum_first_k(np.arange(3.0), 2)

13 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
    121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    124     return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
    125 

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    661   if top_trace is None:
    662     with new_sublevel():
--> 663       ans = primitive.impl(f, *args, **params)
    664   else:
    665     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
    604 def xla_call_impl(fun, *args, **params):
    605   device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 606   compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    607   try:
    608     return compiled_fun(*args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
    617   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    618   with core.new_master(pe.JaxprTrace, True) as master:
--> 619     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    620     assert not env  # no subtraces here (though cond might eventually need them)
    621     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    145 
    146     del gen
--> 147     ans = self.f(*args, **dict(self.params, **kwargs))
    148     del args
    149     while stack:

<ipython-input-165-9a17ef1ee2d8> in sum_first_k(a, k)
      1 @jax.jit
      2 def sum_first_k(a, k):
----> 3   return np.sum(lax.dynamic_slice(a, (0,), (k,)))

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in dynamic_slice(operand, start_indices, slice_sizes)
    607   return dynamic_slice_p.bind(
    608       operand, start_indices, slice_sizes=tuple(slice_sizes),
--> 609       operand_shape=operand.shape)
    610 
    611 def dynamic_update_slice(operand, update, start_indices):

/usr/local/lib/python3.6/dist-packages/jax/core.py in bind(self, *args, **kwargs)
    145 
    146     tracers = map(top_trace.full_raise, args)
--> 147     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    148     return full_lower(out_tracer)
    149 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
    100       tracers = map(self.instantiate_const, tracers)
    101       avals = [t.aval for t in tracers]
--> 102       out_aval = primitive.abstract_eval(*avals, **params)
    103       eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
    104       return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs)
   1405     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1406   elif least_specialized is ShapedArray:
-> 1407     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1408   elif least_specialized is UnshapedArray:
   1409     return UnshapedArray(dtype_rule(*args, **kwargs))

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _dynamic_slice_shape_rule(operand, start_indices, slice_sizes, operand_shape)
   2608            "start_indices, got start_inidices length {} and slice_sizes {}.")
   2609     raise TypeError(msg.format(len(start_indices), slice_sizes))
-> 2610   if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
   2611     msg = ("slice slice_sizes must be less than or equal to operand shape, "
   2612            "got slice_sizes {} for operand shape {}.")

/usr/local/lib/python3.6/dist-packages/jax/core.py in __bool__(self)
    340   def __getitem__(self, idx): return self.aval._getitem(self, idx)
    341   def __nonzero__(self): return self.aval._nonzero(self)
--> 342   def __bool__(self): return self.aval._bool(self)
    343   def __float__(self): return self.aval._float(self)
    344   def __int__(self): return self.aval._int(self)

/usr/local/lib/python3.6/dist-packages/jax/abstract_arrays.py in error(self, *args)
     36 def concretization_function_error(fun):
     37   def error(self, *args):
---> 38     raise TypeError(concretization_err_msg(fun))
     39   return error
     40 

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

I know XLA can't have variable sized outputs, but here I'm summing the outputs, so in principle that shouldn't be an issue.

documentation question

Most helpful comment

It would be nice if the error message said something like this, rather than
sending me down a rabbit hole. What actually happened is that I first tried
using indexing like a[:k], which generated an error encouraging me to try
lax.dyanmic_slice.

On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson notifications@github.com
wrote:

No, it shouldn't work: actually it's not just that XLA (and JAX's jit,
which is what's actually raising the error here for tracing reasons)
require fixed output shapes, but all the shapes of the intermediates need
to be fixed too. So summing the output doesn't help; that
lax.dynamic_slice alone is a problem.

Here are a two alternatives, both of which you probably know about:

from __future__ import print_functionfrom functools import partial
import jaximport jax.numpy as np
@partial(jax.jit, static_argnums=(1,))def sum_first_k(a, k):
return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))
print(sum_first_k(np.arange(3.0), 2))

@jax.jitdef sum_first_k(a, k):
n = len(a)
return np.sum(np.where(np.arange(n) < k, a, 0))
print(sum_first_k(np.arange(3.0), 2))

The first is a way of solving the problem with recompilation. The second
is a way to solve it with masking, for which XLA can still generate very
efficient code by fusing the selection into the reduction rather than
round-tripping several arrays to memory. A third strategy is to use a loop
construct.

WDYT?

—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1007?email_source=notifications&email_token=AAJJFVXUEEPZYGHSO7K6WE3P6VS7HA5CNFSM4H7I6LE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZSI4PI#issuecomment-509906493,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAJJFVWEJNLZMCW3JTWJYJTP6VS7HANCNFSM4H7I6LEQ
.

All 7 comments

No, it shouldn't work: actually it's not just that XLA (and JAX's jit, which is what's actually raising the error here for tracing reasons) require fixed output shapes, but all the shapes of the intermediates need to be fixed too. So summing the output doesn't help; that lax.dynamic_slice alone is a problem.

Here are a two alternatives, both of which you probably know about:

from __future__ import print_function
from functools import partial

import jax
import jax.numpy as np

@partial(jax.jit, static_argnums=(1,))
def sum_first_k(a, k):
  return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))

print(sum_first_k(np.arange(3.0), 2))


@jax.jit
def sum_first_k(a, k):
  n = len(a)
  return np.sum(np.where(np.arange(n) < k, a, 0))

print(sum_first_k(np.arange(3.0), 2))

The first is a way of solving the problem with recompilation. The second is a way to solve it with masking, for which XLA can still generate very efficient code by fusing the selection into the reduction rather than round-tripping several arrays to memory. A third strategy is to use a loop construct.

WDYT?

It would be nice if the error message said something like this, rather than
sending me down a rabbit hole. What actually happened is that I first tried
using indexing like a[:k], which generated an error encouraging me to try
lax.dyanmic_slice.

On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson notifications@github.com
wrote:

No, it shouldn't work: actually it's not just that XLA (and JAX's jit,
which is what's actually raising the error here for tracing reasons)
require fixed output shapes, but all the shapes of the intermediates need
to be fixed too. So summing the output doesn't help; that
lax.dynamic_slice alone is a problem.

Here are a two alternatives, both of which you probably know about:

from __future__ import print_functionfrom functools import partial
import jaximport jax.numpy as np
@partial(jax.jit, static_argnums=(1,))def sum_first_k(a, k):
return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))
print(sum_first_k(np.arange(3.0), 2))

@jax.jitdef sum_first_k(a, k):
n = len(a)
return np.sum(np.where(np.arange(n) < k, a, 0))
print(sum_first_k(np.arange(3.0), 2))

The first is a way of solving the problem with recompilation. The second
is a way to solve it with masking, for which XLA can still generate very
efficient code by fusing the selection into the reduction rather than
round-tripping several arrays to memory. A third strategy is to use a loop
construct.

WDYT?

—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1007?email_source=notifications&email_token=AAJJFVXUEEPZYGHSO7K6WE3P6VS7HA5CNFSM4H7I6LE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZSI4PI#issuecomment-509906493,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAJJFVWEJNLZMCW3JTWJYJTP6VS7HANCNFSM4H7I6LEQ
.

Is there a way to make the construct with static_argnums work with sum_first_k inside vmap?

Calling sum_first_k as a standalone function works fine:
print(sum_first_k(np.arange(3.0), 2))

But it doesn't work when calling it from vmap
vmap_sum_first_k = jax.vmap(sum_first_k,(None,0))
print(vmap_sum_first_k(np.arange(10.0), np.arange(4)))

TypeError: Abstract value passed to bool, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using jit, try using static_argnums or applying jit to smaller subfunctions instead.

It's a little awkward, but you can make something like this work using explicit masking:

@jax.jit
def sum_first_k(a, k):
  return np.sum(a * (np.arange(a.size) < k))

Ok, thanks, that works. But it is slow in my use case: a dimension is roughly 3e6 x 100 while each k is a slice, not a number, with different slice lengths where median slice length is about 50e3. So explicit masking results in huge number of multiplications by 0. Which is too bad - other parts of my program are unbelievably fast using jax with gpu, much faster than anything else I tried.

I think it's worth adding that slice_sizes needs to be static to the dynamic_slice() docstring. I can send in a PR if that sounds good, WDYT?

I ran into the same issue as shoyer@ above, where I want dynamic slice_sizes() and first tried indexing a[:k], then was told to use dynamic_slice(), got this error message, poked around a bit, and then ended up here.

I think it's worth adding that slice_sizes needs to be static to the dynamic_slice() docstring. I can send in a PR if that sounds good, WDYT?

This is a good suggestion! I added clarifications to both the dynamic_slice documentation and this error message in: https://github.com/google/jax/pull/3795

I'm close this issue, since hopefully users will see the more descriptive errors in the future and won't be misled.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

kirk86 picture kirk86  Â·  22Comments

froystig picture froystig  Â·  34Comments

dwang55 picture dwang55  Â·  22Comments

martiningram picture martiningram  Â·  21Comments

dionhaefner picture dionhaefner  Â·  22Comments