home / github / issue_comments

Menu
  • Search all tables
  • GraphQL API

issue_comments: 769789746

This data as json

html_url issue_url id node_id user created_at updated_at author_association body reactions performed_via_github_app issue
https://github.com/pydata/xarray/issues/3232#issuecomment-769789746 https://api.github.com/repos/pydata/xarray/issues/3232 769789746 MDEyOklzc3VlQ29tbWVudDc2OTc4OTc0Ng== 14808389 2021-01-29T12:57:37Z 2021-01-29T15:22:01Z MEMBER

I added a torch_array_type to pycompat.py

torch.Tensor defines values, so the issue is this: https://github.com/pydata/xarray/blob/8cc34cb412ba89ebca12fc84f76a9e452628f1bc/xarray/core/variable.py#L221 @shoyer, any ideas?

For now, I guess we can remove it using __getattribute__. With that you will have to cast the data first if you want to access torch.Tensor.values: python torch.Tensor(tensor).values()

Not sure if that's the best way, but that would look like this:

<tt>pytorch</tt> wrapper class ```python In [13]: import numpy as np ...: import torch ...: from typing import Tuple ...: import xarray as xr ...: import functools ...: ...: def wrap_torch(f): ...: @functools.wraps(f) ...: def wrapper(*args, **kwargs): ...: # TODO: use a dict comprehension if there are functions that rely on the order of the parameters ...: if "axis" in kwargs: ...: kwargs["dim"] = kwargs.pop("axis") # torch calls that parameter 'dim' instead of 'axis' ...: ...: return f(*args, **kwargs) ...: ...: return wrapper ...: ...: class DTypeWrapper: ...: def __init__(self, dtype): ...: self.dtype = dtype ...: if dtype.is_complex: ...: self.kind = "c" ...: elif dtype.is_floating_point: ...: self.kind = "f" ...: else: ...: # I don't know pytorch at all, so falling back to "i" might not be the best choice ...: self.kind = "i" ...: ...: def __getattr__(self, name): ...: return getattr(self.dtype, name) ...: ...: def __repr__(self): ...: return repr(self.dtype) ...: ...: IMPLEMENTED_FUNCTIONS = { ...: np.mean: wrap_torch(torch.mean), ...: np.nanmean: wrap_torch(torch.mean), # not sure if pytorch has a separate nanmean function ...: } ...: ...: class XArrayTensor(torch.Tensor): ...: def __new__(cls, data=None, requires_grad=False): ...: if data is None: ...: data = torch.Tensor() ...: return torch.Tensor._make_subclass(cls, data, requires_grad) ...: ...: def __init__(self, data=None, dims: Tuple[str] = None): ...: self.dims = dims ...: ...: def __array_function__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplemented ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __array_ufunc__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplementedError ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __getattribute__(self, name): ...: if name == "values": ...: raise AttributeError( ...: "'values' has been removed for compatibility with xarray." ...: " To access it, use `torch.Tensor(tensor).values()`." ...: ) ...: return object.__getattribute__(self, name) ...: ...: @property ...: def shape(self): ...: return tuple(super().shape) ...: ...: @property ...: def dtype(self): ...: return DTypeWrapper(super().dtype) ...: ...: tensor = XArrayTensor(torch.rand(3, 2)) ...: display(tensor) ...: display(tensor.shape) ...: display(tensor.dtype) ...: display(tensor.ndim) ...: ...: da = xr.DataArray(tensor, coords={"a": ["a1", "a2", "a3"], "b": ["b1", "b2"]}, dims=["a", "b"]) ...: display(da) ...: display(da.data) ...: display(da.mean(dim="a")) ```

with that, I can execute mean and get back a torch.Tensor wrapped by a DataArray without modifying the xarray code. For a list of features where duck arrays are not supported, yet, see Working with numpy-like arrays (that list should be pretty complete, but if you think there's something missing please open a new issue).

For np.mean(da): be aware that DataArray does not define __array_function__, yet (see #3917), and that with it you have to fall back to np.mean(da, axis=0) instead of da.mean(dim="a").

If the API matching were complete, would the following be possible?

no, it won't be because this is fragile: any new method of DataArray could shadow the methods of the wrapped object. Also, without tight integration xarray does not know what to do with the result, so you would always get the underlying data instead of a new DataArray.

Instead, we recommend extension packages (extending xarray), so with a hypothetical xarray-pytorch library you would write some_sum.torch.backward() instead of some_sum.backward(). That is a bit more work, but it also gives you a lot more control. For an example, see pint-xarray.

{
    "total_count": 0,
    "+1": 0,
    "-1": 0,
    "laugh": 0,
    "hooray": 0,
    "confused": 0,
    "heart": 0,
    "rocket": 0,
    "eyes": 0
}
  482543307
Powered by Datasette · Queries took 0.771ms · About: xarray-datasette