Xarray should have a general curve-fitting function as part of its main API.
Yesterday I wanted to fit a simple decaying exponential function to the data in a DataArray and realised there currently isn't an immediate way to do this in xarray. You have to either pull out the .values (losing the power of dask), or use apply_ufunc (complicated).
This is an incredibly common, domain-agnostic task, so although I don't think we should support various kinds of unusual optimisation procedures (which could always go in an extension package instead), I think a basic fitting method is within scope for the main library. There are SO questions asking how to achieve this.
We already have .polyfit and polyval anyway, which are more specific. (@AndrewWilliams3142 and @aulemahal I expect you will have thoughts on how implement this generally.)
I want something like this to work:
def exponential_decay(xdata, A=10, L=5):
return A*np.exp(-xdata/L)
# returns a dataset containing the optimised values of each parameter
fitted_params = da.fit(exponential_decay)
fitted_line = exponential_decay(da.x, A=fitted_params['A'], L=fitted_params['L'])
# Compare
da.plot(ax)
fitted_line.plot(ax)
It would also be nice to be able to fit in multiple dimensions. That means both for example fitting a 2D function to 2D data:
def hat(xdata, ydata, h=2, r0=1):
r = xdata**2 + ydata**2
return h*np.exp(-r/r0)
fitted_params = da.fit(hat)
fitted_hat = hat(da.x, da.y, h=fitted_params['h'], r0=fitted_params['r0'])
but also repeatedly fitting a 1D function to 2D data:
# da now has a y dimension too
fitted_params = da.fit(exponential_decay, fit_along=['x'])
# As fitted_params now has y-dependence, broadcasting means fitted_lines does too
fitted_lines = exponential_decay(da.x, A=fitted_params.A, L=fitted_params.L)
The latter would be useful for fitting the same curve to multiple model runs, but means we need some kind of fit_along or dim argument, which would default to all dims.
So the method docstring would end up like
def fit(self, f, fit_along=None, skipna=None, full=False, cov=False):
"""
Fits the function f to the DataArray.
Expects the function f to have a signature like
`result = f(*coords, **params)`
for example
`result_da = f(da.xcoord, da.ycoord, da.zcoord, A=5, B=None)`
The names of the `**params` kwargs will be used to name the output variables.
Returns
-------
fit_results - A single dataset which contains the variables (for each parameter in the fitting function):
`param1`
The optimised fit coefficients for parameter one.
`param1_residuals`
The residuals of the fit for parameter one.
...
"""
1) Should it wrap scipy.optimise.curve_fit, or reimplement it?
Wrapping it is simpler, but as it just calls `least_squares` [under the hood](https://github.com/scipy/scipy/blob/v1.5.2/scipy/optimize/minpack.py#L532-L834) then reimplementing it would mean we could use the dask-powered version of `least_squares` (like [`da.polyfit does`](https://github.com/pydata/xarray/blob/9058114f70d07ef04654d1d60718442d0555b84b/xarray/core/dataset.py#L5987)).
2) What form should we expect the curve-defining function to come in?
`scipy.optimize.curve_fit` expects the curve to act as `ydata = f(xdata, *params) + eps`, but in xarray then `xdata` could be one or multiple coords or dims, not necessarily a single array. Might it work to require a signature like `result_da = f(da.xcoord, da.ycoord, da.zcoord, ..., **params)`? Then the `.fit` method would be work out how many coords to pass to `f` based on the dimension of the `da` and the `fit_along` argument. But then the order of coord arguments in the signature of `f` would matter, which doesn't seem very xarray-like.
3) Is it okay to inspect parameters of the curve-defining function?
If we tell the user the curve-defining function has to have a signature like `da = func(*coords, **params)`, then we could read the names of the parameters by inspecting the function kwargs. Is that a good idea or might it end up being unreliable? Is the `inspect` standard library module the right thing to use for that? This could also be used to provide default guesses for the fitting parameters.
My comments
Q.1 : For now xr.apply_ufunc does not accept core dimensions to be chunked along, which would be kinda a sad for curve fitting. However, dask's least square method does, which is part of the reason why I used it in polyfit.
On the other hand, scipy's least-squares procedure is not simple. Curve fitting is quite complex and rewriting all the code to use dask might be a project too ambitious, and surely out-of-scope for xarray...
Q.3 : For simple directly declared function, inspect does a good job, but it can get tricky with wrapped functions, which might arise in more complex workflows. Could we have a params arg that takes in a list of names?
I am also trying to get similar results of scipy curve_fit with xarray and dask. Is there a workaround I can use to fit a sinusoidal function with the current functions/methods?
This is the function I use to fit a seasonal trend with scipy:
t = 365
def timeseries_function_season (x,a0,a1,a2):
return a0+(a1*np.cos(2*np.pi/t*x)+a2*np.sin(2*np.pi/t*x))
timeseries_model_fit,pcov= curve_fit(timeseries_function_season,x,y)
This sounds very cool! :) I'm not sure that I have much to add, but given @aulemahal 's good point about the complexity of rewriting curve_fit from scratch, it seems that maybe a good first step would just be to wrap the existing scipy functionality?
Alternatively, given that xr.apply_ufunc can already do this (though it's probably complicated), perhaps it would be good to just have an example in the documentation?
+1 for just wrapping the existing functionality in SciPy for now. If we want a version of curve_fitthat supports dask, I would suggest implementingcurve_fit` with dask first, and then using that from xarray.
I am OK with using inspect from the standard library for determining _default_ parameter names. inspect.signature is reasonably robust. But there should definitely be an optional argument for setting parameter names explicitly.
@TomNicholas I'm a bit confused about how the fit_along argument would work actually. If you had 2D data and wanted to fit a 1D function to one of the dimensions, wouldn't you have to either take a mean (or slice?) across the other dimension?
Edit: It's been a hot day here, so apologies if this turns out to be a dumb q haha
@AndrewWilliams3142 fair question: what I was envisaging was taking slices along that dimension(s), performing the curve fitting once for each slice (which should parallelize through apply_ufunc), then returning the optimised fitting parameters as a DataArray/Dataset which varied along that dimension. For example:
# 2D dataarray of surface height with x & t dependence
height_data
def pulse_shape(x, peak_height, peak_location, FWHM):
return peak_height * np.exp(-((x-peak_location)/FWHM)^2.0)
# returned fit_params has t dependence
fit_params = height_data.fit(pulse_shape, fit_along='x')
# Plot a graph of change in peak height over t
fit_params['peak_height'].plot(x='t')
cheers @TomNicholas , that's helpful. :) I've started messing with the idea in this Gist if you want to have a look.
It's pretty hacky at the moment, but might be helpful as a testbed. (And a way of getting my head around how apply_ufunc would work in this context)
cheers @TomNicholas , that's helpful. :) I've started messing with the idea in this Gist if you want to have a look.
It's pretty hacky at the moment, but might be helpful as a testbed. (And a way of getting my head around how
apply_ufuncwould work in this context)
@AndrewWilliams3142 I've tried to extend this to a 3d matrix (timeseries of 2d matrices) using Dask, it seems to work! Have a look here https://gist.github.com/clausmichele/8350e1f7f15e6828f29579914276de71
Most helpful comment
+1 for just wrapping the existing functionality in SciPy for now. If we want a version of curve_fit
that supports dask, I would suggest implementingcurve_fit` with dask first, and then using that from xarray.I am OK with using
inspectfrom the standard library for determining _default_ parameter names.inspect.signatureis reasonably robust. But there should definitely be an optional argument for setting parameter names explicitly.