In #324 we got support for logit and expit from scipy.special. This provides all the building blocks to also support scipy.stats.logistic (in particular, it鈥檚 members cdf, pdf, ppf, logcdf and logpdf). Having this would enable doing maximum-likelihood estimation of logistic regression using scipy.optimize.
I'd like to implement this !
I'd like to implement this !
I haven't got time for the coming weeks to help, but I did some preliminary investigation. The primitives in jax.scipy.special that are useful for logistic:
ppf can be implemented as logitcdf can be implemented as expitpdf can be implemented as expit * (1 - expit)logcdf and logpdf are then log(cdf) and log(pdf)The implementations of jax.stats.normal and its accompanying tests are helpful too: https://github.com/google/jax/blob/master/jax/scipy/stats/norm.py
https://github.com/google/jax/blob/master/tests/scipy_stats_test.py
Note that Scipy implements pdf as exp(logpdf) and logpdf(x) as -x - 2. * log1p(exp(-x)), requiring log1p (which is also in Jax fortunately).
I don't know whether the SciPy implementation of pdf is more accurate than simply using expit * (1 - expit) (in my own code I don't see any difference, but I haven't tested this rigorously).
Looking at the SciPy implementation, it seems also rather straightforward to add sf(x) (i.e. 1 - cdf(x)) as cdf(-x) and isf(x) as -ppf(x). This is also very straightforward to add for jax.stats.norm (I have no idea why there aren't present yet).
thanks for outlining the steps, @rhalbersma !
This can be closed now that the feature has been merged.
Most helpful comment
I haven't got time for the coming weeks to help, but I did some preliminary investigation. The primitives in
jax.scipy.specialthat are useful forlogistic:ppfcan be implemented aslogitcdfcan be implemented asexpitpdfcan be implemented asexpit * (1 - expit)logcdfandlogpdfare thenlog(cdf)andlog(pdf)The implementations of
jax.stats.normaland its accompanying tests are helpful too: https://github.com/google/jax/blob/master/jax/scipy/stats/norm.pyhttps://github.com/google/jax/blob/master/tests/scipy_stats_test.py
Note that Scipy implements
pdfasexp(logpdf)andlogpdf(x)as-x - 2. * log1p(exp(-x)), requiringlog1p(which is also in Jax fortunately).I don't know whether the SciPy implementation of
pdfis more accurate than simply usingexpit * (1 - expit)(in my own code I don't see any difference, but I haven't tested this rigorously).Looking at the SciPy implementation, it seems also rather straightforward to add
sf(x)(i.e.1 - cdf(x)) ascdf(-x)andisf(x)as-ppf(x). This is also very straightforward to add forjax.stats.norm(I have no idea why there aren't present yet).