home / github / issue_comments

Menu
  • GraphQL API
  • Search all tables

issue_comments: 175175494

This data as json

html_url issue_url id node_id user created_at updated_at author_association body reactions performed_via_github_app issue
https://github.com/pydata/xarray/issues/723#issuecomment-175175494 https://api.github.com/repos/pydata/xarray/issues/723 175175494 MDEyOklzc3VlQ29tbWVudDE3NTE3NTQ5NA== 15167171 2016-01-26T18:53:02Z 2016-01-26T18:53:02Z NONE

Looks like it can perform tensor dot for dask and straight xarrays! But apparently dask has not implemented tensordot with multiple axes arguments, and it also does not work performing a tensor dot between a dask xarray and an xarray. Neither of these cases worries me too much, hopefully they don't worry you.

``` python from xarray import align, DataArray

note: using private imports (e.g., from xarray.core) is definitely discouraged!

this is not guaranteed to work in future versions of xarray

from xarray.core.ops import _dask_or_eager_func

def tensordot(a, b, dims): if not (isinstance(a, DataArray) and isinstance(b, DataArray)): raise ValueError

a, b = align(a, b, join='inner', copy=False)

axes = (a.get_axis_num(dims), b.get_axis_num(dims))
f = _dask_or_eager_func('tensordot', n_array_args=2)
new_data = f(a.data, b.data, axes=axes)

if isinstance(dims, str):
    dims = [dims]

new_coords = a.coords.merge(b.coords).drop(dims)

#drop the dims you are performing the sum product over
new_dims = ([d for d in a.dims if d not in dims] +
            [d for d in b.dims if d not in dims])

return DataArray(new_data, new_coords, new_dims)

import xarray as xr import numpy as np

x_trans = np.linspace(-3,3,6) y_trans = np.linspace(-3,3,5) imgID = range(4) da = xr.DataArray( np.ones((6,5,4)), coords = [ x_trans, y_trans, imgID ], dims = ['x_trans', 'y_trans', 'imgID'] )

models = range(20) dm = xr.DataArray( np.ones(( 20 , 5, 4 )), coords = [ models, y_trans, imgID], dims = [ 'models', 'y_trans', 'imgID' ] )

xarray tensordot

proj_a = tensordot(da, dm, 'imgID')

dask xarray tensor dot

da = da.chunk() dm = dm.chunk() proj_b = tensordot(da, dm, 'imgID')

errors

multiple dims

proj_c = tensordot(da, dm, ['imgID', 'y_trans'])

mixed types

da = da.chunk() dm = dm.load() proj_d = tensordot(da, dm, 'imgID') ```

{
    "total_count": 0,
    "+1": 0,
    "-1": 0,
    "laugh": 0,
    "hooray": 0,
    "confused": 0,
    "heart": 0,
    "rocket": 0,
    "eyes": 0
}
  128687346
Powered by Datasette · Queries took 0.506ms · About: xarray-datasette