html_url,issue_url,id,node_id,user,created_at,updated_at,author_association,body,reactions,performed_via_github_app,issue https://github.com/pydata/xarray/issues/4325#issuecomment-1507201606,https://api.github.com/repos/pydata/xarray/issues/4325,1507201606,IC_kwDOAMm_X85Z1hJG,34276374,2023-04-13T15:48:31Z,2023-04-13T15:48:31Z,NONE,"I think I may have found a way to make the variance/standard deviation calculation more memory efficient, but I don't know enough about writing the sort of code that would be needed for a PR. I basically wrote out the calculation for variance trying to only use the functions that have already been optimsed. Derived from: $$ var = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 $$ $$ var = \frac{1}{n} \left( (x_1 - \mu)^2 + (x_2 - \mu)^2 + (x_3 - \mu)^2 + ... \right) $$ $$ var = \frac{1}{n} \left(x_1^2 -2x_1\mu + \mu^2 + \\ x_2^2 -2x_2\mu + \mu^2 + \\ x_3^2 -2x_3\mu + \mu^2 + ... \right) $$ $$ var = \frac{1}{n} \left( \sum_{i=1}^{n} x_i^2 - 2\mu\sum_{i=1}^{n} x_i + n\mu^2 \right)$$ I coded this up and demonstrate that it uses approximately 10% of the memory as the current `.var()` implementation: ```python %load_ext memory_profiler import numpy as np import xarray as xr temp = xr.DataArray(np.random.randint(0, 10, (5000, 500)), dims=(""x"", ""y"")) def new_var(da, x=10, y=20): # Defining the re-used parts roll = da.rolling(x=x, y=y) mean = roll.mean() count = roll.count() # First term: sum of squared values term1 = (da**2).rolling(x=x, y=y).sum() # Second term cross term sum term2 = -2 * mean * roll.sum() # Third term 'sum' of squared means term3 = count * mean**2 # Combining into the variance var = (term1 + term2 + term3) / count return var def old_var(da, x=10, y=20): roll = da.rolling(x=x, y=y) var = roll.var() return var %memit new_var(temp) %memit old_var(temp) ``` ``` peak memory: 429.77 MiB, increment: 134.92 MiB peak memory: 5064.07 MiB, increment: 4768.45 MiB ``` I wanted to double check that the calculation was working correctly: ```python print((var_o.where(~np.isnan(var_o), 0) == var_n.where(~np.isnan(var_n), 0)).all().values) print(np.allclose(var_o, var_n, equal_nan = True)) ``` ``` False True ``` I think the difference here is just due to floating point errors, but maybe someone who knows how to check that in more detail could have a look. The standard deviation can be trivially implemented from this if the approach works.","{""total_count"": 0, ""+1"": 0, ""-1"": 0, ""laugh"": 0, ""hooray"": 0, ""confused"": 0, ""heart"": 0, ""rocket"": 0, ""eyes"": 0}",,675482176