Jax: "logit" parameter of jax.random.categorical is misnamed, actually a log probability

Created on 29 Apr 2020  路  7Comments  路  Source: google/jax

summary

jax.random.categorical (introduced #1855) takes a parameter, logit, to specify the categorical distribution to sample from.

I believe this parameter is misnamed. It behaves like a log probability, not a logit.

potential resolutions

  1. rename the parameter to logprob or something similar
  2. change implementation and tests to behave like a logit
  3. change to something less ambiguous, like a probability p

background

Maybe people use "logit" in different ways, but according to Wikipedia, _logit(p)_ = _ln( p / (1 - p) )_. This is also the behavior implemented in scipy.special.logit (and jax.scipy.special.logit).

example / demo

import jax
import jax.numpy as np

key = jax.random.PRNGKey(seed=1)

p = np.array([0.4, 0.6])
n = 10000

arg = np.log(p / (1 - p))
# or:
# arg = jax.scipy.special.logit(p)
samples = jax.random.categorical(key, logits=arg, shape=(n,))

print(np.unique(samples, return_counts=True))
# (DeviceArray([0, 1], dtype=int32), DeviceArray([3059, 6941], dtype=int32))

arg = np.log(p)
samples = jax.random.categorical(key, logits=arg, shape=(n,))

print(np.unique(samples, return_counts=True))
# (DeviceArray([0, 1], dtype=int32), DeviceArray([3998, 6002], dtype=int32))

as you can see the counts for log(p) follow the given distribution, but the counts using logit(p) do not.

question

All 7 comments

Looking back on https://github.com/google/jax/pull/1855, and the docstring actually, it seems like there was ambiguity on what exactly the "logit" parameter refers to. eg @j-towns suggested clarifying the definition.

if the intended behavior is to be like a non-normalized log-probability, then i would say "logit" is not the correct term.

Wow, thanks for the clear explanation!

The Wikipedia page also says this though:

In deep learning, the term logits layer is popularly used for the last neuron layer of neural networks used for classification tasks, which produce raw prediction values as real numbers ranging from {\displaystyle (-\infty ,+\infty )}{\displaystyle (-\infty ,+\infty )}[3].

Random internet folks, Google's ML glossary, PyTorch, and TF seem also to think that, at least in the ML rather than stats context, "logits" can mean "un-normalized log probabilities".

If we had a really good alternative we could change it. Calling it "logprob" seems not to be descriptive enough though because it misses the "un-normalized" bit.

change implementation and tests to behave like a logit

What would that entail, exactly?

Random internet folks, Google's ML glossary, PyTorch, and TF seem also to think that, at least in the ML rather than stats context, "logits" can mean "un-normalized log probabilities".

This matches my understanding of the convention as well. "Logits" does not imply the same strict definition as the literal logit() function.

cool, thanks for explaining @mattjj and @shoyer. i remember now how "logit" is often the word for neural net layers before softmax activation. i'm doing probabilistic modeling with latent mixtures where the word doesn't have the same connotation.

how about adding an option for probabilities p? this is offered by tfp.distributions.Categorical and torch.distributions.categorical.Categorical.

funny, a similar issue to this is open for pytorch https://github.com/pytorch/pytorch/issues/16291

If you like the conventions of TensorFlow probability, note that uou can _already_ use TFP's Categorical with JAX. We probably don't advertise this well enough!

Alternatively, you can convert from probabilities to unnormalized log-probabilities just by using log(), e.g., jax.random.categorical(key, logits=jnp.log(p), shape=(n,)).

I would prefer to document either or both of these approaches in jax.random.categorical rather than to add a redundant argument.

It sounds to me like we've resolved this issue. Closing, but feel free to reopen if there's something left to discuss!

thanks! yup, i was going to close.

thanks @shoyer about the fyi with tfp/jax. very useful!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sussillo picture sussillo  路  3Comments

kunc picture kunc  路  3Comments

sursu picture sursu  路  3Comments

fehiepsi picture fehiepsi  路  3Comments

alexbw picture alexbw  路  3Comments