Jax: How can I use jax to generate code for backward pass?

Created on 24 Sep 2019  路  4Comments  路  Source: google/jax

Hi,
I want to use jax to autogenerate a code for calculating the backward pass. Any idea where to begin?

For Example given this:

def sigmoid(x):
    return 1. / (1. + np.exp(-x))

I want to generate:

def sigmoid_backward(x):
        return 1. / (1. + np.exp(-x)) * (1. - 1. / (1. + np.exp(-x)))
question

All 4 comments

There is commercial software doing the same (as far as I understand) for C++ and they have a
prototype. You can search for some literature on their website.

I think it would be great to have something along the same lines available in jax.

Thanks for the question!

When you say "I want to generate: [Python code]", do you mean you want to generate Python source text? I'm going to assume yes, but correct me if I'm wrong. (If you just want a Python callable, take a look at JAX's autodiff cookbook.)

Generating Python source text is outside of JAX's scope. However, have you looked at Tangent? It might do exactly what you want.

(By the way, I think your sigmoid_backward computes the derivative of sigmoid, i.e. the coefficients of the Jacobian, but it doesn't implement a backward pass, i.e. what we would call a vector-Jacobian product, or VJP. To be a VJP, it needs to take another argument, namely the "incoming gradients" shaped like the output.)

Since we've got a lot of issues, if Tangent is a good fit for what you need, let us know so we can close this one!

Yes, although tangent is not complete at least it showed me the way. Thanks.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  路  3Comments

sursu picture sursu  路  3Comments

zhongwen picture zhongwen  路  3Comments

rdaems picture rdaems  路  3Comments

madvn picture madvn  路  3Comments