Jax: Question: What's the best way to implement a ring buffer?

Created on 15 Oct 2020  路  6Comments  路  Source: google/jax

Hi there,

I'd like to create a ring buffer and I'm wondering what the best approach would be.

Let's say I want to store up to 1M arrays of shape (7, 13). I create my buffer as

import numpy as onp
import jax.numpy as jnp

capacity = 1000000
shape = (7, 13)

storage = jnp.zeros((capacity, *shape))
storage = jax.device_put(storage)  # put on my gpu, say
index = 0

Then insert a new item in such a way that it ends up on the device as well, I see four options:

assert isinstance(new_item, onp.ndarray)  # ordinary numpy array (not jax.numpy)

# option 1
storage = jax.ops.index_update(storage, index, new_item)

# option 2
storage = jax.ops.index_update(storage, index, new_item)
storage = jax.device_put(storage)

# option 3
new_item = jax.device_put(new_item)
storage = jax.ops.index_update(storage, index, new_item)

# option 4
new_item = jax.device_put(new_item)
storage = jax.ops.index_update(storage, index, new_item)
storage = jax.device_put(storage)

Which option would be preferred (and why)?

I'm missing a good mental model for thinking about this, so any help would be appreciated!

Most helpful comment

There are a few ways to get data from external systems, but they aren't super-usable at the moment:

  1. write an XLA custom call
  2. use XLA's infeed/outfeed

(See https://github.com/google/jax/issues/766 for discussion about making this easier)

All 6 comments

I've run some tests and it looks like simple insertion (any variant) is super duper slow.

I'm comparing performance against a basic python implementation that uses a collections.deque.

  • deque variant: insert 680ns, sample 42ms
  • variant from comment above: insert 7s (!), sample 22ms

N.B. sample operation is:

# deque variant
from random import sample
batch = sample(deque, 1024)
batch = jax.tree_multimap(lambda *leaves: jnp.stack(leaves, axis=0), batch)
batch = jax.device_put(batch)

# variant from comment above
idx = jax.random(rng, shape=(1024,), minvalue=0, maxvalue=length)
batch = jax.tree_map(lambda *leaves: leaves[idx], storage)

Should I just avoid using device_put on mutable data like this?

Option (1) should be fine, but everything needs (the whole loop) to be under a single jit decorated function. If you do that, XLA is guaranteed to optimize away the intermediate arrays.

If you try to run this eagerly (without jit compiling the whole thing), jax.ops.index_update makes a whole copy of the array for each modification, which is indeed really slow.

Thanks @shoyer

I feel stupid for asking this, but do you mean something like this?

@jax.jit
def insert(storage, index, length, new_item):
    storage = jax.ops.index_update(storage, index, new_item)
    index = (index + 1) % capacity
    length = jnp.minimum(capacity, length + 1)
    return storage, index, length

The insert would then be:

storage, index, length = insert(storage, index, length, new_item)

This is still slow, because I call it repeatedly for each individual insert. Should I collected batches of new_items to reduce the number of inserts?


By the way, the reason I'm interested in this is to speed up sampling from my buffer, not insertion. I just need insertion to not be painstakingly slow (less than 20ms to get an overall speedup).

I meant that whatever computation you do that is inserting items, that should happen inside the same jit, e.g., under jax.lax.while_loop or jax.lax.scan:

@jax.jit
def entire_computation(inputs)
    def fun(state):
        storage, index, length = state
        new_item = ...  # compute "item" here
        new_state = insert(storage, index, length, new_item)
        return new_state
    ...  # set init_value
    return jax.lax.while_loop(cond_fun, body_fun, init_value)

Inside a jit decorated function, it doesn't really matter when you call device_put(), because everything will get copied to the device as soon as you call the function.

Does that help?

Thanks @shoyer that certainly helps. It also means that I cannot do what I was hoping to do, because new_item is generated outside my control.

But hey that's okay. Thanks for your help!

There are a few ways to get data from external systems, but they aren't super-usable at the moment:

  1. write an XLA custom call
  2. use XLA's infeed/outfeed

(See https://github.com/google/jax/issues/766 for discussion about making this easier)

Was this page helpful?
0 / 5 - 0 ratings