If we want certain jax array to be invariant to gradient operator, what is the jax equivalent that we can use (in PyTorch detach() function or in tensorflow stop_gradient() operator works)?
jax.lax.stop_gradient does this: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html#jax.lax.stop_gradient
Thanks!
Most helpful comment
jax.lax.stop_gradient does this: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html#jax.lax.stop_gradient