Hi,
I want to use jax to autogenerate a code for calculating the backward pass. Any idea where to begin?
For Example given this:
def sigmoid(x):
return 1. / (1. + np.exp(-x))
I want to generate:
def sigmoid_backward(x):
return 1. / (1. + np.exp(-x)) * (1. - 1. / (1. + np.exp(-x)))
There is commercial software doing the same (as far as I understand) for C++ and they have a
prototype. You can search for some literature on their website.
I think it would be great to have something along the same lines available in jax.
Thanks for the question!
When you say "I want to generate: [Python code]", do you mean you want to generate Python source text? I'm going to assume yes, but correct me if I'm wrong. (If you just want a Python callable, take a look at JAX's autodiff cookbook.)
Generating Python source text is outside of JAX's scope. However, have you looked at Tangent? It might do exactly what you want.
(By the way, I think your sigmoid_backward computes the derivative of sigmoid, i.e. the coefficients of the Jacobian, but it doesn't implement a backward pass, i.e. what we would call a vector-Jacobian product, or VJP. To be a VJP, it needs to take another argument, namely the "incoming gradients" shaped like the output.)
Since we've got a lot of issues, if Tangent is a good fit for what you need, let us know so we can close this one!
Yes, although tangent is not complete at least it showed me the way. Thanks.