id,node_id,number,title,user,state,locked,assignee,milestone,comments,created_at,updated_at,closed_at,author_association,active_lock_reason,draft,pull_request,body,reactions,performed_via_github_app,state_reason,repo,type 671609109,MDU6SXNzdWU2NzE2MDkxMDk=,4300,General curve fitting method,35968931,closed,0,,,9,2020-08-02T12:35:49Z,2021-03-31T16:55:53Z,2021-03-31T16:55:53Z,MEMBER,,,,"Xarray should have a general curve-fitting function as part of its main API. ## Motivation Yesterday I wanted to fit a simple decaying exponential function to the data in a DataArray and realised there currently isn't an immediate way to do this in xarray. You have to either pull out the `.values` (losing the power of dask), or use `apply_ufunc` (complicated). This is an incredibly common, domain-agnostic task, so although I don't think we should support various kinds of unusual optimisation procedures (which could always go in an extension package instead), I think a basic fitting method is within scope for the main library. There are [SO questions](https://stackoverflow.com/questions/62987617/using-scipy-curve-fit-with-dask-xarray) asking how to achieve this. We already have [`.polyfit` and `polyval` anyway](https://github.com/pydata/xarray/pull/3733/files#), which are more specific. (@AndrewWilliams3142 and @aulemahal I expect you will have thoughts on how implement this generally.) ## Proposed syntax I want something like this to work: ```python def exponential_decay(xdata, A=10, L=5): return A*np.exp(-xdata/L) # returns a dataset containing the optimised values of each parameter fitted_params = da.fit(exponential_decay) fitted_line = exponential_decay(da.x, A=fitted_params['A'], L=fitted_params['L']) # Compare da.plot(ax) fitted_line.plot(ax) ``` It would also be nice to be able to fit in multiple dimensions. That means both for example fitting a 2D function to 2D data: ```python def hat(xdata, ydata, h=2, r0=1): r = xdata**2 + ydata**2 return h*np.exp(-r/r0) fitted_params = da.fit(hat) fitted_hat = hat(da.x, da.y, h=fitted_params['h'], r0=fitted_params['r0']) ``` but also repeatedly fitting a 1D function to 2D data: ```python # da now has a y dimension too fitted_params = da.fit(exponential_decay, fit_along=['x']) # As fitted_params now has y-dependence, broadcasting means fitted_lines does too fitted_lines = exponential_decay(da.x, A=fitted_params.A, L=fitted_params.L) ``` The latter would be useful for fitting the same curve to multiple model runs, but means we need some kind of `fit_along` or `dim` argument, which would default to all dims. So the method docstring would end up like ```python def fit(self, f, fit_along=None, skipna=None, full=False, cov=False): """""" Fits the function f to the DataArray. Expects the function f to have a signature like `result = f(*coords, **params)` for example `result_da = f(da.xcoord, da.ycoord, da.zcoord, A=5, B=None)` The names of the `**params` kwargs will be used to name the output variables. Returns ------- fit_results - A single dataset which contains the variables (for each parameter in the fitting function): `param1` The optimised fit coefficients for parameter one. `param1_residuals` The residuals of the fit for parameter one. ... """""" ``` ## Questions 1) Should it wrap `scipy.optimise.curve_fit`, or reimplement it? Wrapping it is simpler, but as it just calls `least_squares` [under the hood](https://github.com/scipy/scipy/blob/v1.5.2/scipy/optimize/minpack.py#L532-L834) then reimplementing it would mean we could use the dask-powered version of `least_squares` (like [`da.polyfit does`](https://github.com/pydata/xarray/blob/9058114f70d07ef04654d1d60718442d0555b84b/xarray/core/dataset.py#L5987)). 2) What form should we expect the curve-defining function to come in? `scipy.optimize.curve_fit` expects the curve to act as `ydata = f(xdata, *params) + eps`, but in xarray then `xdata` could be one or multiple coords or dims, not necessarily a single array. Might it work to require a signature like `result_da = f(da.xcoord, da.ycoord, da.zcoord, ..., **params)`? Then the `.fit` method would be work out how many coords to pass to `f` based on the dimension of the `da` and the `fit_along` argument. But then the order of coord arguments in the signature of `f` would matter, which doesn't seem very xarray-like. 3) Is it okay to inspect parameters of the curve-defining function? If we tell the user the curve-defining function has to have a signature like `da = func(*coords, **params)`, then we could read the names of the parameters by inspecting the function kwargs. Is that a good idea or might it end up being unreliable? Is the `inspect` standard library module the right thing to use for that? This could also be used to provide default guesses for the fitting parameters.","{""url"": ""https://api.github.com/repos/pydata/xarray/issues/4300/reactions"", ""total_count"": 4, ""+1"": 3, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 1}",,completed,13221727,issue