jax.random.randint silently returns zero for large values of maxval

Created on 15 Jan 2020  路  10Comments  路  Source: google/jax

key_1 = jax.random.PRNGKey(1)
key_2 = jax.random.PRNGKey(2)
print("output for key_1:", jax.random.randint(key_1, (), 0, 2**31))
print("output for key_2:", jax.random.randint(key_2, (), 0, 2**31))

yields

output for key_1: 0
output for key_2: 0

And indeed, the function returns 0 for any key you pass it, so long as maxval >= 2**31.

I guess this is an issue with the dtype, since maxval=2**31-1 works as expected.

better_errors good first issue

All 10 comments

You have an integer overflow.

There are two things going on. The first is that JAX disables 64-bit numbers by default, so unless you specify the right configuration options, you only get 32-bit numbers to work with. See:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision
This is something we already plan to fix, but it is the current state.

The second is that the default type of randint is a signed 32-bit integer. 2 ** 31 is out of range for an int32. So you can either use a different limit, or if 64-bit mode is enabled, you can pass dtype=jnp.int64 to get a 64-bit integer.

Does that answer the question?

I'm judging by the word "silently" that part of the question is whether it's feasible to raise an error here, at least in some cases (like when upper is a concrete value, as here, so the error is compile/trace-time detectable).

WDYT of that part of the issue @hawkinsp? IIUC we could add an error check for this specific case, though value-dependent overflows would be silent I'm guessing.

In this specific case (converting an out-of-range Python scalar to a JAX int32 array) we can and probably should warn or error.

@mattjj I'm interested in fixing this ! Could you give me some pointer to where I should add the error ?

Thanks for the responses. You're right - I'm mostly asking if it's possible to raise some sort of error or warning in this case. Apologies for not making that clear from the get-go :)

@TuanNguyen27 woo awesome!

I think it would have to be in this jax.random.randint function. The tricky bit in cases like this can be to detect whether minval/maxval are known at trace/compile time, in which case we have the option to raise this error.

Another layer of complexity is that minval / maxval can be arrays, not just (Python) scalars. Maybe we should start with only doing this check when the arguments are Python scalars, since that's the case where we expect the conversion might overflow (since Python has bignums).

Luckily, checking that the type of minval/maxval is a Python integer will also tell us that it's known at trace/compile time. That is, if we do something like isinstance(maxval, int) or isinstance(maxval, types.IntType), we'll only get a True if the value is a Python scalar and, by necessity, known at trace/compile time. So once we do that check, we can also check its value against the dtype we're about to convert it to and raise this error.

WDYT? (I could be mistaken about any part of this, but it's my current understanding!)

@TuanNguyen27 actually, @hawkinsp suggested to me out-of-band that we should catch this error more globally, whenever we are creating an int32 value from a Python scalar. That's much smarter! Let me find a better place to put this logic...

Agreed, I was going to comment that there are a few other places in random.py where minval/maxval are used, but addressing this issue globally solves that !

Doing this globally at PyInt -> int32 conversion sounds fantastic!

One additional interesting thing about this case that I'm finding out by continuing to debug the code where this came up originally:

randint returns in the half-open interval [minval, maxval), so the fact that maxval < 2**31 means that randint presumably can never return 2**31-1, which is an int32 and so should be obtainable. Not sure if this would realistically ever be important anywhere, so I wouldn't recommend spending a lot of effort fixing it, but something that I just noticed which is a consequence of this limitation.

There are two ways Python scalars get consumed by XLA computations: as arguments (to either jitted functions or op-by-op primitives) and as closed-over constants to jitted functions.

Arguments are handled by xla.device_put (actually a totally different function from jax.device_put), which first calls xla.canonicalize_dtype and then a device_put handler. There's a device_put handler for Python scalars (which ultimately calls dtypes.coerce_to_array), but actually I think it's the canonicalize_dtype step that we want to change, since that's the thing that down-casts to the new dtype. (We actually want to delete xla.canonicalize_dtype and all these handlers for it, since IIUC the new jax.dtypes module will provide a better solution, but for now we still have xla.canonicalize_dtype to deal with.) The canonicalize_dtype handler for scalars ultimately calls onp.asarray after callingdtypes.canonicalize_dtype (could it have instead just called dtypes.coerce_to_array? I think maybe yes...), so that's where we'd do the PyInt -> ndarray[int32] conversion and where we'd want this check for arguments.

For closed-over constant scalars, those will be literals in jaxprs, which are staged into XLA here. You'll notice that also calls xla.canonicalize_dtype, so as long as we make that work we should be good to go. In the case that scalar constants aren't inlined as literals, they'd be handled by this other line, which I was surprised didn't itself call canonicalize_dtype until I noticed our xla_bridge._JaxprComputationBuilder.Constant also has its own set of handlers, for which _python_scalar_handler` would be used. This code could probably be de-duplicated.

Here's the upshot of what we'd need to do:

I think among all those only the first is actually "used", in that it'll form a dtype-canonicalized ndarray out of a Python scalar before xla._device_put_scalar is ever called, and in that given our convention for inlining Python scalars as literals in jaxprs I don't think we call xla_builder._JaxComputationBuilder.Constant on any Python scalars directly. But we should probably update all of these to be sure. I think they're the only ways Python scalars could be handed off to an XLA computation.

Of course, if you want to de-duplicate any potentially-redundant logic here, that would be welcome too :P Otherwise I think @hawkinsp or I will get to it at some point (@hawkinsp cleaned parts of this up most recently).

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  路  3Comments

sussillo picture sussillo  路  3Comments

clemisch picture clemisch  路  3Comments

lonelykid picture lonelykid  路  3Comments

rdaems picture rdaems  路  3Comments