home / github / issue_comments

Menu
  • GraphQL API
  • Search all tables

issue_comments: 768529007

This data as json

html_url issue_url id node_id user created_at updated_at author_association body reactions performed_via_github_app issue
https://github.com/pydata/xarray/issues/3232#issuecomment-768529007 https://api.github.com/repos/pydata/xarray/issues/3232 768529007 MDEyOklzc3VlQ29tbWVudDc2ODUyOTAwNw== 19956442 2021-01-27T19:39:32Z 2021-01-29T22:37:28Z NONE

I've made some mild progress, but it raises a few questions. I've defined this simple Tensor subclass which meets the duck array criteria:

``` class XArrayTensor(torch.Tensor): def new(cls, data=None, requires_grad=False): if data is None: data = torch.Tensor() return torch.Tensor._make_subclass(cls, data, requires_grad)

def __init__(self, data=None, dims: Tuple[str] = None):
    self.dims = dims

def __array_function__(self, func, types, args, kwargs):
    if func not in IMPLEMENTED_FUNCTIONS or not (not all(issubclass(t, torch.Tensor) for t in types)):
        return NotImplemented
    return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs)

def __array_ufunc__(self, func, types, args, kwargs):
    if func not in IMPLEMENTED_FUNCTIONS or not (not all(issubclass(t, torch.Tensor) for t in types)):
        return NotImplementedError
    return IMPLEMENTED_FUNCTIONS[func](*args, **kwargs)

```

where IMPLEMENTED_FUNCTIONS holds a mapping from numpy functions to API compatible tensor operators (similar in style to this)

I added a torch_array_type to pycompat.py, which allows DataArray's .data attribute to persist as an XArrayTensor:

``` xr_tsr = XArrayTensor(torch.rand(3, 2))

data_array = xr.DataArray( xr_tsr, coords=dict(a=["a1", "a2", "a3"], b=["b1", "b1"]), dims=["a", "b"], name="dummy", attrs={"grad": xr_tsr.grad}, ) print(type(data_array.data)) --> yields 'xarray_tensor.XArrayTensor' ```

The issue I'm running into is when I run an operation like np.mean(data_array). The operation gets dispatched to functions within duck_array_ops.py, which are the things I'd like to override.

Also, I'd like to confirm something. If the API matching were complete, would the following be possible?

some_sum = data_array.sum() some_sum.backward() data_array.grad --> provides the gradient

I'm starting to suspect not because that would involve data_array being both DataArray and a Torch.Tensor object. It seems what I'm in fact enabling is that DataArray.data is a Torch.Tensor.

{
    "total_count": 2,
    "+1": 0,
    "-1": 0,
    "laugh": 0,
    "hooray": 0,
    "confused": 0,
    "heart": 2,
    "rocket": 0,
    "eyes": 0
}
  482543307
Powered by Datasette · Queries took 0.675ms · About: xarray-datasette