When running a piece of code that was working earlier today, I am suddenly getting the following error
NotImplementedError: Forward-mode differentiation rule for 'while' not implemented
If necessary I can try to create a minimal repo, but I will say that if I replace a lax.scan call with a for loop then this error goes away.
Ok, I have (finally) reduced this to a minimal repro, see here: https://pastebin.com/WzYk6Xj1
Let me know if anything doesn't make sense.
Thanks!
Thanks again for raising this! I only just started looking at it.
I think the issue (or at least one issue) was uncovered by #1175 or #1269: the loop being complained about is in the PRNG hash function. We're not dropping out of the autodiff trace properly somewhere on integer values (that is my guess at the root cause), and that's hitting the PRNG code which has while loops in it.
Some evidence for this guess is I replaced the line key, split = random.split(key) with key, split = key, key and the error went away.
Still looking...
I think I introduced the real issue in #1224, because in these lines I'm effectively creating JVPTracers (with zeros) for integer-valued arguments. I think Dougal had some clever way of avoiding that before, and I clobbered it!
Whooo!!! As always, what a herculean effort Matt!!
Thanks so much Matt! It works like a charm!