Jax: Does vmap or lax.scan do caching?

Created on 21 Jul 2020  路  6Comments  路  Source: google/jax

Hi everyone,

I'm asking because I noticed a very weird behaviour in jax-unirep that, even with a good amount of effort trying to narrow down the cause, I cannot seem to find it.

The weird behaviour is minimally reproduced in a github gist, where the output of a jax vmapped function (vmap over "sample" axes) that internally uses a lax.scan (RNN component) oscillates between two values even when the exact same inputs are provided. The behaviour was first reported by @hhefzi on our issue tracker.

I've inserted logging statements to identify whether the same Python objects being passed into the top-level function are being propagated down into the functions that do the RNN math (the answer is that they are), hinting to me that the "state" of the program is correct, i.e. the program consistently passes the correct inputs down from the top-level functions being called.

This leaves the JAX parts of jax-unirep that I'm not quite sure how to debug. Hence the motivation for the question: is there caching involved in vmap or lax.scan? If so, that _might_ explain the oscillatory behaviour.

If not, I'm not sure what else could be causing the issue; I'm kind of running dry on ideas, already poured in a few of the best hours of my brainpower to try to get to the root of this issue. Might you all have some alternative ideas?

question

Most helpful comment

Hey Eric!

vmap doesn't do any caching. lax.scan has caching so that scanning the same function twice over inputs with the same shapes doesn't recompile; it's like jit in that way.

The caching is sound so long as the functions involved are pure, i.e. they don't have side effects. Could there be side effects in the unirep code?

You can disable lax.scan's caching by commenting out the @cache() decorators here and here. You can further disable all jit caching by commenting out the decorators here and here and here. That might help debug whether caching is an issue.

All 6 comments

Hey Eric!

vmap doesn't do any caching. lax.scan has caching so that scanning the same function twice over inputs with the same shapes doesn't recompile; it's like jit in that way.

The caching is sound so long as the functions involved are pure, i.e. they don't have side effects. Could there be side effects in the unirep code?

You can disable lax.scan's caching by commenting out the @cache() decorators here and here. You can further disable all jit caching by commenting out the decorators here and here and here. That might help debug whether caching is an issue.

Thanks @mattjj! I have to prepare for a presentation tonight, so I'll be sure to get to this tomorrow morning!

If I may ask, I've already checked that no globals are used in unirep functions. Additionally, all data loading is done outside of the function, and are passed as arguments into functions. Closures are the tricky part for me, but I think we have done them "correctly" without side effects. Are there other patterns that I've missed that could induce side effects too?

If you're using numpy.random, then sampling random values has side-effects on the RNG state.

I think Python closures themselves are immutable, but they can point to mutable objects like lists.

Can you help me repro the issue in a Python file rather than a notebook or gist? In particular, in the code on your issue tracker I'm not sure whether it's easy to get the file used by load_params(folderpath="20200707/iter_1/"). (I'm happy to install jax-unirep to make the repro work!)

I think this is should work without any additional files needed (with jax-unirep installed). Same oscillatory behavior.

from jax_unirep import fit
from jax_unirep.utils import load_params_1900
from jax_unirep import get_reps
import numpy as np

sequences = ["MKLVIPJ", "MMLVIKJP", "MKLVIJJ"]

params = fit(params=None, sequences = sequences, n_epochs = 10)

mut_seq = sequences[0]

for i in range(0,6):
    print('------Iteration {}------'.format(i))
    print('Default parameters-sum of embeddings: {}'.format(np.sum(get_reps(mut_seq)[0])))
    print('Custom parameters-sum of embeddings: {}'.format(np.sum(get_reps(mut_seq,params=params[0])[0])))

Output:
------Iteration 0------
Default parameters-sum of embeddings: 220.5883331298828
Custom parameters-sum of embeddings: 220.5883331298828
------Iteration 1------
Default parameters-sum of embeddings: 220.5883331298828
Custom parameters-sum of embeddings: 218.6575164794922
------Iteration 2------
Default parameters-sum of embeddings: 220.5883331298828
Custom parameters-sum of embeddings: 220.5883331298828
------Iteration 3------
Default parameters-sum of embeddings: 220.5883331298828
Custom parameters-sum of embeddings: 218.6575164794922
------Iteration 4------
Default parameters-sum of embeddings: 220.5883331298828
Custom parameters-sum of embeddings: 220.5883331298828
------Iteration 5------
Default parameters-sum of embeddings: 220.5883331298828
Custom parameters-sum of embeddings: 218.65750122070312

Looping back here, I installed jax in development mode in my jax-unirep development environment. Commenting out the two cache decorators didn't solve the oscillation problem. Digging further...

Ohhh my, I've broken my eyeballs closely inspecting the inputs and outputs surrounding https://github.com/ElArkk/jax-unirep/blob/master/jax_unirep/layers.py#L163, and none of them oscillate, indicating to me that I think we have very pure functions that don't depend on global state.

    step_func = partial(mLSTM1900_step, params)
    (h_final, c_final), outputs = lax.scan(
        step_func, init=(h_t, c_t), xs=batch
    )

In summary, I've inspected h_t, c_t, batch, and params, and they are constant over run to run. The only thing that oscillates is h_final, c_final and outputs. @mattjj, I'm a bit at a loss here. Inspecting into lax.scan is something that is a bit "scary" for me at the moment. Do you have any suggestions?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  26Comments

martiningram picture martiningram  路  21Comments

dionhaefner picture dionhaefner  路  22Comments

proteneer picture proteneer  路  53Comments

murphyk picture murphyk  路  31Comments