id,node_id,number,title,user,state,locked,assignee,milestone,comments,created_at,updated_at,closed_at,author_association,active_lock_reason,draft,pull_request,body,reactions,performed_via_github_app,state_reason,repo,type 550355524,MDU6SXNzdWU1NTAzNTU1MjQ=,3698,dask.optimize on xarray objects,2448579,closed,0,,,5,2020-01-15T18:29:18Z,2020-09-20T05:21:57Z,2020-09-20T05:21:57Z,MEMBER,,,,"I am trying to call `dask.optimize` on a xarray object before the graph gets too big. But get weird errors. Simple examples below. All examples work if I remove the `dask.optimize` step. cc @mrocklin @shoyer ### This works with dask arrays: ``` python a = dask.array.ones((10,5), chunks=(1,3)) a = dask.optimize(a)[0] a.compute() ``` ### It works when a dataArray is constructed using a dask array ``` python da = xr.DataArray(a) da = dask.optimize(da)[0] da.compute() ``` ### but fails when creating a DataArray with a numpy array and then chunking it :man_shrugging: ``` python da = xr.DataArray(a.compute()).chunk({""dim_0"": 5}) da = dask.optimize(da)[0] da.compute() ``` fails with error ``` python --------------------------------------------------------------------------- TypeError Traceback (most recent call last) in 1 da = xr.DataArray(a.compute()).chunk({""dim_0"": 5}) 2 da = dask.optimize(da)[0] ----> 3 da.compute() ~/python/xarray/xarray/core/dataarray.py in compute(self, **kwargs) 838 """""" 839 new = self.copy(deep=False) --> 840 return new.load(**kwargs) 841 842 def persist(self, **kwargs) -> ""DataArray"": ~/python/xarray/xarray/core/dataarray.py in load(self, **kwargs) 812 dask.array.compute 813 """""" --> 814 ds = self._to_temp_dataset().load(**kwargs) 815 new = self._from_temp_dataset(ds) 816 self._variable = new._variable ~/python/xarray/xarray/core/dataset.py in load(self, **kwargs) 659 660 # evaluate all the dask arrays simultaneously --> 661 evaluated_data = da.compute(*lazy_data.values(), **kwargs) 662 663 for k, data in zip(lazy_data, evaluated_data): ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs) 434 keys = [x.__dask_keys__() for x in collections] 435 postcomputes = [x.__dask_postcompute__() for x in collections] --> 436 results = schedule(dsk, keys, **kwargs) 437 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)]) 438 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs) 79 get_id=_thread_get_id, 80 pack_exception=pack_exception, ---> 81 **kwargs 82 ) 83 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs) 484 _execute_task(task, data) # Re-execute locally 485 else: --> 486 raise_exception(exc, tb) 487 res, worker_id = loads(res_info) 488 state[""cache""][key] = res ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb) 314 if exc.__traceback__ is not tb: 315 raise exc.with_traceback(tb) --> 316 raise exc 317 318 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception) 220 try: 221 task, data = loads(task_info) --> 222 result = _execute_task(task, data) 223 id = get_id() 224 result = dumps((result, id)) ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk) 117 func, args = arg[0], arg[1:] 118 args2 = [_execute_task(a, cache) for a in args] --> 119 return func(*args2) 120 elif not ishashable(arg): 121 return arg TypeError: string indices must be integers ``` ### And a different error when rechunking a dask-backed DataArray ``` python da = xr.DataArray(a).chunk({""dim_0"": 5}) da = dask.optimize(da)[0] da.compute() ``` ``` python --------------------------------------------------------------------------- IndexError Traceback (most recent call last) in 1 da = xr.DataArray(a).chunk({""dim_0"": 5}) 2 da = dask.optimize(da)[0] ----> 3 da.compute() ~/python/xarray/xarray/core/dataarray.py in compute(self, **kwargs) 838 """""" 839 new = self.copy(deep=False) --> 840 return new.load(**kwargs) 841 842 def persist(self, **kwargs) -> ""DataArray"": ~/python/xarray/xarray/core/dataarray.py in load(self, **kwargs) 812 dask.array.compute 813 """""" --> 814 ds = self._to_temp_dataset().load(**kwargs) 815 new = self._from_temp_dataset(ds) 816 self._variable = new._variable ~/python/xarray/xarray/core/dataset.py in load(self, **kwargs) 659 660 # evaluate all the dask arrays simultaneously --> 661 evaluated_data = da.compute(*lazy_data.values(), **kwargs) 662 663 for k, data in zip(lazy_data, evaluated_data): ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs) 434 keys = [x.__dask_keys__() for x in collections] 435 postcomputes = [x.__dask_postcompute__() for x in collections] --> 436 results = schedule(dsk, keys, **kwargs) 437 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)]) 438 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs) 79 get_id=_thread_get_id, 80 pack_exception=pack_exception, ---> 81 **kwargs 82 ) 83 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs) 484 _execute_task(task, data) # Re-execute locally 485 else: --> 486 raise_exception(exc, tb) 487 res, worker_id = loads(res_info) 488 state[""cache""][key] = res ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb) 314 if exc.__traceback__ is not tb: 315 raise exc.with_traceback(tb) --> 316 raise exc 317 318 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception) 220 try: 221 task, data = loads(task_info) --> 222 result = _execute_task(task, data) 223 id = get_id() 224 result = dumps((result, id)) ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk) 117 func, args = arg[0], arg[1:] 118 args2 = [_execute_task(a, cache) for a in args] --> 119 return func(*args2) 120 elif not ishashable(arg): 121 return arg ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in concatenate3(arrays) 4305 if not ndim: 4306 return arrays -> 4307 chunks = chunks_from_arrays(arrays) 4308 shape = tuple(map(sum, chunks)) 4309 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in chunks_from_arrays(arrays) 4085 4086 while isinstance(arrays, (list, tuple)): -> 4087 result.append(tuple([shape(deepfirst(a))[dim] for a in arrays])) 4088 arrays = arrays[0] 4089 dim += 1 ~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in (.0) 4085 4086 while isinstance(arrays, (list, tuple)): -> 4087 result.append(tuple([shape(deepfirst(a))[dim] for a in arrays])) 4088 arrays = arrays[0] 4089 dim += 1 IndexError: tuple index out of range ```","{""url"": ""https://api.github.com/repos/pydata/xarray/issues/3698/reactions"", ""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,completed,13221727,issue