Jax: how to vectorize custom functions with vmap

Created on 24 Mar 2020  路  7Comments  路  Source: google/jax

Hello everyone,
I am trying to vectorize a function which takes 2 parameters: (model_params, x). Both of these parameters are list of numpy arrays.

I would like to vectorize this function on the second parameter. My problem is that the shapes of the arrays in x parameter are not fixed. I tried to fill them with zeros to have the same shape and get rid of the extra values once I am in the function but it didn't work.

I would appreciate if you have any suggestions for me.

Thank you!

question

All 7 comments

Hey @cagrikymk !

Could you provide a small code example showing what you want to do?

I have a batch loss function with the following signature:
jax_batch_loss(parameters, systems, true_vals)
_parameters_ is a list consists of numpy arrays and scalar values, the structure is fixed and I am trying to optimize the values here.

_systems_ is a list of lists and the inner lists consist of numpy arrays with varying shapes.

_true_vals_ is a list consisting of scalar values.

In the body of that batch loss function, I iterate through the systems and true values, and calculate the loss individually using my model parameters and aggregate them.

I would like to parallelize the batch loss function. I assume I could use vmap for this task but I couldnt figure it out.

Hey @cagrikymk you should probably use something like this:
jax_batch_loss = vmap(loss_fn, in_axes=(None, 0, 0))
assuming the batch dimension is the 0-th dimension.
If this doesn't work could you post here a more detailed example of your loss?

Thank you for the help @joaogui1!
That gave me an error related to the shapes. My problem is that the shapes within a batch vary.
I assume I can just pad them with zeroes but that will increase the arithmetic complexity.
Are there any restrictions to using vmap?

For me the best way to force the framework to distribute the elements of the batch onto different cores using something like round robin. Can I use pmap for it?

I am doing molecular dynamics (similar to https://github.com/google/jax-md) and probably stepping on a bug somewhere because the code runs smoothly using the cpu version of the library but I am getting a seg-fault when I use the gpu version (both jaxlib and jax libraries are up to date). As soon as I isolate the section causing the seg-fault, I will share a minimal code piece reproducing the error.

@cagrikymk Please confirm you have jaxlib 0.1.45 or newer.

@hawkinsp I installed everything from scratch and the problem disappeared. I think that error was on my end.

I think this can be closed

Was this page helpful?
0 / 5 - 0 ratings

Related issues

kunc picture kunc  路  3Comments

shannon63 picture shannon63  路  3Comments

yfji picture yfji  路  3Comments

clemisch picture clemisch  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments