Would be possible to add gradients for expit and logit?
Sure thing!
I put a quick implementation in https://github.com/google/jax/pull/326 for which I still need to write tests. Note that if you are on CPU, you may wish to disable fast math to get correct inf/nan semantics, which at the moment you do with the environment variable
XLA_FLAGS=--xla_cpu_enable_fast_math=false
The key here is actually that what you want is not a gradient for these operators. All you need is an implementation of them in terms of XLA (aka lax) primitives. JAX already knows how to automatically differentiate each of these primitives, and automatic differentiation is compositional — things built out of differentiable primitives can be differentiated too.
Please feel free to add any more numpy or scipy functions you might need, in the same way!
Awesome, thanks!!
I committed these. Please let us know if you have any problems.
We also welcome PRs if you want to help out and add a few more functions!
Most helpful comment
Sure thing!
I put a quick implementation in https://github.com/google/jax/pull/326 for which I still need to write tests. Note that if you are on CPU, you may wish to disable fast math to get correct inf/nan semantics, which at the moment you do with the environment variable
The key here is actually that what you want is not a gradient for these operators. All you need is an implementation of them in terms of XLA (aka
lax) primitives. JAX already knows how to automatically differentiate each of these primitives, and automatic differentiation is compositional — things built out of differentiable primitives can be differentiated too.Please feel free to add any more numpy or scipy functions you might need, in the same way!