html_url,issue_url,id,node_id,user,created_at,updated_at,author_association,body,reactions,performed_via_github_app,issue https://github.com/pydata/xarray/issues/3232#issuecomment-851581057,https://api.github.com/repos/pydata/xarray/issues/3232,851581057,MDEyOklzc3VlQ29tbWVudDg1MTU4MTA1Nw==,14808389,2021-05-31T16:12:35Z,2021-06-01T20:01:07Z,MEMBER,"changing the `xarray` internals is not too much work: we need to get `xarray.core.utils.is_duck_array` to return true if the object has either `__array_namespace__` or `__array_ufunc__` and `__array_function__` (or all three) defined, and we'd need a short test demonstrating that objects that implement only `__array_namespace__` survive unchanged when wrapped by a `xarray` object (i.e. something like `isinstance(xr.DataArray(pytorch_object).mean().data, pytorch.Tensor)`). We might still be a bit too early with this, though: the PR which adds `__array_namespace__` to `numpy` has not been merged into `numpy:main` yet.","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-851426576,https://api.github.com/repos/pydata/xarray/issues/3232,851426576,MDEyOklzc3VlQ29tbWVudDg1MTQyNjU3Ng==,14808389,2021-05-31T11:32:05Z,2021-05-31T11:32:05Z,MEMBER,"I don't, unfortunately (there's the partial example in https://github.com/pydata/xarray/issues/3232#issuecomment-769789746, though). This is nothing usable right now, but the `pytorch` maintainers are currently looking into providing support for `__array_namespace__` (NEP47). Once there has been sufficient progress in both [`numpy`](https://github.com/numpy/numpy/pull/18585) and [`pytorch`](https://github.com/pytorch/pytorch/issues/58743) we don't have to change much in xarray (i.e. allowing `__array_namespace__` instead of `__array_ufunc__` / `_array_function__` for duck arrays) to make this work without any wrapper code. You (or anyone interested) might still want to maintain a ""pytorch-xarray"" convenience library to allow something like `arr.torch.grad(dim=""x"")`.","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-786599239,https://api.github.com/repos/pydata/xarray/issues/3232,786599239,MDEyOklzc3VlQ29tbWVudDc4NjU5OTIzOQ==,14808389,2021-02-26T11:47:55Z,2021-02-26T11:48:09Z,MEMBER,@Duane321: with `xarray>=0.17.0` you should be able to remove the `__getattributes__` trick.,"{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-771066618,https://api.github.com/repos/pydata/xarray/issues/3232,771066618,MDEyOklzc3VlQ29tbWVudDc3MTA2NjYxOA==,14808389,2021-02-01T18:34:00Z,2021-02-01T23:39:51Z,MEMBER,"I can't reproduce that: ```python In [4]: da.loc[""a1""] Out[4]: tensor([0.4793, 0.7493], dtype=torch.float32) Coordinates: a I added a `torch_array_type` to `pycompat.py` `torch.Tensor` defines `values`, so the issue is this: https://github.com/pydata/xarray/blob/8cc34cb412ba89ebca12fc84f76a9e452628f1bc/xarray/core/variable.py#L221 @shoyer, any ideas? For now, I guess we can remove it using `__getattribute__`. With that you will have to cast the data first if you want to access `torch.Tensor.values`: ```python torch.Tensor(tensor).values() ``` Not sure if that's the best way, but that would look like this:
pytorch wrapper class ```python In [13]: import numpy as np ...: import torch ...: from typing import Tuple ...: import xarray as xr ...: import functools ...: ...: def wrap_torch(f): ...: @functools.wraps(f) ...: def wrapper(*args, **kwargs): ...: # TODO: use a dict comprehension if there are functions that rely on the order of the parameters ...: if ""axis"" in kwargs: ...: kwargs[""dim""] = kwargs.pop(""axis"") # torch calls that parameter 'dim' instead of 'axis' ...: ...: return f(*args, **kwargs) ...: ...: return wrapper ...: ...: class DTypeWrapper: ...: def __init__(self, dtype): ...: self.dtype = dtype ...: if dtype.is_complex: ...: self.kind = ""c"" ...: elif dtype.is_floating_point: ...: self.kind = ""f"" ...: else: ...: # I don't know pytorch at all, so falling back to ""i"" might not be the best choice ...: self.kind = ""i"" ...: ...: def __getattr__(self, name): ...: return getattr(self.dtype, name) ...: ...: def __repr__(self): ...: return repr(self.dtype) ...: ...: IMPLEMENTED_FUNCTIONS = { ...: np.mean: wrap_torch(torch.mean), ...: np.nanmean: wrap_torch(torch.mean), # not sure if pytorch has a separate nanmean function ...: } ...: ...: class XArrayTensor(torch.Tensor): ...: def __new__(cls, data=None, requires_grad=False): ...: if data is None: ...: data = torch.Tensor() ...: return torch.Tensor._make_subclass(cls, data, requires_grad) ...: ...: def __init__(self, data=None, dims: Tuple[str] = None): ...: self.dims = dims ...: ...: def __array_function__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplemented ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __array_ufunc__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplementedError ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __getattribute__(self, name): ...: if name == ""values"": ...: raise AttributeError( ...: ""'values' has been removed for compatibility with xarray."" ...: "" To access it, use `torch.Tensor(tensor).values()`."" ...: ) ...: return object.__getattribute__(self, name) ...: ...: @property ...: def shape(self): ...: return tuple(super().shape) ...: ...: @property ...: def dtype(self): ...: return DTypeWrapper(super().dtype) ...: ...: tensor = XArrayTensor(torch.rand(3, 2)) ...: display(tensor) ...: display(tensor.shape) ...: display(tensor.dtype) ...: display(tensor.ndim) ...: ...: da = xr.DataArray(tensor, coords={""a"": [""a1"", ""a2"", ""a3""], ""b"": [""b1"", ""b2""]}, dims=[""a"", ""b""]) ...: display(da) ...: display(da.data) ...: display(da.mean(dim=""a"")) ```
with that, I can execute `mean` and get back a `torch.Tensor` wrapped by a `DataArray` without modifying the `xarray` code. For a list of features where duck arrays are not supported, yet, see [Working with numpy-like arrays](https://xarray.pydata.org/en/stable/duckarrays.html) (that list should be pretty complete, but if you think there's something missing please open a new issue). For `np.mean(da)`: be aware that `DataArray` does not define `__array_function__`, yet (see #3917), and that with it you have to fall back to `np.mean(da, axis=0)` instead of `da.mean(dim=""a"")`. > If the API matching were complete, would the following be possible? no, it won't be because this is fragile: any new method of `DataArray` could shadow the methods of the wrapped object. Also, without tight integration `xarray` does not know what to do with the result, so you would always get the underlying data instead of a new `DataArray`. Instead, we recommend extension packages ([extending xarray](https://xarray.pydata.org/en/stable/internals.html#extending-xarray)), so with a hypothetical `xarray-pytorch` library you would write `some_sum.torch.backward()` instead of `some_sum.backward()`. That is a bit more work, but it also gives you a lot more control. For an example, see [pint-xarray](https://github.com/xarray-contrib/pint-xarray).","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307 https://github.com/pydata/xarray/issues/3232#issuecomment-766470557,https://api.github.com/repos/pydata/xarray/issues/3232,766470557,MDEyOklzc3VlQ29tbWVudDc2NjQ3MDU1Nw==,14808389,2021-01-25T00:33:35Z,2021-01-25T00:33:35Z,MEMBER,"> Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES. defining `__array_function__` (and the other properties listed in the [docs](https://xarray.pydata.org/en/latest/internals.html)) should be enough: https://github.com/pydata/xarray/blob/a0c71c1508f34345ad7eef244cdbbe224e031c1b/xarray/core/variable.py#L232-L235 ","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307