I am working with multiple time series of different lengths, that should all undergo the same operations.
Following the tweet from @marcvanzee I attempted to use mask and vmap for these manipulations. Unfortunately due to the lack of documentation, I am not sure I can fully decipher Marc's example.
The simplest operation I am attempting is concatenating two padded arrays. Is it possible to parallelize it using mask and vmap?
import jax.numpy as np
# each row is a data_set of different length. Zero padding is used to put it all in one array.
data_sets_a = np.array([[1,2,0,0], [4,5,6,7]])
data_lengths_a = np.array([2, 4])
data_sets_b = np.array([[1,0], [1,2]])
data_lengths_b = np.array([1, 2])
num_data_sets = data_sets_a.shape[0]
# implementation using list comprehension
data_sets_concatenated_list = [np.concatenate([data_sets_a[i, :data_lengths_a[i]], data_sets_b[i, :data_lengths_b[i]]]) \
for i in range(num_data_sets)]
data_sets_concatenated = np.array(list(itertools.zip_longest(*data_sets_concatenated_list, fillvalue=0))).T
Is it possible to define something like:
def concatenate(array_1, array_2):
return np.concatenate([array_1, array_2])
input_mask_1 = data_lengths_a
input_mask_2 = data_lengths_b
output_mask = data_lengths_a + data_sets_b
and use vmap and mask to get the same result as in the code above?
Hi noashin, I'm looking at similar stuff and I found this comment with an example of the masking feature https://github.com/google/jax/issues/2521#issuecomment-604759386
It's from march 2020 so I don't know if they've changed anything about the API since then. I'm trying to get it working myself too so please say if you manage to get it working :)
So far no luck. For the more basic operations I need I am using padded arrays, but for the more complicated ones, I am currently looping.
Not an exact answer, but looking at the masking tests and that comment:
@functools.partial(jax.mask, in_shapes=["n", "m"], out_shape="_")
def mask_concat(x, y):
return jnp.pad(jnp.concatenate([x, y]), (0, 1))
foo = jax.vmap(mask_concat)([data_sets_a, data_sets_b], dict(n=data_lengths_a, m=data_lengths_b))[:, :-1]
> [[1 2 1 0 0 0]
[4 5 6 7 1 2]]
Maybe there is something more involved and pretty that can get the job without that final slicing, but it should work.
The explanation is that mask_concat will receive two 1d arrays (because of vmap) and the sizes to take into account will be defined by the values in n and m, respectively. The resulting 1d array will have an undefined length (n + m + 1, but we can skip setting it explicitly).
@myagues Thank you so much! This is definitely what I need.
I tried to apply it to an example where the function to be masked and vmapped operates on a vector and a matrix, and I can't get it to work.
import jax.numpy as jnp
from jax import mask, vmap
def my_func(vec, mat):
return jnp.dot(mat, vec)
mat = jnp.ones((2, 3, 3))
vec = jnp.array([[1., 2., 3.], [1., 2., 0.]])
inds = jnp.array([3, 2])
inds_mat = jnp.array([[2, 2], [3, 3]])
masked_func = mask(my_func, in_shapes=["n", "m"], out_shape="_")
vmapped_masked_func = vmap(masked_func, (0, 0), 0)
res = vmapped_masked_func([vec, mat], dict(n=inds, m=inds_mat))
The error is FilteredStackTrace: AssertionError: length mismatch: [1, 2].
I think the problem is in the definition of the masking indices of the matrix inds_mat, but I don't have an idea how to solve it.
I am not sure if it is possible to index with "multi index" as in inds_mat. I have not found anything of the sorts, but there is an example in the tests with lax.dot:
https://github.com/google/jax/blob/4d78d60a24f2cfd2b16a3cf8a2992b693ff1de05/tests/masking_test.py#L432-L437
The error you are getting is a length mismatch between mat and the in_shapes value you are declaring: "m" should be "(m, _)". Essentially, you are giving a matrix, but jax.mask is expecting a vector.
Since we are using jnp.dot, that second dimension needs to be a fixed value resulting in: in_shapes=["n", "(m, n)"]. However, I guess m takes into account only the leading value of inds_mat? Also, the resulting matrix will give "padded" values for shorter sequences:
masked_func = jax.mask(my_func, in_shapes=["n", "(m, n)"], out_shape="_")
jax.vmap(masked_func)([vec, mat], dict(n=inds, m=inds_mat))
> DeviceArray([[6., 6., 6.],
[3., 3., 3.]], dtype=float32)