html_url,issue_url,id,node_id,user,created_at,updated_at,author_association,body,reactions,performed_via_github_app,issue https://github.com/pydata/xarray/issues/3232#issuecomment-1183301651,https://api.github.com/repos/pydata/xarray/issues/3232,1183301651,IC_kwDOAMm_X85Gh8AT,2448579,2022-07-13T14:31:55Z,2022-07-13T14:32:01Z,MEMBER,"> I'd be happy to turn this into a PR with some tests. Absolutely!","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-851581057,https://api.github.com/repos/pydata/xarray/issues/3232,851581057,MDEyOklzc3VlQ29tbWVudDg1MTU4MTA1Nw==,14808389,2021-05-31T16:12:35Z,2021-06-01T20:01:07Z,MEMBER,"changing the `xarray` internals is not too much work: we need to get `xarray.core.utils.is_duck_array` to return true if the object has either `__array_namespace__` or `__array_ufunc__` and `__array_function__` (or all three) defined, and we'd need a short test demonstrating that objects that implement only `__array_namespace__` survive unchanged when wrapped by a `xarray` object (i.e. something like `isinstance(xr.DataArray(pytorch_object).mean().data, pytorch.Tensor)`). We might still be a bit too early with this, though: the PR which adds `__array_namespace__` to `numpy` has not been merged into `numpy:main` yet.","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-851426576,https://api.github.com/repos/pydata/xarray/issues/3232,851426576,MDEyOklzc3VlQ29tbWVudDg1MTQyNjU3Ng==,14808389,2021-05-31T11:32:05Z,2021-05-31T11:32:05Z,MEMBER,"I don't, unfortunately (there's the partial example in https://github.com/pydata/xarray/issues/3232#issuecomment-769789746, though). This is nothing usable right now, but the `pytorch` maintainers are currently looking into providing support for `__array_namespace__` (NEP47). Once there has been sufficient progress in both [`numpy`](https://github.com/numpy/numpy/pull/18585) and [`pytorch`](https://github.com/pytorch/pytorch/issues/58743) we don't have to change much in xarray (i.e. allowing `__array_namespace__` instead of `__array_ufunc__` / `_array_function__` for duck arrays) to make this work without any wrapper code. You (or anyone interested) might still want to maintain a ""pytorch-xarray"" convenience library to allow something like `arr.torch.grad(dim=""x"")`.","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-786599239,https://api.github.com/repos/pydata/xarray/issues/3232,786599239,MDEyOklzc3VlQ29tbWVudDc4NjU5OTIzOQ==,14808389,2021-02-26T11:47:55Z,2021-02-26T11:48:09Z,MEMBER,@Duane321: with `xarray>=0.17.0` you should be able to remove the `__getattributes__` trick.,"{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-771066618,https://api.github.com/repos/pydata/xarray/issues/3232,771066618,MDEyOklzc3VlQ29tbWVudDc3MTA2NjYxOA==,14808389,2021-02-01T18:34:00Z,2021-02-01T23:39:51Z,MEMBER,"I can't reproduce that: ```python In [4]: da.loc[""a1""] Out[4]: tensor([0.4793, 0.7493], dtype=torch.float32) Coordinates: a I added a `torch_array_type` to `pycompat.py` `torch.Tensor` defines `values`, so the issue is this: https://github.com/pydata/xarray/blob/8cc34cb412ba89ebca12fc84f76a9e452628f1bc/xarray/core/variable.py#L221 @shoyer, any ideas? For now, I guess we can remove it using `__getattribute__`. With that you will have to cast the data first if you want to access `torch.Tensor.values`: ```python torch.Tensor(tensor).values() ``` Not sure if that's the best way, but that would look like this:
pytorch wrapper class ```python In [13]: import numpy as np ...: import torch ...: from typing import Tuple ...: import xarray as xr ...: import functools ...: ...: def wrap_torch(f): ...: @functools.wraps(f) ...: def wrapper(*args, **kwargs): ...: # TODO: use a dict comprehension if there are functions that rely on the order of the parameters ...: if ""axis"" in kwargs: ...: kwargs[""dim""] = kwargs.pop(""axis"") # torch calls that parameter 'dim' instead of 'axis' ...: ...: return f(*args, **kwargs) ...: ...: return wrapper ...: ...: class DTypeWrapper: ...: def __init__(self, dtype): ...: self.dtype = dtype ...: if dtype.is_complex: ...: self.kind = ""c"" ...: elif dtype.is_floating_point: ...: self.kind = ""f"" ...: else: ...: # I don't know pytorch at all, so falling back to ""i"" might not be the best choice ...: self.kind = ""i"" ...: ...: def __getattr__(self, name): ...: return getattr(self.dtype, name) ...: ...: def __repr__(self): ...: return repr(self.dtype) ...: ...: IMPLEMENTED_FUNCTIONS = { ...: np.mean: wrap_torch(torch.mean), ...: np.nanmean: wrap_torch(torch.mean), # not sure if pytorch has a separate nanmean function ...: } ...: ...: class XArrayTensor(torch.Tensor): ...: def __new__(cls, data=None, requires_grad=False): ...: if data is None: ...: data = torch.Tensor() ...: return torch.Tensor._make_subclass(cls, data, requires_grad) ...: ...: def __init__(self, data=None, dims: Tuple[str] = None): ...: self.dims = dims ...: ...: def __array_function__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplemented ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __array_ufunc__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplementedError ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __getattribute__(self, name): ...: if name == ""values"": ...: raise AttributeError( ...: ""'values' has been removed for compatibility with xarray."" ...: "" To access it, use `torch.Tensor(tensor).values()`."" ...: ) ...: return object.__getattribute__(self, name) ...: ...: @property ...: def shape(self): ...: return tuple(super().shape) ...: ...: @property ...: def dtype(self): ...: return DTypeWrapper(super().dtype) ...: ...: tensor = XArrayTensor(torch.rand(3, 2)) ...: display(tensor) ...: display(tensor.shape) ...: display(tensor.dtype) ...: display(tensor.ndim) ...: ...: da = xr.DataArray(tensor, coords={""a"": [""a1"", ""a2"", ""a3""], ""b"": [""b1"", ""b2""]}, dims=[""a"", ""b""]) ...: display(da) ...: display(da.data) ...: display(da.mean(dim=""a"")) ```
with that, I can execute `mean` and get back a `torch.Tensor` wrapped by a `DataArray` without modifying the `xarray` code. For a list of features where duck arrays are not supported, yet, see [Working with numpy-like arrays](https://xarray.pydata.org/en/stable/duckarrays.html) (that list should be pretty complete, but if you think there's something missing please open a new issue). For `np.mean(da)`: be aware that `DataArray` does not define `__array_function__`, yet (see #3917), and that with it you have to fall back to `np.mean(da, axis=0)` instead of `da.mean(dim=""a"")`. > If the API matching were complete, would the following be possible? no, it won't be because this is fragile: any new method of `DataArray` could shadow the methods of the wrapped object. Also, without tight integration `xarray` does not know what to do with the result, so you would always get the underlying data instead of a new `DataArray`. Instead, we recommend extension packages ([extending xarray](https://xarray.pydata.org/en/stable/internals.html#extending-xarray)), so with a hypothetical `xarray-pytorch` library you would write `some_sum.torch.backward()` instead of `some_sum.backward()`. That is a bit more work, but it also gives you a lot more control. For an example, see [pint-xarray](https://github.com/xarray-contrib/pint-xarray).","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-766470557,https://api.github.com/repos/pydata/xarray/issues/3232,766470557,MDEyOklzc3VlQ29tbWVudDc2NjQ3MDU1Nw==,14808389,2021-01-25T00:33:35Z,2021-01-25T00:33:35Z,MEMBER,"> Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES. defining `__array_function__` (and the other properties listed in the [docs](https://xarray.pydata.org/en/latest/internals.html)) should be enough: https://github.com/pydata/xarray/blob/a0c71c1508f34345ad7eef244cdbbe224e031c1b/xarray/core/variable.py#L232-L235 ","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-655751621,https://api.github.com/repos/pydata/xarray/issues/3232,655751621,MDEyOklzc3VlQ29tbWVudDY1NTc1MTYyMQ==,13301940,2020-07-08T20:54:15Z,2020-07-08T20:54:15Z,MEMBER,"> @jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here. I've been test driving xarray objects backed by CuPy arrays, and one issue I keep running into is that operations (such as plotting) that expect numpy arrays fail due to xarray's implicit converstion to Numpy arrays via `np.asarray()`. CuPy decided not to allow implicit conversion to NumPy arrays (see https://github.com/cupy/cupy/pull/3421). I am wondering whether there is a plan for dealing with this issue? Here's a small, reproducible example: ```python [23]: ds.tmin.data.device [24]: ds.isel(time=0, lev=0).tmin.plot() # Fails ```
Traceback ```python --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in ----> 1 ds.isel(time=0, lev=0).tmin.plot() /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in __call__(self, **kwargs) 444 445 def __call__(self, **kwargs): --> 446 return plot(self._da, **kwargs) 447 448 @functools.wraps(hist) /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in plot(darray, row, col, col_wrap, ax, hue, rtol, subplot_kws, **kwargs) 198 kwargs[""ax""] = ax 199 --> 200 return plotfunc(darray, **kwargs) 201 202 /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs) 684 685 # Pass the data as a masked ndarray too --> 686 zval = darray.to_masked_array(copy=False) 687 688 # Replace pd.Intervals if contained in xval or yval. /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy) 2325 Masked where invalid values (nan or inf) occur. 2326 """""" -> 2327 values = self.values # only compute lazy arrays once 2328 isnull = pd.isnull(values) 2329 return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in values(self) 556 def values(self) -> np.ndarray: 557 """"""The array's data as a numpy.ndarray"""""" --> 558 return self.variable.values 559 560 @values.setter /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in values(self) 444 def values(self): 445 """"""The variable's data as a numpy.ndarray"""""" --> 446 return _as_array_or_item(self._data) 447 448 @values.setter /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in _as_array_or_item(data) 247 TODO: remove this (replace with np.asarray) once these issues are fixed 248 """""" --> 249 data = np.asarray(data) 250 if data.ndim == 0: 251 if data.dtype.kind == ""M"": /glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order) 83 84 """""" ---> 85 return array(a, dtype, copy=False, order=order) 86 87 ValueError: object __array__ method not producing an array ```
","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-606230158,https://api.github.com/repos/pydata/xarray/issues/3232,606230158,MDEyOklzc3VlQ29tbWVudDYwNjIzMDE1OA==,2443309,2020-03-30T20:27:32Z,2020-03-30T20:27:32Z,MEMBER,@jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here. ,"{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-606228143,https://api.github.com/repos/pydata/xarray/issues/3232,606228143,MDEyOklzc3VlQ29tbWVudDYwNjIyODE0Mw==,2448579,2020-03-30T20:24:08Z,2020-03-30T20:24:08Z,MEMBER,"Just chiming in quickly. I think there's definitely interest in doing this through NEP-18. It looks like CUDA has implemented `__array_function__` (https://docs-cupy.chainer.org/en/stable/reference/interoperability.html) so many things may ""just work"". There was some work earlier on plugging in `pydata/sparse`, and there is some ongoing work to plug in `pint`. With both these efforts, a lot of xarray's code should be ""backend-agnostic"" but its not perfect. Have you tried creating `DataArrays` with `cupy` arrays yet? I would just try things and see what works vs what doesn't. Practically, our approach so far has been to add a number of xfailed tests (`test_sparse.py` and `test_units.py`) and slowly start fixing them. So that's one way to proceed if you're up for it.","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-524420000,https://api.github.com/repos/pydata/xarray/issues/3232,524420000,MDEyOklzc3VlQ29tbWVudDUyNDQyMDAwMA==,1217238,2019-08-23T18:38:19Z,2019-08-23T18:38:19Z,MEMBER,"I have not thought too much about these yet. But I agree that they will probably require backend specific logic to do efficiently. On Fri, Aug 23, 2019 at 12:13 PM firdaus janoos wrote: > While it is pretty straightforward to implement a lot of standard xarray > operations with a pytorch / Jax backend (since they just fallback on native > functions) - it will be interesting to think about how to implement rolling > operations / expanding / exponential window in a way that is both efficient > and maintains differentiability. > > Expanding and exponential window operations would be easy to do leveraging > RNN semantics - but doing rolling using convolutions is going to be very > inefficient. > > Do you have any thoughts on this? > > — > You are receiving this because you commented. > Reply to this email directly, view it on GitHub > , > or mute the thread > > . > ","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-524403160,https://api.github.com/repos/pydata/xarray/issues/3232,524403160,MDEyOklzc3VlQ29tbWVudDUyNDQwMzE2MA==,1217238,2019-08-23T17:45:54Z,2019-08-23T17:45:54Z,MEMBER,"Within a `jit` compiled function, JAX's execution speed should be quite competitive on GPUs. It uses the XLA compiler, which was recently enabled by default in TensorFlow. For data loading and deep learning algorithms, take a look at the examples in the `notebooks` directory in the JAX repo. The APIs for deep learning in JAX are still undergoing rapid development, so APIs are not quite as stable/usable as pytorch or keras yet, but they are quite capable. See `jax.experimental.stax` and [`tensor2tensor.trax`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax) for examples.","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-522884516,https://api.github.com/repos/pydata/xarray/issues/3232,522884516,MDEyOklzc3VlQ29tbWVudDUyMjg4NDUxNg==,1217238,2019-08-20T07:07:18Z,2019-08-20T07:07:18Z,MEMBER,"> Implementing it has some backwards compat concerns as well, because people may be relying on `np.somefunc(some_torch_tensor)` to be coerced to `ndarray`. Yes, this is a concern for JAX as well. This is a definite downside of reusing NumPy's existing namespace. It turns out even xarray was relying on this behavior with dask in at least one edge case: https://github.com/pydata/xarray/issues/3215","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-522820303,https://api.github.com/repos/pydata/xarray/issues/3232,522820303,MDEyOklzc3VlQ29tbWVudDUyMjgyMDMwMw==,1217238,2019-08-20T01:55:46Z,2019-08-20T01:55:46Z,MEMBER,"If pytorch implements overrides of NumPy's API via the [`__array_function__` protocol](https://www.numpy.org/neps/nep-0018-array-function-protocol.html), then this could work with minimal effort. We are already using this to support [sparse arrays](https://sparse.pydata.org/en/latest/) (this isn't an official release yet, but functionality is working in the development version). I think there has been some discussion about this, but I don't know the current status (CC @rgommers). The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API. Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with [JAX](https://github.com/google/jax), which already implements NumPy's API almost exactly. I have an [experimental pull request](https://github.com/google/jax/pull/611) adding `__array_function__` to JAX, but it still needs a bit of work to finish it up, e.g., we probably want to hide this behind a flag at first.","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307