html_url,issue_url,id,node_id,user,created_at,updated_at,author_association,body,reactions,performed_via_github_app,issue
https://github.com/pydata/xarray/issues/3232#issuecomment-851581057,https://api.github.com/repos/pydata/xarray/issues/3232,851581057,MDEyOklzc3VlQ29tbWVudDg1MTU4MTA1Nw==,14808389,2021-05-31T16:12:35Z,2021-06-01T20:01:07Z,MEMBER,"changing the `xarray` internals is not too much work: we need to get `xarray.core.utils.is_duck_array` to return true if the object has either `__array_namespace__` or `__array_ufunc__` and `__array_function__` (or all three) defined, and we'd need a short test demonstrating that objects that implement only `__array_namespace__` survive unchanged when wrapped by a `xarray` object (i.e. something like `isinstance(xr.DataArray(pytorch_object).mean().data, pytorch.Tensor)`).
We might still be a bit too early with this, though: the PR which adds `__array_namespace__` to `numpy` has not been merged into `numpy:main` yet.","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-851426576,https://api.github.com/repos/pydata/xarray/issues/3232,851426576,MDEyOklzc3VlQ29tbWVudDg1MTQyNjU3Ng==,14808389,2021-05-31T11:32:05Z,2021-05-31T11:32:05Z,MEMBER,"I don't, unfortunately (there's the partial example in https://github.com/pydata/xarray/issues/3232#issuecomment-769789746, though).
This is nothing usable right now, but the `pytorch` maintainers are currently looking into providing support for `__array_namespace__` (NEP47). Once there has been sufficient progress in both [`numpy`](https://github.com/numpy/numpy/pull/18585) and [`pytorch`](https://github.com/pytorch/pytorch/issues/58743) we don't have to change much in xarray (i.e. allowing `__array_namespace__` instead of `__array_ufunc__` / `_array_function__` for duck arrays) to make this work without any wrapper code.
You (or anyone interested) might still want to maintain a ""pytorch-xarray"" convenience library to allow something like `arr.torch.grad(dim=""x"")`.","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-786599239,https://api.github.com/repos/pydata/xarray/issues/3232,786599239,MDEyOklzc3VlQ29tbWVudDc4NjU5OTIzOQ==,14808389,2021-02-26T11:47:55Z,2021-02-26T11:48:09Z,MEMBER,@Duane321: with `xarray>=0.17.0` you should be able to remove the `__getattributes__` trick.,"{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-771066618,https://api.github.com/repos/pydata/xarray/issues/3232,771066618,MDEyOklzc3VlQ29tbWVudDc3MTA2NjYxOA==,14808389,2021-02-01T18:34:00Z,2021-02-01T23:39:51Z,MEMBER,"I can't reproduce that:
```python
In [4]: da.loc[""a1""]
Out[4]:
tensor([0.4793, 0.7493], dtype=torch.float32)
Coordinates:
a I added a `torch_array_type` to `pycompat.py`
`torch.Tensor` defines `values`, so the issue is this: https://github.com/pydata/xarray/blob/8cc34cb412ba89ebca12fc84f76a9e452628f1bc/xarray/core/variable.py#L221
@shoyer, any ideas?
For now, I guess we can remove it using `__getattribute__`. With that you will have to cast the data first if you want to access `torch.Tensor.values`:
```python
torch.Tensor(tensor).values()
```
Not sure if that's the best way, but that would look like this:
pytorch wrapper class
```python
In [13]: import numpy as np
...: import torch
...: from typing import Tuple
...: import xarray as xr
...: import functools
...:
...: def wrap_torch(f):
...: @functools.wraps(f)
...: def wrapper(*args, **kwargs):
...: # TODO: use a dict comprehension if there are functions that rely on the order of the parameters
...: if ""axis"" in kwargs:
...: kwargs[""dim""] = kwargs.pop(""axis"") # torch calls that parameter 'dim' instead of 'axis'
...:
...: return f(*args, **kwargs)
...:
...: return wrapper
...:
...: class DTypeWrapper:
...: def __init__(self, dtype):
...: self.dtype = dtype
...: if dtype.is_complex:
...: self.kind = ""c""
...: elif dtype.is_floating_point:
...: self.kind = ""f""
...: else:
...: # I don't know pytorch at all, so falling back to ""i"" might not be the best choice
...: self.kind = ""i""
...:
...: def __getattr__(self, name):
...: return getattr(self.dtype, name)
...:
...: def __repr__(self):
...: return repr(self.dtype)
...:
...: IMPLEMENTED_FUNCTIONS = {
...: np.mean: wrap_torch(torch.mean),
...: np.nanmean: wrap_torch(torch.mean), # not sure if pytorch has a separate nanmean function
...: }
...:
...: class XArrayTensor(torch.Tensor):
...: def __new__(cls, data=None, requires_grad=False):
...: if data is None:
...: data = torch.Tensor()
...: return torch.Tensor._make_subclass(cls, data, requires_grad)
...:
...: def __init__(self, data=None, dims: Tuple[str] = None):
...: self.dims = dims
...:
...: def __array_function__(self, func, types, args, kwargs):
...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types):
...: return NotImplemented
...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs)
...:
...: def __array_ufunc__(self, func, types, args, kwargs):
...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types):
...: return NotImplementedError
...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs)
...:
...: def __getattribute__(self, name):
...: if name == ""values"":
...: raise AttributeError(
...: ""'values' has been removed for compatibility with xarray.""
...: "" To access it, use `torch.Tensor(tensor).values()`.""
...: )
...: return object.__getattribute__(self, name)
...:
...: @property
...: def shape(self):
...: return tuple(super().shape)
...:
...: @property
...: def dtype(self):
...: return DTypeWrapper(super().dtype)
...:
...: tensor = XArrayTensor(torch.rand(3, 2))
...: display(tensor)
...: display(tensor.shape)
...: display(tensor.dtype)
...: display(tensor.ndim)
...:
...: da = xr.DataArray(tensor, coords={""a"": [""a1"", ""a2"", ""a3""], ""b"": [""b1"", ""b2""]}, dims=[""a"", ""b""])
...: display(da)
...: display(da.data)
...: display(da.mean(dim=""a""))
```
with that, I can execute `mean` and get back a `torch.Tensor` wrapped by a `DataArray` without modifying the `xarray` code. For a list of features where duck arrays are not supported, yet, see [Working with numpy-like arrays](https://xarray.pydata.org/en/stable/duckarrays.html) (that list should be pretty complete, but if you think there's something missing please open a new issue).
For `np.mean(da)`: be aware that `DataArray` does not define `__array_function__`, yet (see #3917), and that with it you have to fall back to `np.mean(da, axis=0)` instead of `da.mean(dim=""a"")`.
> If the API matching were complete, would the following be possible?
no, it won't be because this is fragile: any new method of `DataArray` could shadow the methods of the wrapped object. Also, without tight integration `xarray` does not know what to do with the result, so you would always get the underlying data instead of a new `DataArray`.
Instead, we recommend extension packages ([extending xarray](https://xarray.pydata.org/en/stable/internals.html#extending-xarray)), so with a hypothetical `xarray-pytorch` library you would write `some_sum.torch.backward()` instead of `some_sum.backward()`. That is a bit more work, but it also gives you a lot more control. For an example, see [pint-xarray](https://github.com/xarray-contrib/pint-xarray).","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-766470557,https://api.github.com/repos/pydata/xarray/issues/3232,766470557,MDEyOklzc3VlQ29tbWVudDc2NjQ3MDU1Nw==,14808389,2021-01-25T00:33:35Z,2021-01-25T00:33:35Z,MEMBER,"> Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES.
defining `__array_function__` (and the other properties listed in the [docs](https://xarray.pydata.org/en/latest/internals.html)) should be enough: https://github.com/pydata/xarray/blob/a0c71c1508f34345ad7eef244cdbbe224e031c1b/xarray/core/variable.py#L232-L235
","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307