Xarray: Unexpected chunking behavior when using `xr.align` with `join='outer'`

Created on 30 May 2020  Â·  6Comments  Â·  Source: pydata/xarray

I just came across some unexpected behavior, when using xr.align with the option join='outer' on two Dataarrays which contain dask.arrays and have different dimension lengths.

MCVE Code Sample

import numpy as np
import xarray as xr

short_time = xr.cftime_range('2000', periods=12)
long_time = xr.cftime_range('2000', periods=120)

data_short = np.random.rand(len(short_time))
data_long = np.random.rand(len(long_time))
a = xr.DataArray(data_short, dims=['time'], coords={'time':short_time}).chunk({'time':3})
b = xr.DataArray(data_long, dims=['time'], coords={'time':long_time}).chunk({'time':3})

a,b = xr.align(a,b, join = 'outer')

Expected Output

As expected a is filled with missing values:

a.plot()
b.plot()

image

But the filled values do not replicate the chunking along the time dimension in b. Instead the padded values are in one single chunk, which can be substantially larger than the others.

a.data

image

b.data

image

(Quick shoutout for the amazing html representation. This made diagnosing this problem super easy! 🥳 )

Problem Description

I think for many problems it would be more appropriate if the padded portion of the array would have a chunking scheme like the longer array.

A practical example (which brought me to this issue) is given in the CMIP6 data archive, where some models give output for several members, with some of them running longer than others, leading to problems when these are combined (see intake-esm/#225).
Basically for that particular model, there are 5 members with a runtime of 100 years and one member with a runtime of 300 years. I think using xr.align leads immediately to a chunk that is 200 years long and blows up the memory on all systems I have tried this on.

Is there a way to work around this, or is this behavior intended and I am missing something?

cc'ing @dcherian @andersy005

Versions

Output of xr.show_versions()

INSTALLED VERSIONS

commit: None
python: 3.8.2 | packaged by conda-forge | (default, Apr 24 2020, 08:20:52)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 3.10.0-1127.el7.x86_64
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.5
libnetcdf: 4.7.4

xarray: 0.15.1
pandas: 1.0.3
numpy: 1.18.4
scipy: 1.4.1
netCDF4: 1.5.3
pydap: None
h5netcdf: 0.8.0
h5py: 2.10.0
Nio: None
zarr: 2.4.0
cftime: 1.1.2
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.1.3
cfgrib: None
iris: None
bottleneck: None
dask: 2.15.0
distributed: 2.15.2
matplotlib: 3.2.1
cartopy: 0.18.0
seaborn: None
numbagg: None
setuptools: 46.1.3.post20200325
pip: 20.1
conda: None
pytest: 5.4.2
IPython: 7.14.0
sphinx: None

upstream issue

Most helpful comment

Rechunking the indexer array is how I would be explicit about the desired chunk size. Opened https://github.com/dask/dask/issues/6270 to discuss this on the dask side.

All 6 comments

Great diagnosis @jbusecke .

Ultimately this comes down to dask indexing

import dask.array

arr = dask.array.from_array([0, 1, 2, 3], chunks=(1,))
print(arr.chunks)  # ((1, 1, 1, 1),)
# align calls reindex which indexes with something like this
indexer = [0, 1, 2, 3, ] + [-1,] * 111
print(arr[indexer].chunks)  # ((1, 1, 1, 112),)

# maybe something like this is a solution
lazy_indexer = dask.array.from_array(indexer, chunks=arr.chunks[0][0], name="idx")
print(arr[lazy_indexer].chunks) # ((1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),)

cc @TomAugspurger, the issue here is that big 112 size chunk takes down the cluster in https://github.com/NCAR/intake-esm/issues/225

Rechunking the indexer array is how I would be explicit about the desired chunk size. Opened https://github.com/dask/dask/issues/6270 to discuss this on the dask side.

Thanks @TomAugspurger

I think an upstream dask solution would be useful.

xarray automatic aligns objects everywhere and this alignment is what is blowing things up. For this reason I think xarray should explicitly chunk the indexer when aligning. We could use a reasonable chunk size like median chunk size of dataarray along that axis — this would respect the user's chunksize choices.

@shoyer What do you think?

The problem with chunking indexers is that then dask doesn't have any visibility into the indexing values, which means the graph now grows like the square of the number of chunks along an axis, instead of proportional to the number of chunks.

The real operation that xarray needs here is Variable._getitem_with_mask, i.e., indexing with -1 remapped to a fill value:
https://github.com/pydata/xarray/blob/e8bd8665e8fd762031c2d9c87987d21e113e41cc/xarray/core/variable.py#L715

The padded portion of the array is used in indexing, but only so the result is aligned for np.where to replace with the fill value. We actually don't look at those values at all.

I don't know the best way to handle this. One option might be to rewrite Dask's indexing functionality to "split" chunks that are much larger than their inputs into smaller pieces, even if they all come from the same input chunk?

One option might be to rewrite Dask's indexing functionality to "split" chunks that are much larger than their inputs into smaller pieces, even if they all come from the same input chunk?

This is Tom's proposed solution in https://github.com/dask/dask/issues/6270

Just tried this with the newest dask version and can confirm that I do not get huge chunks anymore IF i specify dask.config.set({"array.slicing.split_large_chunks": True}). I also needed to modify the example to exceed the internal chunk size limitation:

import numpy as np
import xarray as xr
import dask
dask.config.set({"array.slicing.split_large_chunks": True})

short_time = xr.cftime_range('2000', periods=12)
long_time = xr.cftime_range('2000', periods=120)

data_short = np.random.rand(len(short_time))
data_long = np.random.rand(len(long_time))
n=1000
a = xr.DataArray(data_short, dims=['time'], coords={'time':short_time}).expand_dims(a=n, b=n).chunk({'time':3})
b = xr.DataArray(data_long, dims=['time'], coords={'time':long_time}).expand_dims(a=n, b=n).chunk({'time':3})

a,b = xr.align(a,b, join = 'outer')

with the option turned on I get this for a;
image

with the defaults, I still get one giant chunk.

image

Ill try this soon in a real world scenario described above. Just wanted to report back here.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

d-chambers picture d-chambers  Â·  4Comments

duncanwp picture duncanwp  Â·  4Comments

benbovy picture benbovy  Â·  3Comments

andrewpauling picture andrewpauling  Â·  3Comments

jhamman picture jhamman  Â·  5Comments