Numba: Support p= option in numpy.random.choice

Created on 10 Sep 2017  路  6Comments  路  Source: numba/numba

In numpy.random.choice, the p= option is not supported. Please add support for this option.

For reference, the p= option specifies the probabilities for the different elements in the first argument.

feature_request

Most helpful comment

Echoing what @cedricsimar said above, I used the following numba-compatible workaround (works with the nb.jit(nopython=True) decorator:

@nb.jit(nopython=True) def rand_choice_nb(arr, prob): """ :param arr: A 1D numpy array of values to sample from. :param prob: A 1D numpy array of probabilities for the given samples. :return: A random sample from the given array with a given probability. """ return arr[np.searchsorted(np.cumsum(prob), np.random.random(), side="right")]

All 6 comments

I would like to reiterate this. Being able to specify probabilities is essential for sampling from general distributions with a discrete outcome space.

If outcomes is the vector with possible outcomes and p the probability vector, a workaround for sampling size values with replacement is

```
x = outcomes[np.searchsorted(np.cumsum(p), np.random.rand(size))]
````

Similarly to @gvanderheide
I've used successfully a workaround based on np.random.choice code published on
https://github.com/numpy/numpy/blob/76a76c78d1b049126153e81b0a9d137fa3e4947b/numpy/random/mtrand/mtrand.pyx#L1194:

Given the probability distribution probability and the sampling size size, the code is

cumulative_distribution = np.cumsum(probability)
cumulative_distribution /= cumulative_distribution[-1]
uniform_samples = np.random.rand(size)
index = np.searchsorted(cumulative_distribution, uniform_samples, side="right")

Tested with a sample size of 1 and probability = np.asarray([0.60, 0.16, 0.12, 0.08, 0.04])
The result over 100000000 generated indices is: [0.60007007 0.1599901 0.11995732 0.08000664 0.03997587]

Close enough I guess..

Echoing what @cedricsimar said above, I used the following numba-compatible workaround (works with the nb.jit(nopython=True) decorator:

@nb.jit(nopython=True) def rand_choice_nb(arr, prob): """ :param arr: A 1D numpy array of values to sample from. :param prob: A 1D numpy array of probabilities for the given samples. :return: A random sample from the given array with a given probability. """ return arr[np.searchsorted(np.cumsum(prob), np.random.random(), side="right")]

Echoing what @cedricsimar said above, I used the following numba-compatible workaround (works with the nb.jit(nopython=True) decorator:

def rand_choice_nb(arr, prob):
    """
    :param arr: A 1D numpy array of values to sample from.
    :param prob: A 1D numpy array of probabilities for the given samples.
    :return: A random sample from the given array with a given probability.
    """
    return arr[np.searchsorted(np.cumsum(prob), np.random.random(), side="right")]```

For this solution if you have a grouping of values that don't sum to 1, simply apply a softmax to your "prob" array:

# assume prob is something like [1, 10, 5, 43, 2] (non-probabalistic)
exp = np.exp(prob)
prob = exp/np.sum(exp)

Echoing what @cedricsimar said above, I used the following numba-compatible workaround (works with the nb.jit(nopython=True) decorator:

def rand_choice_nb(arr, prob):
    """
    :param arr: A 1D numpy array of values to sample from.
    :param prob: A 1D numpy array of probabilities for the given samples.
    :return: A random sample from the given array with a given probability.
    """
    return arr[np.searchsorted(np.cumsum(prob), np.random.random(), side="right")]```

For this solution if you have a grouping of values that don't sum to 1, simply apply a softmax to your "prob" array:

# assume prob is something like [1, 10, 5, 43, 2] (non-probabalistic)
exp = np.exp(prob)
prob = exp/np.sum(exp)

Aye, I actually normalise the probabilities before they get passed in here, but good point!

Thanks guys. Was using numpy.random.choice and saw it didn't work, so I switched to random.choices which also doesn't work :) Glad to find a workaround here.

Was this page helpful?
0 / 5 - 0 ratings