Jax: Could we have support for scipy.stats.logistic?

Created on 12 Jan 2020  路  4Comments  路  Source: google/jax

In #324 we got support for logit and expit from scipy.special. This provides all the building blocks to also support scipy.stats.logistic (in particular, it鈥檚 members cdf, pdf, ppf, logcdf and logpdf). Having this would enable doing maximum-likelihood estimation of logistic regression using scipy.optimize.

enhancement good first issue

Most helpful comment

I'd like to implement this !

I haven't got time for the coming weeks to help, but I did some preliminary investigation. The primitives in jax.scipy.special that are useful for logistic:

  • [ ] ppf can be implemented as logit
  • [ ] cdf can be implemented as expit
  • [ ] pdf can be implemented as expit * (1 - expit)
  • [ ] logcdf and logpdf are then log(cdf) and log(pdf)

The implementations of jax.stats.normal and its accompanying tests are helpful too: https://github.com/google/jax/blob/master/jax/scipy/stats/norm.py
https://github.com/google/jax/blob/master/tests/scipy_stats_test.py

Note that Scipy implements pdf as exp(logpdf) and logpdf(x) as -x - 2. * log1p(exp(-x)), requiring log1p (which is also in Jax fortunately).

I don't know whether the SciPy implementation of pdf is more accurate than simply using expit * (1 - expit) (in my own code I don't see any difference, but I haven't tested this rigorously).

Looking at the SciPy implementation, it seems also rather straightforward to add sf(x) (i.e. 1 - cdf(x)) as cdf(-x) and isf(x) as -ppf(x). This is also very straightforward to add for jax.stats.norm (I have no idea why there aren't present yet).

All 4 comments

I'd like to implement this !

I'd like to implement this !

I haven't got time for the coming weeks to help, but I did some preliminary investigation. The primitives in jax.scipy.special that are useful for logistic:

  • [ ] ppf can be implemented as logit
  • [ ] cdf can be implemented as expit
  • [ ] pdf can be implemented as expit * (1 - expit)
  • [ ] logcdf and logpdf are then log(cdf) and log(pdf)

The implementations of jax.stats.normal and its accompanying tests are helpful too: https://github.com/google/jax/blob/master/jax/scipy/stats/norm.py
https://github.com/google/jax/blob/master/tests/scipy_stats_test.py

Note that Scipy implements pdf as exp(logpdf) and logpdf(x) as -x - 2. * log1p(exp(-x)), requiring log1p (which is also in Jax fortunately).

I don't know whether the SciPy implementation of pdf is more accurate than simply using expit * (1 - expit) (in my own code I don't see any difference, but I haven't tested this rigorously).

Looking at the SciPy implementation, it seems also rather straightforward to add sf(x) (i.e. 1 - cdf(x)) as cdf(-x) and isf(x) as -ppf(x). This is also very straightforward to add for jax.stats.norm (I have no idea why there aren't present yet).

thanks for outlining the steps, @rhalbersma !

This can be closed now that the feature has been merged.

Was this page helpful?
0 / 5 - 0 ratings