id,node_id,number,title,user,state,locked,assignee,milestone,comments,created_at,updated_at,closed_at,author_association,active_lock_reason,draft,pull_request,body,reactions,performed_via_github_app,state_reason,repo,type
561921094,MDU6SXNzdWU1NjE5MjEwOTQ=,3762,xarray groupby/map fails to parallelize,6491058,closed,1,,,4,2020-02-07T23:20:59Z,2023-09-15T15:52:42Z,2023-09-15T15:52:41Z,NONE,,,,"#### MCVE Code Sample
<!-- In order for the maintainers to efficiently understand and prioritize issues, we ask you post a ""Minimal, Complete and Verifiable Example"" (MCVE): http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports -->

```python
import sys
import math
import logging
import dask
import xarray
import numpy

logger = logging.getLogger('main')

if __name__ == '__main__':
    logging.basicConfig(
        stream=sys.stdout,
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')

    logger.info('Starting dask client')
    client = dask.distributed.Client()

    SIZE = 100000
    SONAR_BINS = 2000
    time = range(0, SIZE)
    upper_limit = numpy.random.randint(0, 10, (SIZE))
    lower_limit = numpy.random.randint(20, 30, (SIZE))
    sonar_data = numpy.random.randint(0, 255, (SIZE, SONAR_BINS))

    channel = xarray.Dataset({
            'upper_limit': (['time'], upper_limit, {'units': 'depth meters'}),
            'lower_limit': (['time'],  lower_limit, {'units': 'depth meters'}),
            'data': (['time', 'depth_bin'], sonar_data, {'units': 'amplitude'}),
        },
        coords={
            'depth_bin': (['depth_bin'], range(0,SONAR_BINS)),
            'time': (['time'], time)
        })

    logger.info('get overall min/max radar range we want to normalize to called the adjusted range')
    adjusted_min, adjusted_max = channel.upper_limit.min().values.item(), channel.lower_limit.max().values.item()
    adjusted_min = math.floor(adjusted_min)
    adjusted_max = math.ceil(adjusted_max)
    logger.info('adjusted_min: %s, adjusted_max: %s', adjusted_min, adjusted_max)

    bin_count = len(channel.depth_bin)
    logger.info('bin_count: %s', bin_count)

    adjusted_depth_per_bin = (adjusted_max - adjusted_min) / bin_count
    logger.info('adjusted_depth_per_bin: %s', adjusted_depth_per_bin)

    adjusted_bin_depths = [adjusted_min + (j * adjusted_depth_per_bin) for j in range(0, bin_count)]
    logger.info('adjusted_bin_depths[0]: %s ... [-1]: %s', adjusted_bin_depths[0], adjusted_bin_depths[-1])

    def Interp(ds):
        # Ideally instead of using interp we will use some kind of downsampling and shift
        # this doesnt exist in xarray though and interp is good enough for the moment

        # I just added this to debug
        t = ds.time.values.item()
        if (t % 100) == 0:
            total = len(channel.time)
            perc = 100.0 * t / total
            logger.info('%s : %s of %s', perc, t, total)

        unadjusted_depth_amplitudes = ds.data
        unadjusted_min = ds.upper_limit.values.item()
        unadjusted_max = ds.lower_limit.values.item()
        unadjusted_depth_per_bin = (unadjusted_max - unadjusted_min) / bin_count

        index_mapping = [((adjusted_min + (bin * adjusted_depth_per_bin)) - unadjusted_min) / unadjusted_depth_per_bin for bin in range(0, bin_count)]
        adjusted_depth_amplitudes = unadjusted_depth_amplitudes.interp(coords={'depth_bin':index_mapping}, method='linear', assume_sorted=True)
        adjusted_depth_amplitudes = adjusted_depth_amplitudes.rename({'depth_bin':'depth'}).assign_coords({'depth':adjusted_bin_depths})

        #logger.info('%s, \n\tunadjusted_depth_amplitudes.values:%s\n\tunadjusted_min:%s\n\tunadjusted_max:%s\n\tunadjusted_depth_per_bin:%s\n\tindex_mapping:%s\n\tadjusted_depth_amplitudes:%s\n\tadjusted_depth_amplitudes.values:%s\n\n', ds, unadjusted_depth_amplitudes.values, unadjusted_min, unadjusted_max, unadjusted_depth_per_bin, index_mapping, adjusted_depth_amplitudes, adjusted_depth_amplitudes.values)
        return adjusted_depth_amplitudes

    # Lets split into chunks so could be performed in parallel
    # This doesnt work to parallelize and only slows it down a lot
    #logger.info('chunk')
    #channel = channel.chunk({'time':100})

    logger.info('groupby')
    g = channel.groupby('time')

    logger.info('do interp')
    normalized_depth_data = g.map(Interp)

    logger.info('done')
```

#### Expected Output
I am fairly new to xarray but feel this example could have been executed a bit better than xarray currenty does. Each map call of the above custom function should be possible to be parallelized from what I can tell. I imagined that in the backend, xarray would have chunked it and run in parallel on dask. However I find it is VERY slow even for single threaded case but also that it doesn't seem to parallelize. 

It takes roughly 5msec per map call in my hardware when I don't include the chunk and 70msec with the chunk call you can find in the code. 


#### Problem Description
<!-- this should explain why the current behavior is a problem and why the expected output is a better solution -->
The single threaded performance is super slow, but also it fails to parallelize the computations across the cores on my machine.

If you are after more background to what I am trying to do, I also asked a SO question about how to re-organize the code to improve performance. I felt the current behavior though is a performance bug (assuming I didn't do something completely wrong in the code).

https://stackoverflow.com/questions/60103317/can-the-performance-of-using-xarray-groupby-map-be-improved


#### Output of ``xr.show_versions()``
<details>
# Paste the output here xr.show_versions() here
xarray.show_versions()

INSTALLED VERSIONS
------------------
commit: None
python: 3.7.6 | packaged by conda-forge | (default, Jan  7 2020, 21:48:41) [MSC v.1916 64 bit (AMD64)]
python-bits: 64
OS: Windows
OS-release: 10
machine: AMD64
processor: Intel64 Family 6 Model 142 Stepping 10, GenuineIntel
byteorder: little
LC_ALL: None
LANG: None
LOCALE: None.None
libhdf5: 1.10.4
libnetcdf: 4.6.1

xarray: 0.14.1
pandas: 0.25.3
numpy: 1.17.3
scipy: 1.3.1
netCDF4: 1.4.2
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.0.4.2
nc_time_axis: None
PseudoNetCDF: None
rasterio: 1.1.2
cfgrib: None
iris: None
bottleneck: None
dask: 2.9.1
distributed: 2.9.1
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: None
numbagg: None
setuptools: 44.0.0.post20200102
pip: 19.3.1
conda: None
pytest: None
IPython: 7.11.1
sphinx: None
</details>
","{""url"": ""https://api.github.com/repos/pydata/xarray/issues/3762/reactions"", ""total_count"": 1, ""+1"": 1, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,completed,13221727,issue