home / github / issues

Menu
  • GraphQL API
  • Search all tables

issues: 671609109

This data as json

id node_id number title user state locked assignee milestone comments created_at updated_at closed_at author_association active_lock_reason draft pull_request body reactions performed_via_github_app state_reason repo type
671609109 MDU6SXNzdWU2NzE2MDkxMDk= 4300 General curve fitting method 35968931 closed 0     9 2020-08-02T12:35:49Z 2021-03-31T16:55:53Z 2021-03-31T16:55:53Z MEMBER      

Xarray should have a general curve-fitting function as part of its main API.

Motivation

Yesterday I wanted to fit a simple decaying exponential function to the data in a DataArray and realised there currently isn't an immediate way to do this in xarray. You have to either pull out the .values (losing the power of dask), or use apply_ufunc (complicated).

This is an incredibly common, domain-agnostic task, so although I don't think we should support various kinds of unusual optimisation procedures (which could always go in an extension package instead), I think a basic fitting method is within scope for the main library. There are SO questions asking how to achieve this.

We already have .polyfit and polyval anyway, which are more specific. (@AndrewWilliams3142 and @aulemahal I expect you will have thoughts on how implement this generally.)

Proposed syntax

I want something like this to work:

```python def exponential_decay(xdata, A=10, L=5): return A*np.exp(-xdata/L)

returns a dataset containing the optimised values of each parameter

fitted_params = da.fit(exponential_decay)

fitted_line = exponential_decay(da.x, A=fitted_params['A'], L=fitted_params['L'])

Compare

da.plot(ax) fitted_line.plot(ax) ```

It would also be nice to be able to fit in multiple dimensions. That means both for example fitting a 2D function to 2D data:

```python def hat(xdata, ydata, h=2, r0=1): r = xdata2 + ydata2 return h*np.exp(-r/r0)

fitted_params = da.fit(hat)

fitted_hat = hat(da.x, da.y, h=fitted_params['h'], r0=fitted_params['r0']) ```

but also repeatedly fitting a 1D function to 2D data:

```python

da now has a y dimension too

fitted_params = da.fit(exponential_decay, fit_along=['x'])

As fitted_params now has y-dependence, broadcasting means fitted_lines does too

fitted_lines = exponential_decay(da.x, A=fitted_params.A, L=fitted_params.L) `` The latter would be useful for fitting the same curve to multiple model runs, but means we need some kind offit_alongordim` argument, which would default to all dims.

So the method docstring would end up like ```python def fit(self, f, fit_along=None, skipna=None, full=False, cov=False): """ Fits the function f to the DataArray.

Expects the function f to have a signature like
`result = f(*coords, **params)`
for example
`result_da = f(da.xcoord, da.ycoord, da.zcoord, A=5, B=None)`
The names of the `**params` kwargs will be used to name the output variables.

Returns
-------
fit_results - A single dataset which contains the variables (for each parameter in the fitting function):
`param1`
    The optimised fit coefficients for parameter one.
`param1_residuals`
    The residuals of the fit for parameter one.
...
"""

```

Questions

1) Should it wrap scipy.optimise.curve_fit, or reimplement it?

Wrapping it is simpler, but as it just calls `least_squares` [under the hood](https://github.com/scipy/scipy/blob/v1.5.2/scipy/optimize/minpack.py#L532-L834) then reimplementing it would mean we could use the dask-powered version of `least_squares` (like [`da.polyfit does`](https://github.com/pydata/xarray/blob/9058114f70d07ef04654d1d60718442d0555b84b/xarray/core/dataset.py#L5987)).

2) What form should we expect the curve-defining function to come in?

`scipy.optimize.curve_fit` expects the curve to act as `ydata = f(xdata, *params) + eps`, but in xarray then `xdata` could be one or multiple coords or dims, not necessarily a single array. Might it work to require a signature like `result_da = f(da.xcoord, da.ycoord, da.zcoord, ..., **params)`? Then the `.fit` method would be work out how many coords to pass to `f` based on the dimension of the `da` and the `fit_along` argument. But then the order of coord arguments in the signature of `f` would matter, which doesn't seem very xarray-like.

3) Is it okay to inspect parameters of the curve-defining function?

If we tell the user the curve-defining function has to have a signature like `da = func(*coords, **params)`, then we could read the names of the parameters by inspecting the function kwargs. Is that a good idea or might it end up being unreliable? Is the `inspect` standard library module the right thing to use for that? This could also be used to provide default guesses for the fitting parameters.
{
    "url": "https://api.github.com/repos/pydata/xarray/issues/4300/reactions",
    "total_count": 4,
    "+1": 3,
    "-1": 0,
    "laugh": 0,
    "hooray": 0,
    "confused": 0,
    "heart": 0,
    "rocket": 0,
    "eyes": 1
}
  completed 13221727 issue

Links from other tables

  • 2 rows from issues_id in issues_labels
  • 9 rows from issue in issue_comments
Powered by Datasette · Queries took 0.928ms · About: xarray-datasette