I'm intending on working on cupy support in xarray along with @quasiben. Thanks for the warm welcome in the xarray dev meeting yesterday!
I'd like to use this issue to track cupy support and discuss certain design decisions. I appreciate there are also issues such as #4208, #3484 and #3232 which are related to cupy support, but maybe this could represent an umbrella issue for cupy specifically.
The main goal here is to improve support for array types other than numpy and dask in general. However, it is likely there will need to be some cupy specific compatibility code in xarray. (@andersy005 raised issues with calling __array__ on cupy in #3232 for example).
I would love to hear from folks wanting to use cupy with xarray to help build up some use cases for us to develop against. We have some ideas but more are welcome.
My first steps here will be to add some tests which use cupy. These will skip in the main CI but we will also look at running xarray tests on some GPU CI too as we develop. A few limited experiments that I've run seem to work, so I'll start with tests which reproduce those.
@jacobtomlinson, thank you for getting this started. I'll be monitoring closely this issue. Let me know if I can help in any way.
This PR for adding pint support is a useful reference. https://github.com/pydata/xarray/pull/3238
@jacobtomlinson Any idea how this would play with the work that's been going on for units here; I'm specifically wondering if xarray ( pint ( cupy )) would/could work.
@jacobtomlinson Any idea how this would play with the work that's been going on for units here; I'm specifically wondering if xarray ( pint ( cupy )) would/could work.
As far as I'd see it, the pieces to get this working are
and then finally testing xarray( pint( cupy )) works automatically from there. https://github.com/hgrecco/pint/issues/964 was deferred due to CI/testing concerns, so it will be great to see what @jacobtomlinson can come up with here for xarray, since hopefully at some point it would be transferable over to pint as well.
_I've written this comment a few times to try and not come across as confrontational. I'm not intending to be at all, so please don't take it that way 😅. Tone is hard in comments! I'm just trying to figure out how to proceed quickly._
I've noticed a diverging theme that seems to be coming up in various conversations (see #3234 and #3245) around API design for alternative array implementations.
It seems to boil down to whether an array implementation has 1st party or 3rd party support within xarray.
For numpy and Dask they appear to be 1st party. They influence the main API of xarray and xarray contains baked in logic to create and work with them.
The work on pint so far points towards it being 3rd party. While I'm sure some compatibility code has gone into xarray much of the logic lives out in an accessor library. Given that pint is extending the numpy API this makes sense.
I initially started this work assuming that cupy would be added as 1st party type, given that it attempts to replicate the numpy API without addition. However I'm not sure this is the right stance.
There are a few questions such as "should .plot coerce cupy arrays to numpy?" (with the conversation in #3234 pointing towards no) which are making me think that perhaps it should be more 3rd party.
I think it would help with API design and speed here if a decision were to be made about cupy (and sparse) being 1st or 3rd party.
Perhaps some core maintainers could weigh in here?
While I'm sure some compatibility code has gone into xarray
actually, I have been able to get by without compatibility code, the code changes outside of test_units are mostly code refactoring (see #3706). This is probably because pint behaves (or tries to behave) like whatever it is wrapping, so that doesn't have to apply to other duck arrays.
While adding support for cupy as a 1st party might be less complicated, I think we should aim for 3rd party because that might be more maintainable, and adding support for similar duck arrays would become easier (I might be missing something, though).
While there are parts where interaction between numpy.ndarray and cupy.ndarray fails (e.g. #4231) – I still don't have any idea how to fix that in a general way – there are also several methods where we want to explicitly convert to numpy (e.g. when plotting), but without using numpy.asarray since that is commonly used to ensure objects are array-like.
Ideally, to make those work we'd have a standard on how to explicitly get the data of a duck array as a numpy array (obj.to_numpy()?). Right now, we have several: for sparse it is named dense() (or something like that), for cupy it is get(), and for pint we have the magnitude property (and I'm sure there are many more).