Jax: Unimplemented: binary integer op 'power'

Created on 10 Dec 2018  路  4Comments  路  Source: google/jax

def square(x):
  return x**2

val = 3
dfn = grad(square)
print(dfn(val))

I was surprised this threw an error. Changing it to val = 3.0 works as expected.

enhancement

All 4 comments

This seems like an important operation to offer :)

I think the binary integer pow function isn't implemented in XLA, since we seem to be getting this error and the kPow opcode doesn't seem to appear in that list. Maybe just an oversight. I'll follow up with XLA folks.

A second issue here is that grad should raise an error on non-floating argument types.

@hawkinsp guessed that XLA's Pow HLO is meant to model std::pow, which apparently doesn't work on integer values either. We should solve this in JAX at the jax.numpy level.

Thanks for the issue report!

I added support for integer powers to jax.numpy.power.

I also filed https://github.com/google/jax/issues/424 for raising an error if taking grad of an integer.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

DylanMuir picture DylanMuir  路  3Comments

sussillo picture sussillo  路  3Comments

madvn picture madvn  路  3Comments

clemisch picture clemisch  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments