I am using fastjmd95 to infer potential density from CMIP6 models. I have recently experienced performance issues in a complicated workflow, but I think I can trace some of it back to the step involving fastjmd95.
Here is a small example that reproduces the issue:
# Load a single model from the CMIP archive
import xarray as xr
import gcsfs
from fastjmd95 import jmd95numba
gcs = gcsfs.GCSFileSystem(token='anon')
so = xr.open_zarr(gcs.get_mapper('gs://cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/so/gn/'), consolidated=True).so
thetao = xr.open_zarr(gcs.get_mapper('gs://cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Omon/thetao/gn/'), consolidated=True).thetao
# calculate sigma0 based on the instruction notebook (https://nbviewer.jupyter.org/github/xgcm/fastjmd95/blob/master/doc/fastjmd95_tutorial.ipynb)
sigma_0 = xr.apply_ufunc(
jmd95numba.rho, so, thetao, 0, dask='parallelized', output_dtypes=[so.dtype]
) - 1000
I then performed some tests on the Goodle Cloud deployment (dask cluster with 5 workers)
When I trigger a computation on the variables that are simply read from storage (so.mean().load(), everything works fine, the memory load is low and the task stream dense)
But when I try the same with the derived variable (sigma_0.mean().load()), things look really ugly: The memory fills up almost immediately and spilling to disk starts. From the Progress Pane it seems like dask is trying to load a large chunk of the dataset into memory before the rho calculation is applied.

To me it seems like the scheduler is going wide on the task graph rather than deep, which could free up some memory?
I am really not good enough to diagnose what is going on with dask, but any tips would be much appreciated.
I am using fastjmd95 to infer potential density from CMIP6 models. I have recently experienced performance issues in a complicated workflow, but I think I can trace some of it back to the step involving fastjmd95.
Here is a small example that reproduces the issue:
I then performed some tests on the Goodle Cloud deployment (dask cluster with 5 workers)
When I trigger a computation on the variables that are simply read from storage (
so.mean().load(), everything works fine, the memory load is low and the task stream dense)But when I try the same with the derived variable (

sigma_0.mean().load()), things look really ugly: The memory fills up almost immediately and spilling to disk starts. From the Progress Pane it seems like dask is trying to load a large chunk of the dataset into memory before therhocalculation is applied.To me it seems like the scheduler is going
wideon the task graph rather thandeep, which could free up some memory?I am really not good enough to diagnose what is going on with dask, but any tips would be much appreciated.