home / github / issues

Menu
  • GraphQL API
  • Search all tables

issues: 1525546857

This data as json

id node_id number title user state locked assignee milestone comments created_at updated_at closed_at author_association active_lock_reason draft pull_request body reactions performed_via_github_app state_reason repo type
1525546857 I_kwDOAMm_X85a7f9p 7429 Training on xarray files leads to CPU memory leak (PyTorch) 7348840 closed 0     2 2023-01-09T12:57:23Z 2023-01-13T13:17:43Z 2023-01-13T13:17:42Z NONE      

What happened?

Description

At each training batch, CPU memory increases a bit until I run out of memory (total RAM: 376 GB). After that I cannot even ssh into the machine jupyter notebook was served (nor see any errors caused). I cannot understand where or why memory is being cached forever in CPU: my training is done on GPU.

I have written a minimal reproduction code, which I share below. Two dataset versions are tested: one uses xarray and the other, numpy npz files. The bug is only reproduced using the xarray dataset.

Setup

The cluster I use is managed by SLURM and it uses a lustre filesystem. Training is performed using an NVidia GPU. My Python 3.8.15 is an official Debian Bullseye docker image, which I've pulled and now use via Singularity (must, as my cluster does not allow docker directly).

Dependencies

numpy==1.22.4 (also tested on 1.23.5) xarray==2022.11.0 (also tested on 2022.12.0) h5netcdf==1.0.2 (also tested on 1.1.0) torch==1.13.0 (also tested on 1.13.1)

Current workaround

One workaround we use is using more workers to load data. So the memory is forced to be freed when the epoch ends because the threads are dead, I suppose. So for a while I can force that less batches are trained per epoch and the leak is controlled.

Issue based on Unexpected eternal CPU RAM growth during training #16227

What did you expect to happen?

Reading xarray files should keep data until data is read and the reader is closed. For some reason, data seems to be maintained cached somewhere.

Minimal Complete Verifiable Example

```Python

Imports

from pathlib import Path import psutil import numpy as np import xarray as xr import torch

data_dir = Path.cwd() / "data"

Defining equivalent XArray and NPZ datasets

class BaseDataset(torch.utils.data.Dataset):

def __init__(self,
             data_dir=None,
             transform=None,
             shape=(1, 128, 128, 128),
             size=1000):
    self.data_dir = Path(data_dir)
    self.transform = transform
    self.shape = shape
    self.size = size
    self.prepared = False

def __len__(self):
    return self.size

def get_fake_sample(self):
    x = np.random.normal(size=self.shape).astype(np.float32)
    y = (x > .7).astype(np.int8)
    return {"x": x, "y": y}

class XarrayDataset(BaseDataset):

def __getitem__(self, idx):
    ds = xr.open_dataset(self.data_path)
    sample = {"x": torch.as_tensor(ds["x"].data),
              "y": torch.as_tensor(ds["y"].data)}
    ds.close()
    if self.transform:
        sample = self.transform(sample)
    return sample

@property
def data_path(self):
    return self.data_dir / "data.nc"

def prepare_data(self):
    if self.data_path.exists():
        return
    self.data_dir.mkdir(exist_ok=True)

    sample = self.get_fake_sample()
    ds = xr.Dataset({
        var: xr.DataArray(arr) for var, arr in sample.items()
    })
    ds.to_netcdf(self.data_path)
    ds.close()

class NpzDataset(BaseDataset):

def __getitem__(self, idx):
    npz = np.load(self.data_path)
    sample = {"x": torch.as_tensor(npz["x"]),
              "y": torch.as_tensor(npz["y"])}
    if self.transform:
        sample = self.transform(sample)
    return sample

@property
def data_path(self):
    return self.data_dir / "data.npz"

def prepare_data(self):
    if self.data_path.exists():
        return
    self.data_dir.mkdir(exist_ok=True)

    sample = self.get_fake_sample()
    np.savez_compressed(self.data_path, **sample)

class ComplicatedTransform():

def __init__(self, concat_operations=1):
    self.concat_operations = concat_operations

def __call__(self, sample):
    x = sample["x"]
    for _ in range(self.concat_operations):
        x = torch.cat([x, x**2])
    sample["x"] = x
    return sample

Prepare training

ChosenDataset = XarrayDataset

ChosenDataset = NpzDataset

max_epochs = 10 concat_operations = 4

dataset = ChosenDataset( data_dir=data_dir, transform=ComplicatedTransform(concat_operations), )

dataset.prepare_data()

loader = torch.utils.data.DataLoader( dataset, batch_size=64, num_workers=0, drop_last=True, shuffle=True, )

loss_fn = torch.nn.CrossEntropyLoss()

model = SimpleModel = torch.nn.Sequential( torch.nn.LazyConv3d(out_channels=1, kernel_size=1), torch.nn.Sigmoid(), )

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

Train

device = "cuda"

print(f"## {ChosenDataset.name}: {concat_operations=}", end="\n\n")

print(f"| epoch | memory (GB) |") print(f"|-------|-------------|")

model = model.to(device) for epoch in range(max_epochs):

memory = psutil.Process().memory_info().rss / (1024 **3)  # GB
print(f"| {epoch} | {memory:.3f} |")

for batch in loader:
    X = batch["x"].to(device)
    Y = batch["y"].to(device)

    optimizer.zero_grad()
    Y_pred_proba = model(X)
    loss = loss_fn(Y_pred_proba, Y.to(torch.float16))

    loss.backward()
    optimizer.step()

    del X
    del Y

```

MVCE confirmation

  • [X] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • [X] Complete example — the example is self-contained, including all data and the text of any traceback.
  • [x] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • [X] New issue — a search of GitHub Issues suggests this is not a duplicate.

Relevant log output

No response

Anything else we need to know?

Using the Xarray dataset and concat_operations=1, I could see memory growth of ~16 GB per epoch. With concat_operations=4, ~30 GB per epoch. With no concat_operations, no memory growth.

Using the NPZ dataset, I could reproduce no RAM accumulation along epochs.

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.10.6 (main, Aug 23 2022, 08:25:41) [GCC 10.2.1 20210110] python-bits: 64 OS: Linux OS-release: 3.10.0-1160.49.1.el7.x86_64 machine: x86_64 processor: byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.12.2 libnetcdf: 4.9.0 xarray: 2022.11.0 pandas: 1.4.3 numpy: 1.23.4 scipy: 1.8.1 netCDF4: 1.6.1 pydap: None h5netcdf: 1.0.2 h5py: 3.7.0 Nio: None zarr: None cftime: 1.6.2 nc_time_axis: None PseudoNetCDF: None rasterio: None cfgrib: None iris: None bottleneck: None dask: 2022.10.2 distributed: None matplotlib: 3.6.2 cartopy: None seaborn: None numbagg: None fsspec: 2022.7.0 cupy: None pint: None sparse: None flox: None numpy_groupies: None setuptools: 63.2.0 pip: 22.3.1 conda: None pytest: None IPython: 8.6.0 sphinx: None
{
    "url": "https://api.github.com/repos/pydata/xarray/issues/7429/reactions",
    "total_count": 0,
    "+1": 0,
    "-1": 0,
    "laugh": 0,
    "hooray": 0,
    "confused": 0,
    "heart": 0,
    "rocket": 0,
    "eyes": 0
}
  completed 13221727 issue

Links from other tables

  • 2 rows from issues_id in issues_labels
  • 2 rows from issue in issue_comments
Powered by Datasette · Queries took 77.899ms · About: xarray-datasette