Skip to content

[Enhancement]: Explore dask distributed and chunking for improved performance and support for ultra high-resolution data #948

@tomvothecoder

Description

@tomvothecoder

Is your feature request related to a problem?

E3SM Diags performs diagnostic runs in parallel on a per variable level, based on the number of assigned workers (num_workers parameter). It utilizes a function called _run_with_dask() which passes the list of diagnostic runs to dask.bag. dask.bag is essentially a wrapper of Python multiprocessing.

def _run_with_dask(parameters: List[CoreParameter]) -> List[CoreParameter]:
"""Run diagnostics with the parameters in parallel using Dask.
This function passes ``run_diag`` to ``dask.bag.map``, which gets executed
in parallel with ``.compute``.
The first CoreParameter object's `num_workers` attribute is used to set
the number of workers for ``.compute``.
Parameters
----------
parameters : List[CoreParameter]
The list of CoreParameter objects to run diagnostics on.
Returns
-------
List[CoreParameter]
The list of CoreParameter objects with results from the diagnostic run.
Notes
-----
https://docs.dask.org/en/stable/generated/dask.bag.map.html
https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html
"""
bag = db.from_sequence(parameters)
config = {"scheduler": "processes", "multiprocessing.context": "fork"}
num_workers = getattr(parameters[0], "num_workers", None)
if num_workers is None:
raise ValueError(
"The `num_workers` attribute is required for multiprocessing but it is not "
"defined on the CoreParameter object. Set this attribute and try running "
"again."
)
with dask.config.set(config):
results = bag.map(CoreParameter._run_diag).compute(num_workers=num_workers)
# `results` becomes a list of lists of parameters so it needs to be
# collapsed a level.
collapsed_results = _collapse_results(results)
return collapsed_results

On each worker, e3sm_diags loads datasets into memory via .load(). Loading data into memory presents the following issues:

  1. Assumes data fits into memory and chokes on large datasets that exceed memory limits.
  2. There will be a multiprocessing scheduler conflict the dask.bag scheduler and the one created by Xarray/Dask automatically with open_mfdataset(). As a result, datasets can't be chunked for parallel operations downstream.

Example of where .load() is called:

def _subset_vars_and_load(self, ds: xr.Dataset, var: str | List[str]) -> xr.Dataset:
"""Subset for variables needed for processing and load into memory.
Subsetting the dataset reduces its memory footprint. Loading is
necessary because there seems to be an issue with `open_mfdataset()`
and using the multiprocessing scheduler defined in e3sm_diags,
resulting in timeouts and resource locking. To avoid this, we load the
multi-file dataset into memory before performing downstream operations.
Source: https://github.com/pydata/xarray/issues/3781


Describe the solution you'd like

We want to utilize Dask more effectively in order to handle datasets much larger than memory by chunking and working on small pieces at a time. Dask can schedule operations efficiently to maximize throughput. On HPCs, Dask can distribute work across nodes, making big workflows much faster.

In order to do that, we must:

  1. Replace dask.bag with dask's distributed cluster (related discussion).

Asking ChatGPT: Should I replace dask.bag with a dask distributed scheduler?

> 
> Yes, replacing `dask.bag` with **Dask Distributed** is a good idea because it provides better **process isolation**, **fault tolerance**, and **MPI compatibility**. Here’s why:
> 
> ---
> 
> ### **Why Use Dask Distributed Instead of `dask.bag` with Multiprocessing?**
> #### **1. Avoids `multiprocessing`-related Crashes**
> - `dask.bag` with `multiprocessing` relies on `concurrent.futures.ProcessPoolExecutor`, which can lead to **MPI conflicts** and the `BrokenProcessPool` error.
> - Dask Distributed runs **each worker in its own separate process**, avoiding `fork()` issues.
> 
> #### **2. Better Handling of MPI**
> - Dask Distributed **isolates workers properly**, preventing unwanted MPI reinitialization.
> - Unlike Python’s `multiprocessing`, Dask Distributed doesn’t inherit an already-initialized MPI environment across processes.
> 
> #### **3. More Robust and Scalable**
> - Supports **dynamic task scheduling**, while `dask.bag` is more static.
> - Handles **worker failures gracefully** (e.g., if a process crashes, the scheduler can restart it).
> 
> #### **4. More Control Over Worker Processes**
> - With Dask Distributed, you can:
>   - Specify how many workers to use.
>   - Run a scheduler on a different machine.
>   - Monitor task execution with **Dask’s dashboard**.
> 
> ---
> 
> ### **How to Replace `dask.bag` with Dask Distributed**
> #### **Step 1: Start a Dask Distributed Client**
> Modify your script to use `dask.distributed` instead of `dask.bag` with multiprocessing:
> ```python
> from dask.distributed import Client
> 
> if __name__ == "__main__":
>     client = Client()  # Automatically sets up a distributed scheduler
>     print(client)  # Shows connection info
> 
>     # Run your Dask computation here...
> ```
> - This automatically **creates a local cluster** with multiple workers.
> 
> #### **Step 2: Convert `dask.bag` to Use `dask.delayed` or Dask DataFrame**
> Instead of:
> ```python
> import dask.bag as db
> b = db.from_sequence(data).map(my_function)
> result = b.compute()
> ```
> Use:
> ```python
> from dask import delayed
> 
> tasks = [delayed(my_function)(item) for item in data]
> result = dask.compute(*tasks)
> ```
> - This approach **distributes the workload across the Dask cluster**.
> 
> #### **Step 3: Use `client.submit` for Parallel Execution**
> For better control, use `client.submit`:
> ```python
> futures = client.map(my_function, data)
> results = client.gather(futures)  # Collect results
> ```
> - This avoids the limitations of `dask.bag`.
> 
> ---
> 
> ### **Summary: Should You Replace `dask.bag`?**
> ✅ **Yes**, because:
> - Dask Distributed is **more stable with MPI**.
> - It provides **better process isolation**.
> - It supports **dynamic, fault-tolerant task execution**.
> 
> Would you like a more detailed example with your specific use case? 🚀

2. Figure out heuristics for dynamic and programmatic chunking

  • We need dynamic chunking because e3sm_diags reads input files with varying sizes.
  • To dynamically and programmatically decide when to chunk or not when using xarray.open_mfdataset(), you can employ a heuristic approach based on dataset size, memory usage, and available resources.
  1. Update functions that require .load() to operate lazily
  2. Run .compute() at the end of the computation chain
  3. Measure performance compared to v3.0.0

Describe alternatives you've considered

No response

Additional context

There are some MPI issues in #933 to be aware of.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions