Jax: forward-mode differentiation of lax.scan suddenly not working

Created on 6 Sep 2019  路  6Comments  路  Source: google/jax

When running a piece of code that was working earlier today, I am suddenly getting the following error
NotImplementedError: Forward-mode differentiation rule for 'while' not implemented

If necessary I can try to create a minimal repo, but I will say that if I replace a lax.scan call with a for loop then this error goes away.

All 6 comments

Sounds like it could be a regression! Though we have tests for this, they might leave some case uncovered.

A minimal repro would be really helpful, since our current tests pass.

Ok, I have (finally) reduced this to a minimal repro, see here: https://pastebin.com/WzYk6Xj1

Let me know if anything doesn't make sense.

Thanks!

Thanks again for raising this! I only just started looking at it.

I think the issue (or at least one issue) was uncovered by #1175 or #1269: the loop being complained about is in the PRNG hash function. We're not dropping out of the autodiff trace properly somewhere on integer values (that is my guess at the root cause), and that's hitting the PRNG code which has while loops in it.

Some evidence for this guess is I replaced the line key, split = random.split(key) with key, split = key, key and the error went away.

Still looking...

I think I introduced the real issue in #1224, because in these lines I'm effectively creating JVPTracers (with zeros) for integer-valued arguments. I think Dougal had some clever way of avoiding that before, and I clobbered it!

Whooo!!! As always, what a herculean effort Matt!!

Thanks so much Matt! It works like a charm!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

clemisch picture clemisch  路  3Comments

zhongwen picture zhongwen  路  3Comments

madvn picture madvn  路  3Comments

DylanMuir picture DylanMuir  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments