issue_comments: 769789746
This data as json
html_url | issue_url | id | node_id | user | created_at | updated_at | author_association | body | reactions | performed_via_github_app | issue |
---|---|---|---|---|---|---|---|---|---|---|---|
https://github.com/pydata/xarray/issues/3232#issuecomment-769789746 | https://api.github.com/repos/pydata/xarray/issues/3232 | 769789746 | MDEyOklzc3VlQ29tbWVudDc2OTc4OTc0Ng== | 14808389 | 2021-01-29T12:57:37Z | 2021-01-29T15:22:01Z | MEMBER |
For now, I guess we can remove it using Not sure if that's the best way, but that would look like this: <tt>pytorch</tt> wrapper class```python In [13]: import numpy as np ...: import torch ...: from typing import Tuple ...: import xarray as xr ...: import functools ...: ...: def wrap_torch(f): ...: @functools.wraps(f) ...: def wrapper(*args, **kwargs): ...: # TODO: use a dict comprehension if there are functions that rely on the order of the parameters ...: if "axis" in kwargs: ...: kwargs["dim"] = kwargs.pop("axis") # torch calls that parameter 'dim' instead of 'axis' ...: ...: return f(*args, **kwargs) ...: ...: return wrapper ...: ...: class DTypeWrapper: ...: def __init__(self, dtype): ...: self.dtype = dtype ...: if dtype.is_complex: ...: self.kind = "c" ...: elif dtype.is_floating_point: ...: self.kind = "f" ...: else: ...: # I don't know pytorch at all, so falling back to "i" might not be the best choice ...: self.kind = "i" ...: ...: def __getattr__(self, name): ...: return getattr(self.dtype, name) ...: ...: def __repr__(self): ...: return repr(self.dtype) ...: ...: IMPLEMENTED_FUNCTIONS = { ...: np.mean: wrap_torch(torch.mean), ...: np.nanmean: wrap_torch(torch.mean), # not sure if pytorch has a separate nanmean function ...: } ...: ...: class XArrayTensor(torch.Tensor): ...: def __new__(cls, data=None, requires_grad=False): ...: if data is None: ...: data = torch.Tensor() ...: return torch.Tensor._make_subclass(cls, data, requires_grad) ...: ...: def __init__(self, data=None, dims: Tuple[str] = None): ...: self.dims = dims ...: ...: def __array_function__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplemented ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __array_ufunc__(self, func, types, args, kwargs): ...: if func not in IMPLEMENTED_FUNCTIONS or any(not issubclass(t, torch.Tensor) for t in types): ...: return NotImplementedError ...: return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs) ...: ...: def __getattribute__(self, name): ...: if name == "values": ...: raise AttributeError( ...: "'values' has been removed for compatibility with xarray." ...: " To access it, use `torch.Tensor(tensor).values()`." ...: ) ...: return object.__getattribute__(self, name) ...: ...: @property ...: def shape(self): ...: return tuple(super().shape) ...: ...: @property ...: def dtype(self): ...: return DTypeWrapper(super().dtype) ...: ...: tensor = XArrayTensor(torch.rand(3, 2)) ...: display(tensor) ...: display(tensor.shape) ...: display(tensor.dtype) ...: display(tensor.ndim) ...: ...: da = xr.DataArray(tensor, coords={"a": ["a1", "a2", "a3"], "b": ["b1", "b2"]}, dims=["a", "b"]) ...: display(da) ...: display(da.data) ...: display(da.mean(dim="a")) ```with that, I can execute For
no, it won't be because this is fragile: any new method of Instead, we recommend extension packages (extending xarray), so with a hypothetical |
{ "total_count": 0, "+1": 0, "-1": 0, "laugh": 0, "hooray": 0, "confused": 0, "heart": 0, "rocket": 0, "eyes": 0 } |
482543307 |