Should this work?
import jax
import jax.numpy as np
@jax.jit
def sum_first_k(a, k):
return np.sum(lax.dynamic_slice(a, (0,), (k,)))
sum_first_k(np.arange(3.0), 2)
Here's the traceback I get:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-167-645715d2be42> in <module>()
----> 1 sum_first_k(np.arange(3.0), 2)
13 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
121 _check_args(args_flat)
122 flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123 out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
124 return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
125
/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
661 if top_trace is None:
662 with new_sublevel():
--> 663 ans = primitive.impl(f, *args, **params)
664 else:
665 tracers = map(top_trace.full_raise, args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
604 def xla_call_impl(fun, *args, **params):
605 device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 606 compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
607 try:
608 return compiled_fun(*args)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
206 if len(cache) > max_size:
207 cache.popitem(last=False)
--> 208 ans = call(f, *args)
209 cache[key] = (ans, f)
210 return ans
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
617 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
618 with core.new_master(pe.JaxprTrace, True) as master:
--> 619 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
620 assert not env # no subtraces here (though cond might eventually need them)
621 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
145
146 del gen
--> 147 ans = self.f(*args, **dict(self.params, **kwargs))
148 del args
149 while stack:
<ipython-input-165-9a17ef1ee2d8> in sum_first_k(a, k)
1 @jax.jit
2 def sum_first_k(a, k):
----> 3 return np.sum(lax.dynamic_slice(a, (0,), (k,)))
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in dynamic_slice(operand, start_indices, slice_sizes)
607 return dynamic_slice_p.bind(
608 operand, start_indices, slice_sizes=tuple(slice_sizes),
--> 609 operand_shape=operand.shape)
610
611 def dynamic_update_slice(operand, update, start_indices):
/usr/local/lib/python3.6/dist-packages/jax/core.py in bind(self, *args, **kwargs)
145
146 tracers = map(top_trace.full_raise, args)
--> 147 out_tracer = top_trace.process_primitive(self, tracers, kwargs)
148 return full_lower(out_tracer)
149
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
100 tracers = map(self.instantiate_const, tracers)
101 avals = [t.aval for t in tracers]
--> 102 out_aval = primitive.abstract_eval(*avals, **params)
103 eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
104 return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs)
1405 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
1406 elif least_specialized is ShapedArray:
-> 1407 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
1408 elif least_specialized is UnshapedArray:
1409 return UnshapedArray(dtype_rule(*args, **kwargs))
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _dynamic_slice_shape_rule(operand, start_indices, slice_sizes, operand_shape)
2608 "start_indices, got start_inidices length {} and slice_sizes {}.")
2609 raise TypeError(msg.format(len(start_indices), slice_sizes))
-> 2610 if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
2611 msg = ("slice slice_sizes must be less than or equal to operand shape, "
2612 "got slice_sizes {} for operand shape {}.")
/usr/local/lib/python3.6/dist-packages/jax/core.py in __bool__(self)
340 def __getitem__(self, idx): return self.aval._getitem(self, idx)
341 def __nonzero__(self): return self.aval._nonzero(self)
--> 342 def __bool__(self): return self.aval._bool(self)
343 def __float__(self): return self.aval._float(self)
344 def __int__(self): return self.aval._int(self)
/usr/local/lib/python3.6/dist-packages/jax/abstract_arrays.py in error(self, *args)
36 def concretization_function_error(fun):
37 def error(self, *args):
---> 38 raise TypeError(concretization_err_msg(fun))
39 return error
40
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
I know XLA can't have variable sized outputs, but here I'm summing the outputs, so in principle that shouldn't be an issue.
No, it shouldn't work: actually it's not just that XLA (and JAX's jit, which is what's actually raising the error here for tracing reasons) require fixed output shapes, but all the shapes of the intermediates need to be fixed too. So summing the output doesn't help; that lax.dynamic_slice alone is a problem.
Here are a two alternatives, both of which you probably know about:
from __future__ import print_function
from functools import partial
import jax
import jax.numpy as np
@partial(jax.jit, static_argnums=(1,))
def sum_first_k(a, k):
return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))
print(sum_first_k(np.arange(3.0), 2))
@jax.jit
def sum_first_k(a, k):
n = len(a)
return np.sum(np.where(np.arange(n) < k, a, 0))
print(sum_first_k(np.arange(3.0), 2))
The first is a way of solving the problem with recompilation. The second is a way to solve it with masking, for which XLA can still generate very efficient code by fusing the selection into the reduction rather than round-tripping several arrays to memory. A third strategy is to use a loop construct.
WDYT?
It would be nice if the error message said something like this, rather than
sending me down a rabbit hole. What actually happened is that I first tried
using indexing like a[:k], which generated an error encouraging me to try
lax.dyanmic_slice.
On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson notifications@github.com
wrote:
No, it shouldn't work: actually it's not just that XLA (and JAX's jit,
which is what's actually raising the error here for tracing reasons)
require fixed output shapes, but all the shapes of the intermediates need
to be fixed too. So summing the output doesn't help; that
lax.dynamic_slice alone is a problem.Here are a two alternatives, both of which you probably know about:
from __future__ import print_functionfrom functools import partial
import jaximport jax.numpy as np
@partial(jax.jit, static_argnums=(1,))def sum_first_k(a, k):
return np.sum(jax.lax.dynamic_slice(a, (0,), (k,)))
print(sum_first_k(np.arange(3.0), 2))@jax.jitdef sum_first_k(a, k):
n = len(a)
return np.sum(np.where(np.arange(n) < k, a, 0))
print(sum_first_k(np.arange(3.0), 2))The first is a way of solving the problem with recompilation. The second
is a way to solve it with masking, for which XLA can still generate very
efficient code by fusing the selection into the reduction rather than
round-tripping several arrays to memory. A third strategy is to use a loop
construct.WDYT?
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1007?email_source=notifications&email_token=AAJJFVXUEEPZYGHSO7K6WE3P6VS7HA5CNFSM4H7I6LE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZSI4PI#issuecomment-509906493,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAJJFVWEJNLZMCW3JTWJYJTP6VS7HANCNFSM4H7I6LEQ
.
Is there a way to make the construct with static_argnums work with sum_first_k inside vmap?
Calling sum_first_k as a standalone function works fine:
print(sum_first_k(np.arange(3.0), 2))
But it doesn't work when calling it from vmap
vmap_sum_first_k = jax.vmap(sum_first_k,(None,0))
print(vmap_sum_first_k(np.arange(10.0), np.arange(4)))
TypeError: Abstract value passed to bool, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using jit, try using static_argnums or applying jit to smaller subfunctions instead.
It's a little awkward, but you can make something like this work using explicit masking:
@jax.jit
def sum_first_k(a, k):
return np.sum(a * (np.arange(a.size) < k))
Ok, thanks, that works. But it is slow in my use case: a dimension is roughly 3e6 x 100 while each k is a slice, not a number, with different slice lengths where median slice length is about 50e3. So explicit masking results in huge number of multiplications by 0. Which is too bad - other parts of my program are unbelievably fast using jax with gpu, much faster than anything else I tried.
I think it's worth adding that slice_sizes needs to be static to the dynamic_slice() docstring. I can send in a PR if that sounds good, WDYT?
I ran into the same issue as shoyer@ above, where I want dynamic slice_sizes() and first tried indexing a[:k], then was told to use dynamic_slice(), got this error message, poked around a bit, and then ended up here.
I think it's worth adding that
slice_sizesneeds to be static to thedynamic_slice()docstring. I can send in a PR if that sounds good, WDYT?
This is a good suggestion! I added clarifications to both the dynamic_slice documentation and this error message in: https://github.com/google/jax/pull/3795
I'm close this issue, since hopefully users will see the more descriptive errors in the future and won't be misled.
Most helpful comment
It would be nice if the error message said something like this, rather than
sending me down a rabbit hole. What actually happened is that I first tried
using indexing like a[:k], which generated an error encouraging me to try
lax.dyanmic_slice.
On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson notifications@github.com
wrote: