Jax: [Question] Efficiently sampling from the binomial distribution

Created on 6 Mar 2019  路  4Comments  路  Source: google/jax

I currently have a naive binomial sampler which uses jax.random.bernoulli under the hood:

def binomial(key, p, n=1, shape=()):
    p, n = _promote_shapes(p, n)
    shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n))
    n_max = np.max(n)
    uniforms = random.uniform(key, shape + (n_max,))
    n = np.expand_dims(n, axis=-1)
    p = np.expand_dims(p, axis=-1)
    mask = (np.arange(n_max) > n).astype(uniforms.dtype)
    p, uniforms = promote_shapes(p, uniforms)
    return np.sum(mask * lax.lt(uniforms, p), axis=-1, keepdims=False)

This works, but the biggest drawback is that it is not jittable due to the dynamic size argument passed to random.uniform (see error trace below). It is also wasteful in that it draws n_max uniform random floats and does an element-wise multiply with mask. I was wondering if there is a more efficient way to sample from the binomial distribution, using existing primitives. I am new to JAX, so any insights on improving on this implementation will be really helpful.

error trace

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../numpyro/distributions/distribution.py:47: in rvs
    return _rvs(self, *args, **kwargs)
../numpyro/distributions/distribution.py:36: in _rvs
    vals = _instance._rvs(*args)
../numpyro/distributions/discrete.py:22: in _rvs
    return binomial(self._random_state, p, n, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
    jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
    ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
    compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
    ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
    ans = self.f(*args, **self.kwargs)
../numpyro/distributions/util.py:285: in binomial
    uniforms = random.uniform(key, shape + (n_max,))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/api.py:107: in f_jitted
    jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:544: in call_bind
    ans = full_lower(top_trace.process_call(primitive, f, tracers, kwargs))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/partial_eval.py:85: in process_call
    out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:541: in call_bind
    ans = primitive.impl(f, *args, **kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:436: in xla_call_impl
    compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:146: in memoized_fun
    ans = call(f, *args)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/interpreters/xla.py:448: in xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/linear_util.py:86: in call_wrapped
    ans = self.f(*args, **self.kwargs)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:240: in uniform
    bits = _random_bits(key, nbits, shape)
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/random.py:199: in _random_bits
    if max_count >= onp.iinfo(onp.uint32).max:
../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/core.py:256: in __bool__
    def __bool__(self): return self.aval._bool(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ShapedArray(bool[])
args = (Traced<ShapedArray(bool[]):JaxprTrace(level=-1/2)>,)

    def error(self, *args):
>     raise TypeError(concretization_err_msg(fun))
E     TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

../../../../miniconda3/envs/numpyro/lib/python3.6/site-packages/jax/abstract_arrays.py:38: TypeError

Most helpful comment

The usage would be something like sampler = jit(binomial, static_argnums=(3,)).

Thanks for the suggestion. I don't think that would work as is because the size argument to random.uniform itself isn't static. It depends on n_max = np.max(n), but if we have both n and shape as static_args, it should work fine! This might not be ideal to use as a default decorator because in cases where n can take many possible values, e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this and not reclaiming any performance benefits in return. Another not so ideal strategy might be to set n_max to some high value or take in as static function argument so as not to force recompilation as often.

All 4 comments

As I was looking into the code a bit, I think that if we could make the shape arg to random.uniform jittable, it will help use-cases like these where we may need to draw a variable number of samples in [0, 1), i.e. the shape arg could be an abstract value. This would help with such examples, e.g. binomial/multinomial samplers, but more generally anything where the shape arg isn't known until runtime. This would require a change to _random_bits and maybe other functions too, but I am unsure how feasible this is. 馃槃

Can you have shape be a static arg to jit? This would mean that binomial will get recompiled every time it is called with a different shape, but for multiple calls with the same shape (which seems like a common use-case) it should work.

The usage would be something like sampler = jit(binomial, statig_argnums=(3,)).

Alternatively you could decorate the function with @jit(static_argnums=(3,)).

The usage would be something like sampler = jit(binomial, static_argnums=(3,)).

Thanks for the suggestion. I don't think that would work as is because the size argument to random.uniform itself isn't static. It depends on n_max = np.max(n), but if we have both n and shape as static_args, it should work fine! This might not be ideal to use as a default decorator because in cases where n can take many possible values, e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this and not reclaiming any performance benefits in return. Another not so ideal strategy might be to set n_max to some high value or take in as static function argument so as not to force recompilation as often.

e.g. if its coming from a data minibatch, we probably would end up spending a lot of time jitting this

This is tricky... I'm looking forward to see a solution too.

Was this page helpful?
0 / 5 - 0 ratings