html_url,issue_url,id,node_id,user,created_at,updated_at,author_association,body,reactions,performed_via_github_app,issue
https://github.com/pydata/xarray/issues/3232#issuecomment-524420000,https://api.github.com/repos/pydata/xarray/issues/3232,524420000,MDEyOklzc3VlQ29tbWVudDUyNDQyMDAwMA==,1217238,2019-08-23T18:38:19Z,2019-08-23T18:38:19Z,MEMBER,"I have not thought too much about these yet. But I agree that they will
probably require backend specific logic to do efficiently.
On Fri, Aug 23, 2019 at 12:13 PM firdaus janoos
wrote:
> While it is pretty straightforward to implement a lot of standard xarray
> operations with a pytorch / Jax backend (since they just fallback on native
> functions) - it will be interesting to think about how to implement rolling
> operations / expanding / exponential window in a way that is both efficient
> and maintains differentiability.
>
> Expanding and exponential window operations would be easy to do leveraging
> RNN semantics - but doing rolling using convolutions is going to be very
> inefficient.
>
> Do you have any thoughts on this?
>
> —
> You are receiving this because you commented.
> Reply to this email directly, view it on GitHub
> ,
> or mute the thread
>
> .
>
","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-524403160,https://api.github.com/repos/pydata/xarray/issues/3232,524403160,MDEyOklzc3VlQ29tbWVudDUyNDQwMzE2MA==,1217238,2019-08-23T17:45:54Z,2019-08-23T17:45:54Z,MEMBER,"Within a `jit` compiled function, JAX's execution speed should be quite competitive on GPUs. It uses the XLA compiler, which was recently enabled by default in TensorFlow.
For data loading and deep learning algorithms, take a look at the examples in the `notebooks` directory in the JAX repo. The APIs for deep learning in JAX are still undergoing rapid development, so APIs are not quite as stable/usable as pytorch or keras yet, but they are quite capable. See `jax.experimental.stax` and [`tensor2tensor.trax`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax) for examples.","{""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-522884516,https://api.github.com/repos/pydata/xarray/issues/3232,522884516,MDEyOklzc3VlQ29tbWVudDUyMjg4NDUxNg==,1217238,2019-08-20T07:07:18Z,2019-08-20T07:07:18Z,MEMBER,"> Implementing it has some backwards compat concerns as well, because people may be relying on `np.somefunc(some_torch_tensor)` to be coerced to `ndarray`.
Yes, this is a concern for JAX as well. This is a definite downside of reusing NumPy's existing namespace.
It turns out even xarray was relying on this behavior with dask in at least one edge case: https://github.com/pydata/xarray/issues/3215","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307
https://github.com/pydata/xarray/issues/3232#issuecomment-522820303,https://api.github.com/repos/pydata/xarray/issues/3232,522820303,MDEyOklzc3VlQ29tbWVudDUyMjgyMDMwMw==,1217238,2019-08-20T01:55:46Z,2019-08-20T01:55:46Z,MEMBER,"If pytorch implements overrides of NumPy's API via the [`__array_function__` protocol](https://www.numpy.org/neps/nep-0018-array-function-protocol.html), then this could work with minimal effort. We are already using this to support [sparse arrays](https://sparse.pydata.org/en/latest/) (this isn't an official release yet, but functionality is working in the development version).
I think there has been some discussion about this, but I don't know the current status (CC @rgommers). The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API.
Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with [JAX](https://github.com/google/jax), which already implements NumPy's API almost exactly. I have an [experimental pull request](https://github.com/google/jax/pull/611) adding `__array_function__` to JAX, but it still needs a bit of work to finish it up, e.g., we probably want to hide this behind a flag at first.","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,482543307