Is there a function similar to tensorflow's tf.stop_gradient?
How can I do this in jax??
It's not documented, but I think you're looking for jax.lax.stop_gradient.
jax.lax.stop_gradient
We should document it!
Most helpful comment
It's not documented, but I think you're looking for
jax.lax.stop_gradient.We should document it!