Jax: Question about block_until_ready() on tuple

Created on 17 May 2020  路  2Comments  路  Source: google/jax

I want to time the following:
opt_state = update(itr, grad(loss)(get_params(opt_state)), opt_state).

opt_state is a Python tuple so I can't call block_until_ready() directly.

What is the best way to ensure that opt_state is consumed from the host so I get accurate time?

  • nothing; does containment in a native Python contain imply the values have already been consumed?
  • tree_map and call block_until_ready() over all the leaves of opt_state
  • make opt_state a JAX type and call block_until_ready() once (If so, how to convert it to JAX type?)
  • directly consume from the host in some other way?
question

Most helpful comment

Thanks for the very detailed reply as always @mattjj :)

No, loops won't do anything special. The only thing that blocks the Python thread (e.g. so that timers are accurate) is executing a non-jax operation on it

Interesting, good to know!

We used to have JaxTuples! But they make the system much more complex, both in terms of "front-end" transformation stuff and "back-end" low-level runtime stuff.

Haha so I'm not crazy, I remember noticing these before I think! The way JAX handles nested containers is super nice. I suppose it's one of the simpler features but honestly one of my favourite things about JAX btw.

That works, e.g. printing the values, but then you'd also be timing the transfer-to-host time as well as whatever operation (e.g. printing) is being performed.

Good point, I guess that's why block_until_ready() is useful in the first place.

So yeah I'm thinking tree_map(lambda x: x.block_until_ready, opt_state)! But also if update is jitted then you can just do tree_flatten(opt_state)[0][0].block_until_ready(), since all results of a jitted function become available at the same time.

Ah, yes update is jitted so I think this is what I'll go with, thanks for pointing out this additional simplification.

All 2 comments

I think tree-mapping block_until_ready is a decent idea. I don't think it should add noticeable overheads (based on my guess about how much time the computation itself will take).

nothing; does containment in a native Python contain imply the values have already been consumed?

No, loops won't do anything special. The only thing that blocks the Python thread (e.g. so that timers are accurate) is executing a non-jax operation on it (like printing a value, which will entail blocking until that value is ready and then also transferring it to the CPU) or block_until_ready.

make opt_state a JAX type and call block_until_ready() once (If so, how to convert it to JAX type?)

We used to have JaxTuples! But they make the system much more complex, both in terms of "front-end" transformation stuff and "back-end" low-level runtime stuff.

directly consume from the host in some other way?

That works, e.g. printing the values, but then you'd also be timing the transfer-to-host time as well as whatever operation (e.g. printing) is being performed.

So yeah I'm thinking tree_map(lambda x: x.block_until_ready, opt_state)! But also if update is jitted then you can just do tree_flatten(opt_state)[0][0].block_until_ready(), since all results of a jitted function become available at the same time.

Thanks for the very detailed reply as always @mattjj :)

No, loops won't do anything special. The only thing that blocks the Python thread (e.g. so that timers are accurate) is executing a non-jax operation on it

Interesting, good to know!

We used to have JaxTuples! But they make the system much more complex, both in terms of "front-end" transformation stuff and "back-end" low-level runtime stuff.

Haha so I'm not crazy, I remember noticing these before I think! The way JAX handles nested containers is super nice. I suppose it's one of the simpler features but honestly one of my favourite things about JAX btw.

That works, e.g. printing the values, but then you'd also be timing the transfer-to-host time as well as whatever operation (e.g. printing) is being performed.

Good point, I guess that's why block_until_ready() is useful in the first place.

So yeah I'm thinking tree_map(lambda x: x.block_until_ready, opt_state)! But also if update is jitted then you can just do tree_flatten(opt_state)[0][0].block_until_ready(), since all results of a jitted function become available at the same time.

Ah, yes update is jitted so I think this is what I'll go with, thanks for pointing out this additional simplification.

Was this page helpful?
0 / 5 - 0 ratings