def square(x):
return x**2
val = 3
dfn = grad(square)
print(dfn(val))
I was surprised this threw an error. Changing it to val = 3.0 works as expected.
This seems like an important operation to offer :)
I think the binary integer pow function isn't implemented in XLA, since we seem to be getting this error and the kPow opcode doesn't seem to appear in that list. Maybe just an oversight. I'll follow up with XLA folks.
A second issue here is that grad should raise an error on non-floating argument types.
@hawkinsp guessed that XLA's Pow HLO is meant to model std::pow, which apparently doesn't work on integer values either. We should solve this in JAX at the jax.numpy level.
Thanks for the issue report!
I added support for integer powers to jax.numpy.power.
I also filed https://github.com/google/jax/issues/424 for raising an error if taking grad of an integer.