I've been experimenting with distributed recently and have run into an issue when saving a result directly to a file using the netcdf4 engine. I've found if I compute things before saving to a file (thus loading the result into memory before calling to_netcdf) things work OK. I attached a minimum working example below.
Can others reproduce this? Part of me thinks there must be something wrong with my setup, because I'm somewhat surprised something like this wouldn't have come up already (apologies in advance if that's the case).
In [1]: import dask
In [2]: import distributed
In [3]: import netCDF4
In [4]: import xarray as xr
In [5]: dask.__version__
Out[5]: '0.15.0'
In [6]: distributed.__version__
Out[6]: '1.17.1'
In [7]: netCDF4.__version__
Out[7]: '1.2.9'
In [8]: xr.__version__
Out[8]: '0.9.6'
In [9]: da = xr.DataArray([1., 2., 3.])
In [10]: da.to_netcdf('no-dask.nc')
In [11]: da.chunk().to_netcdf('dask.nc') # Not using distributed yet
In [12]: c = distributed.Client() # Launch a LocalCluster (now using distributed)
In [13]: c
Out[13]: <Client: scheduler='tcp://127.0.0.1:44576' processes=16 cores=16>
In [14]: da.chunk().to_netcdf('dask-distributed-netcdf4.nc', engine='netcdf4')
---------------------------------------------------------------------------
EOFError Traceback (most recent call last)
<ipython-input-14-98490239a35f> in <module>()
----> 1 da.chunk().to_netcdf('dask-distributed-netcdf4.nc', engine='netcdf4')
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/core/dataarray.py in to_netcdf(self, *args, **kwargs)
1349 dataset = self.to_dataset()
1350
-> 1351 dataset.to_netcdf(*args, **kwargs)
1352
1353 def to_dict(self):
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims)
975 return to_netcdf(self, path, mode, format=format, group=group,
976 engine=engine, encoding=encoding,
--> 977 unlimited_dims=unlimited_dims)
978
979 def __unicode__(self):
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, writer, encoding, unlimited_dims)
571 try:
572 dataset.dump_to_store(store, sync=sync, encoding=encoding,
--> 573 unlimited_dims=unlimited_dims)
574 if path_or_file is None:
575 return target.getvalue()
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/core/dataset.py in dump_to_store(self, store, encoder, sync, encoding, unlimited_dims)
916 unlimited_dims=unlimited_dims)
917 if sync:
--> 918 store.sync()
919
920 def to_netcdf(self, path=None, mode='w', format=None, group=None,
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/backends/netCDF4_.py in sync(self)
334 def sync(self):
335 with self.ensure_open(autoclose=True):
--> 336 super(NetCDF4DataStore, self).sync()
337 self.ds.sync()
338
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/backends/common.py in sync(self)
200
201 def sync(self):
--> 202 self.writer.sync()
203
204 def store_dataset(self, dataset):
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/xarray/backends/common.py in sync(self)
177 import dask
178 if LooseVersion(dask.__version__) > LooseVersion('0.8.1'):
--> 179 da.store(self.sources, self.targets, lock=GLOBAL_LOCK)
180 else:
181 da.store(self.sources, self.targets)
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/dask/array/core.py in store(sources, targets, lock, regions, compute, **kwargs)
922 dsk = sharedict.merge((name, updates), *[src.dask for src in sources])
923 if compute:
--> 924 Array._get(dsk, keys, **kwargs)
925 else:
926 from ..delayed import Delayed
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/dask/base.py in _get(cls, dsk, keys, get, **kwargs)
102 get = get or _globals['get'] or cls._default_get
103 dsk2 = optimization_function(cls)(ensure_dict(dsk), keys, **kwargs)
--> 104 return get(dsk2, keys, **kwargs)
105
106 @classmethod
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, **kwargs)
1762 if sync:
1763 try:
-> 1764 results = self.gather(packed)
1765 finally:
1766 for f in futures.values():
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/client.py in gather(self, futures, errors, maxsize, direct)
1261 else:
1262 return self.sync(self._gather, futures, errors=errors,
-> 1263 direct=direct)
1264
1265 @gen.coroutine
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/client.py in sync(self, func, *args, **kwargs)
487 return future
488 else:
--> 489 return sync(self.loop, func, *args, **kwargs)
490
491 def __str__(self):
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/utils.py in sync(loop, func, *args, **kwargs)
232 e.wait(1000000)
233 if error[0]:
--> 234 six.reraise(*error[0])
235 else:
236 return result[0]
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
684 if value.__traceback__ is not tb:
685 raise value.with_traceback(tb)
--> 686 raise value
687
688 else:
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/utils.py in f()
221 raise RuntimeError("sync() called from thread of running loop")
222 yield gen.moment
--> 223 result[0] = yield make_coro()
224 except Exception as exc:
225 logger.exception(exc)
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/tornado/gen.py in run(self)
1013
1014 try:
-> 1015 value = future.result()
1016 except Exception:
1017 self.had_exception = True
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/tornado/concurrent.py in result(self, timeout)
235 return self._result
236 if self._exc_info is not None:
--> 237 raise_exc_info(self._exc_info)
238 self._check_done()
239 return self._result
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/tornado/util.py in raise_exc_info(exc_info)
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/tornado/gen.py in run(self)
1019
1020 if exc_info is not None:
-> 1021 yielded = self.gen.throw(*exc_info)
1022 exc_info = None
1023 else:
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/client.py in _gather(self, futures, errors, direct)
1154 six.reraise(type(exception),
1155 exception,
-> 1156 traceback)
1157 if errors == 'skip':
1158 bad_keys.add(key)
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
683 value = tp()
684 if value.__traceback__ is not tb:
--> 685 raise value.with_traceback(tb)
686 raise value
687
/nbhome/skc/miniconda3/envs/research/lib/python3.6/site-packages/distributed/protocol/pickle.py in loads()
57 def loads(x):
58 try:
---> 59 return pickle.loads(x)
60 except Exception:
61 logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
EOFError: Ran out of input
If I load the data into memory first by invoking compute() things work OK:
In [15]: da.chunk().compute().to_netcdf('dask-distributed-netcdf4.nc', engine='netcdf4')
Hmm. Can you try using scipy as an engine to write the netcdf file?
Honestly I've barely used dask distributed. Possibly @mrocklin has ideas.
It's failing to serialize something in the task graph, I'm not sure what (I'm also surprised that the except clause didn't trigger and log the input). My first guess is that there is an open netcdf file object floating around within the task graph. If so then we should endeavor to avoid doing this (or have some file object proxy that is (de)serializable.
As a short-term workaround you might try starting a local cluster within the same process.
client = Client(processes=False)
This might help you to avoid serialization issues. Generally we should resolve the issue regardless though.
cc'ing @rabernat, who seems to have the most experience here.
@shoyer @mrocklin thanks for your quick responses; I can confirm that both the workarounds you suggested work in my case.
Presumably there is some object in the task graph that we don't know how to serialize. This can be fixed either in XArray, by not including such an object but recreating it each time or wrapping it, or in Dask, by learning how to (de)serialize it.
I'm a little surprised that this doesn't work because I thought we made all our xarray datastore object pickle-able.
The place to start is probably to write an integration test for this functionality. I notice now that our current tests only check reading netCDF files with dask-distributed:
https://github.com/pydata/xarray/blob/master/xarray/tests/test_distributed.py
I did a little bit of digging here, using @mrocklin's Client(processes=False) trick.
The problem seems to be that the arrays that we add to the writer in AbstractWritableDataStore.set_variables are not pickleable. To be more concrete, consider these lines:
https://github.com/pydata/xarray/blob/f83361c76b6aa8cdba8923080bb6b98560cf3a96/xarray/backends/common.py#L221-L232
target is currently a netCDF4.Variable object (or whatever the appropriate backend type is). Anything added to the writer eventually ends up as an argument to dask.array.store and hence gets put into the dask graph. When dask-distributed tries to pickle the dask graph, it fails on the netCDF4.Variable.
What we need to instead is wrap these target arrays in appropriate array wrappers, e.g., NetCDF4ArrayWrapper, adding __setitem__ methods to the array wrappers if needed. Unlike most backend array types, our array wrappers are pickleable, which is essentially for use with dask-distributed.
If anyone's curious, here's the traceback and code I used to debug this:
https://gist.github.com/shoyer/4564971a4d030cd43bba8241d3b36c73
The place to start is probably to write an integration test for this functionality. I notice now that our current tests only check reading netCDF files with dask-distributed:
We should probably also write some tests for saving datasets with save_mfdataset and distributed.