Jax: Working with data of different lengths using automasking

Created on 23 Nov 2020  路  5Comments  路  Source: google/jax

I am working with multiple time series of different lengths, that should all undergo the same operations.
Following the tweet from @marcvanzee I attempted to use mask and vmap for these manipulations. Unfortunately due to the lack of documentation, I am not sure I can fully decipher Marc's example.
The simplest operation I am attempting is concatenating two padded arrays. Is it possible to parallelize it using mask and vmap?

import jax.numpy as np

# each row is a data_set of different length. Zero padding is used to put it all in one array.
data_sets_a = np.array([[1,2,0,0], [4,5,6,7]])
data_lengths_a = np.array([2, 4])

data_sets_b = np.array([[1,0], [1,2]])
data_lengths_b = np.array([1, 2])

num_data_sets = data_sets_a.shape[0]

# implementation using list comprehension
data_sets_concatenated_list = [np.concatenate([data_sets_a[i, :data_lengths_a[i]], data_sets_b[i, :data_lengths_b[i]]]) \
                               for i in range(num_data_sets)]
data_sets_concatenated = np.array(list(itertools.zip_longest(*data_sets_concatenated_list, fillvalue=0))).T

Is it possible to define something like:

def concatenate(array_1, array_2):
    return np.concatenate([array_1, array_2])

input_mask_1 = data_lengths_a
input_mask_2 = data_lengths_b
output_mask = data_lengths_a + data_sets_b

and use vmap and mask to get the same result as in the code above?

question

All 5 comments

Hi noashin, I'm looking at similar stuff and I found this comment with an example of the masking feature https://github.com/google/jax/issues/2521#issuecomment-604759386

It's from march 2020 so I don't know if they've changed anything about the API since then. I'm trying to get it working myself too so please say if you manage to get it working :)

So far no luck. For the more basic operations I need I am using padded arrays, but for the more complicated ones, I am currently looping.

Not an exact answer, but looking at the masking tests and that comment:

@functools.partial(jax.mask, in_shapes=["n", "m"], out_shape="_")
def mask_concat(x, y):
    return jnp.pad(jnp.concatenate([x, y]), (0, 1))

foo = jax.vmap(mask_concat)([data_sets_a, data_sets_b], dict(n=data_lengths_a, m=data_lengths_b))[:, :-1]

> [[1 2 1 0 0 0]
   [4 5 6 7 1 2]]

Maybe there is something more involved and pretty that can get the job without that final slicing, but it should work.

The explanation is that mask_concat will receive two 1d arrays (because of vmap) and the sizes to take into account will be defined by the values in n and m, respectively. The resulting 1d array will have an undefined length (n + m + 1, but we can skip setting it explicitly).

@myagues Thank you so much! This is definitely what I need.
I tried to apply it to an example where the function to be masked and vmapped operates on a vector and a matrix, and I can't get it to work.

import jax.numpy as jnp
from jax import mask, vmap

def my_func(vec, mat):
    return jnp.dot(mat, vec)

mat = jnp.ones((2, 3, 3))
vec = jnp.array([[1., 2., 3.], [1., 2., 0.]])

inds = jnp.array([3, 2])
inds_mat = jnp.array([[2, 2], [3, 3]])

masked_func = mask(my_func, in_shapes=["n", "m"], out_shape="_")
vmapped_masked_func = vmap(masked_func, (0, 0), 0)
res = vmapped_masked_func([vec, mat], dict(n=inds, m=inds_mat))

The error is FilteredStackTrace: AssertionError: length mismatch: [1, 2].
I think the problem is in the definition of the masking indices of the matrix inds_mat, but I don't have an idea how to solve it.

I am not sure if it is possible to index with "multi index" as in inds_mat. I have not found anything of the sorts, but there is an example in the tests with lax.dot:
https://github.com/google/jax/blob/4d78d60a24f2cfd2b16a3cf8a2992b693ff1de05/tests/masking_test.py#L432-L437

The error you are getting is a length mismatch between mat and the in_shapes value you are declaring: "m" should be "(m, _)". Essentially, you are giving a matrix, but jax.mask is expecting a vector.

Since we are using jnp.dot, that second dimension needs to be a fixed value resulting in: in_shapes=["n", "(m, n)"]. However, I guess m takes into account only the leading value of inds_mat? Also, the resulting matrix will give "padded" values for shorter sequences:

masked_func = jax.mask(my_func, in_shapes=["n", "(m, n)"], out_shape="_")
jax.vmap(masked_func)([vec, mat], dict(n=inds, m=inds_mat))
> DeviceArray([[6., 6., 6.],
               [3., 3., 3.]], dtype=float32)
Was this page helpful?
0 / 5 - 0 ratings