I currently have a naive binomial sampler which uses jax.random.bernoulli under the hood:
def binomial(key, p, n=1, shape=()):
p, n = _promote_shapes(p, n)
shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n))
n_max = np.max(n)
uniforms = random.uniform(key, shape + (n_max,))
n = np.expand_dims(n, axis=-1)
p = np.expand_dims(p, axis=-1)
mask = (np.arange(n_max) > n).astype(uniforms.dtype)
p, uniforms = promote_shapes(p, uniforms)
return np.sum(mask * lax.lt(uniforms, p), axis=-1, keepdims=False)
This works, but the biggest drawback is that it is not jittable due to the dynamic size argument passed to random.uniform (see error trace below). It is also wasteful in that it draws n_max uniform random floats and does an element-wise multiply with mask. I was wondering if there is a more efficient way to sample from the binomial distribution, using existing primitives. I am new to JAX, so any insights on improving on this implementation will be really helpful.
error trace
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../numpyro/distributions/distribution.py:47: in rvs
return _rvs(self, *args, **kwargs)
../numpyro/distributions/distribution.py:36: in _rvs
vals = _instance._rvs(*args)
../numpyro/distributions/discrete.py:22: in _rvs
return binomial(self._random_state, p, n, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
ans = self.f(*args, **self.kwargs)
../numpyro/distributions/util.py:285: in binomial
uniforms = random.uniform(key, shape + (n_max,))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:544: in call_bind
ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:85: in process_call
out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
ans = self.f(*args, **self.kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:240: in uniform
bits = _random_bits(key, nbits, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:199: in _random_bits
if max_count >= onp.iinfo(onp.uint32).max:
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:256: in __bool__
def __bool__(self): return self.aval._bool(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = ShapedArray(bool[])
args = (Traced<ShapedArray(bool[]):JaxprTrace(level=-1/2)>,)
def error(self, *args):
> raise TypeError(concretization_err_msg(fun))
E TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/abstract_arrays.py:38: TypeError
As I was looking into the code a bit, I think that if we could make the shape arg to random.uniform jittable, it will help use-cases like these where we may need to draw a variable number of samples in [0, 1), i.e. the shape arg could be an abstract value. This would help with such examples, e.g. binomial/multinomial samplers, but more generally anything where the shape arg isn't known until runtime. This would require a change to _random_bits and maybe other functions too, but I am unsure how feasible this is. 馃槃
Can you have shape be a static arg to jit? This would mean that binomial will get recompiled every time it is called with a different shape, but for multiple calls with the same shape (which seems like a common use-case) it should work.
The usage would be something like sampler = jit(binomial, statig_argnums=(3,)).
Alternatively you could decorate the function with @jit(static_argnums=(3,)).
The usage would be something like sampler = jit(binomial, static_argnums=(3,)).
Thanks for the suggestion. I don't think that would work as is because the size argument to random.uniform itself isn't static. It depends on n_max = np.max(n), but if we have both n and shape as static_args, it should work fine! This might not be ideal to use as a default decorator because in cases where n can take many possible values, e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this and not reclaiming any performance benefits in return. Another not so ideal strategy might be to set n_max to some high value or take in as static function argument so as not to force recompilation as often.
e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this
This is tricky... I'm looking forward to see a solution too.
Most helpful comment
Thanks for the suggestion. I don't think that would work as is because the size argument to
random.uniformitself isn't static. It depends onn_max = np.max(n), but if we have bothnandshapeasstatic_args, it should work fine! This might not be ideal to use as a default decorator because in cases wherencan take many possible values, e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this and not reclaiming any performance benefits in return. Another not so ideal strategy might be to setn_maxto some high value or take in as static function argument so as not to force recompilation as often.