-
Notifications
You must be signed in to change notification settings - Fork 34
Description
Is your feature request related to a problem?
E3SM Diags performs diagnostic runs in parallel on a per variable level, based on the number of assigned workers (num_workers parameter). It utilizes a function called _run_with_dask() which passes the list of diagnostic runs to dask.bag. dask.bag is essentially a wrapper of Python multiprocessing.
e3sm_diags/e3sm_diags/e3sm_diags_driver.py
Lines 334 to 376 in e2fcc3a
| def _run_with_dask(parameters: List[CoreParameter]) -> List[CoreParameter]: | |
| """Run diagnostics with the parameters in parallel using Dask. | |
| This function passes ``run_diag`` to ``dask.bag.map``, which gets executed | |
| in parallel with ``.compute``. | |
| The first CoreParameter object's `num_workers` attribute is used to set | |
| the number of workers for ``.compute``. | |
| Parameters | |
| ---------- | |
| parameters : List[CoreParameter] | |
| The list of CoreParameter objects to run diagnostics on. | |
| Returns | |
| ------- | |
| List[CoreParameter] | |
| The list of CoreParameter objects with results from the diagnostic run. | |
| Notes | |
| ----- | |
| https://docs.dask.org/en/stable/generated/dask.bag.map.html | |
| https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html | |
| """ | |
| bag = db.from_sequence(parameters) | |
| config = {"scheduler": "processes", "multiprocessing.context": "fork"} | |
| num_workers = getattr(parameters[0], "num_workers", None) | |
| if num_workers is None: | |
| raise ValueError( | |
| "The `num_workers` attribute is required for multiprocessing but it is not " | |
| "defined on the CoreParameter object. Set this attribute and try running " | |
| "again." | |
| ) | |
| with dask.config.set(config): | |
| results = bag.map(CoreParameter._run_diag).compute(num_workers=num_workers) | |
| # `results` becomes a list of lists of parameters so it needs to be | |
| # collapsed a level. | |
| collapsed_results = _collapse_results(results) | |
| return collapsed_results |
On each worker, e3sm_diags loads datasets into memory via .load(). Loading data into memory presents the following issues:
- Assumes data fits into memory and chokes on large datasets that exceed memory limits.
- There will be a multiprocessing scheduler conflict the
dask.bagscheduler and the one created by Xarray/Dask automatically withopen_mfdataset(). As a result, datasets can't be chunked for parallel operations downstream.
Example of where .load() is called:
e3sm_diags/e3sm_diags/driver/utils/dataset_xr.py
Lines 1518 to 1527 in e2fcc3a
| def _subset_vars_and_load(self, ds: xr.Dataset, var: str | List[str]) -> xr.Dataset: | |
| """Subset for variables needed for processing and load into memory. | |
| Subsetting the dataset reduces its memory footprint. Loading is | |
| necessary because there seems to be an issue with `open_mfdataset()` | |
| and using the multiprocessing scheduler defined in e3sm_diags, | |
| resulting in timeouts and resource locking. To avoid this, we load the | |
| multi-file dataset into memory before performing downstream operations. | |
| Source: https://github.com/pydata/xarray/issues/3781 |
Describe the solution you'd like
We want to utilize Dask more effectively in order to handle datasets much larger than memory by chunking and working on small pieces at a time. Dask can schedule operations efficiently to maximize throughput. On HPCs, Dask can distribute work across nodes, making big workflows much faster.
In order to do that, we must:
- Replace
dask.bagwith dask's distributed cluster (related discussion).
Asking ChatGPT: Should I replace dask.bag with a dask distributed scheduler?
>
> Yes, replacing `dask.bag` with **Dask Distributed** is a good idea because it provides better **process isolation**, **fault tolerance**, and **MPI compatibility**. Here’s why:
>
> ---
>
> ### **Why Use Dask Distributed Instead of `dask.bag` with Multiprocessing?**
> #### **1. Avoids `multiprocessing`-related Crashes**
> - `dask.bag` with `multiprocessing` relies on `concurrent.futures.ProcessPoolExecutor`, which can lead to **MPI conflicts** and the `BrokenProcessPool` error.
> - Dask Distributed runs **each worker in its own separate process**, avoiding `fork()` issues.
>
> #### **2. Better Handling of MPI**
> - Dask Distributed **isolates workers properly**, preventing unwanted MPI reinitialization.
> - Unlike Python’s `multiprocessing`, Dask Distributed doesn’t inherit an already-initialized MPI environment across processes.
>
> #### **3. More Robust and Scalable**
> - Supports **dynamic task scheduling**, while `dask.bag` is more static.
> - Handles **worker failures gracefully** (e.g., if a process crashes, the scheduler can restart it).
>
> #### **4. More Control Over Worker Processes**
> - With Dask Distributed, you can:
> - Specify how many workers to use.
> - Run a scheduler on a different machine.
> - Monitor task execution with **Dask’s dashboard**.
>
> ---
>
> ### **How to Replace `dask.bag` with Dask Distributed**
> #### **Step 1: Start a Dask Distributed Client**
> Modify your script to use `dask.distributed` instead of `dask.bag` with multiprocessing:
> ```python
> from dask.distributed import Client
>
> if __name__ == "__main__":
> client = Client() # Automatically sets up a distributed scheduler
> print(client) # Shows connection info
>
> # Run your Dask computation here...
> ```
> - This automatically **creates a local cluster** with multiple workers.
>
> #### **Step 2: Convert `dask.bag` to Use `dask.delayed` or Dask DataFrame**
> Instead of:
> ```python
> import dask.bag as db
> b = db.from_sequence(data).map(my_function)
> result = b.compute()
> ```
> Use:
> ```python
> from dask import delayed
>
> tasks = [delayed(my_function)(item) for item in data]
> result = dask.compute(*tasks)
> ```
> - This approach **distributes the workload across the Dask cluster**.
>
> #### **Step 3: Use `client.submit` for Parallel Execution**
> For better control, use `client.submit`:
> ```python
> futures = client.map(my_function, data)
> results = client.gather(futures) # Collect results
> ```
> - This avoids the limitations of `dask.bag`.
>
> ---
>
> ### **Summary: Should You Replace `dask.bag`?**
> ✅ **Yes**, because:
> - Dask Distributed is **more stable with MPI**.
> - It provides **better process isolation**.
> - It supports **dynamic, fault-tolerant task execution**.
>
> Would you like a more detailed example with your specific use case? 🚀
2. Figure out heuristics for dynamic and programmatic chunking
We need dynamic chunking because e3sm_diags reads input files with varying sizes.To dynamically and programmatically decide when to chunk or not when using xarray.open_mfdataset(), you can employ a heuristic approach based on dataset size, memory usage, and available resources.
- Update functions that require
.load()to operate lazily - Run
.compute()at the end of the computation chain - Measure performance compared to v3.0.0
Describe alternatives you've considered
No response
Additional context
There are some MPI issues in #933 to be aware of.