home / github / issues

Menu
  • Search all tables
  • GraphQL API

issues: 1974350560

This data as json

id node_id number title user state locked assignee milestone comments created_at updated_at closed_at author_association active_lock_reason draft pull_request body reactions performed_via_github_app state_reason repo type
1974350560 I_kwDOAMm_X851rjLg 8402 `where` dtype upcast with numpy 2 1828519 open 0     10 2023-11-02T14:12:49Z 2024-04-15T19:18:49Z   CONTRIBUTOR      

What happened?

I'm testing my code with numpy 2.0 and current main xarray and dask and ran into a change that I guess is expected given the way xarray does things, but want to make sure as it could be unexpected for many users.

Doing DataArray.where with an integer array less than 64-bits and an integer as the new value will upcast the array to 64-bit integers (python's int). With old versions of numpy this would preserve the dtype of the array. As far as I can tell the relevant xarray code hasn't changed so this seems to be more about numpy making things more consistent.

The main problem seems to come down to:

https://github.com/pydata/xarray/blob/d933578ebdc4105a456bada4864f8ffffd7a2ced/xarray/core/duck_array_ops.py#L218

As this converts my scalar input int to a numpy array. If it didn't do this array conversion then numpy works as expected. See the MCVE for the xarray specific example, but here's the numpy equivalent:

```python import numpy as np

a = np.zeros((2, 2), dtype=np.uint16)

what I'm intending to do with my xarray data_arr.where(cond, 2)

np.where(a != 0, a, 2).dtype

dtype('uint16')

equivalent to what xarray does:

np.where(a != 0, a, np.asarray(2)).dtype

dtype('int64')

workaround, cast my scalar to a specific numpy type

np.where(a != 0, a, np.asarray(np.uint16(2))).dtype

dtype('uint16')

```

From a numpy point of view, the second where call makes sense that 2 arrays should be upcast to the same dtype so they can be combined. But from an xarray user point of view, I'm entering a scalar so I expect it to be the same as the first where call above.

What did you expect to happen?

See above.

Minimal Complete Verifiable Example

```Python import xarray as xr import numpy as np

data_arr = xr.DataArray(np.array([1, 2], dtype=np.uint16)) print(data_arr.where(data_arr == 2, 3).dtype)

int64

```

MVCE confirmation

  • [X] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • [X] Complete example — the example is self-contained, including all data and the text of any traceback.
  • [X] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • [X] New issue — a search of GitHub Issues suggests this is not a duplicate.
  • [X] Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

Numpy 1.x preserves the dtype.

```python In [1]: import numpy as np

In [2]: np.asarray(2).dtype Out[2]: dtype('int64')

In [3]: a = np.zeros((2, 2), dtype=np.uint16)

In [4]: np.where(a != 0, a, np.asarray(2)).dtype Out[4]: dtype('uint16')

In [5]: np.where(a != 0, a, np.asarray(np.uint16(2))).dtype Out[5]: dtype('uint16') ```

Environment

``` INSTALLED VERSIONS ------------------ commit: None python: 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:17) [GCC 12.2.0] python-bits: 64 OS: Linux OS-release: 6.4.6-76060406-generic machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.2 libnetcdf: 4.9.2 xarray: 2023.10.2.dev21+gfcdc8102 pandas: 2.2.0.dev0+495.gecf449b503 numpy: 2.0.0.dev0+git20231031.42c33f3 scipy: 1.12.0.dev0+1903.18d0a2f netCDF4: 1.6.5 pydap: None h5netcdf: 1.2.0 h5py: 3.10.0 Nio: None zarr: 2.16.1 cftime: 1.6.3 nc_time_axis: None PseudoNetCDF: None iris: None bottleneck: 1.3.7.post0.dev7 dask: 2023.10.1+4.g91098a63 distributed: 2023.10.1+5.g76dd8003 matplotlib: 3.9.0.dev0 cartopy: None seaborn: None numbagg: None fsspec: 2023.6.0 cupy: None pint: 0.22 sparse: None flox: None numpy_groupies: None setuptools: 68.0.0 pip: 23.2.1 conda: None pytest: 7.4.0 mypy: None IPython: 8.14.0 sphinx: 7.1.2 ```
{
    "url": "https://api.github.com/repos/pydata/xarray/issues/8402/reactions",
    "total_count": 0,
    "+1": 0,
    "-1": 0,
    "laugh": 0,
    "hooray": 0,
    "confused": 0,
    "heart": 0,
    "rocket": 0,
    "eyes": 0
}
    13221727 issue

Links from other tables

  • 1 row from issues_id in issues_labels
  • 0 rows from issue in issue_comments
Powered by Datasette · Queries took 0.698ms · About: xarray-datasette