Skip to content

Commit eaaeb53

Browse files
authored
Vt 514 dask usage via karabo is pit of success (#516)
* Made karabo "pit of success", added parameter if inside container for dask client creation. * Fixed missing type for mypy. * Fixed line too long and improved description. * Imported at the beginning. * Fixed failing tests * Added test for dask. * Added test for dask * Improved documentation.,
1 parent 19c379b commit eaaeb53

9 files changed

Lines changed: 417 additions & 151 deletions

File tree

doc/src/examples/example_structure.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,67 @@ Please look at the karabo.package documentation for specifics on the individual
1717

1818
![Image](../images/telescope.png)
1919

20+
## Parallel processing with Karabo
21+
22+
Karabo streamlines the process of setting up an environment for parallelization. Through its utility function `parallelize_with_dask`, Karabo nudges the user towards a seamless parallelization experience. By adhering to its format, users find themselves in a `pit of success` with parallel processing. This ensures efficient task distribution across multiple cores or even entire cluster nodes, especially when handling large datasets or tasks with high computational demands.
23+
24+
### Points to Consider When Using `parallelize_with_dask` and Dask in General
25+
26+
When leveraging the `parallelize_with_dask` function for parallel processing in Karabo, users should keep in mind the following best practices:
27+
28+
1. **Avoid Infinite Tasks**: Ensure that the tasks you're parallelizing have a defined end. Infinite or extremely long-running tasks can clog the parallelization pipeline.
29+
30+
2. **Beware of Massive Tasks**: Large tasks can monopolize resources, potentially causing an imbalance in the workload distribution. It's often more efficient to break massive tasks into smaller, more manageable chunks.
31+
32+
3. **No Open h5 Connections**: Objects with open h5 connections are not `pickleable`. This means that they cannot be serialized and sent to other processes. If you need to pass an object with an open h5 connection to a function, close the connection before passing it to the function, e.g. by calling `h5file.close()` or `.compute()` inside Karabo.
33+
34+
4. **Use `.compute()` on Dask Arrays**: Before passing Dask arrays to the function, call `.compute()` on them to realize their values. This avoids potential issues and ensures efficient processing.
35+
36+
5. **Refer to Dask's Best Practices**: For a more comprehensive understanding and to avoid common pitfalls, consult [Dask's official best practices guide](https://docs.dask.org/en/stable/best-practices.html).
37+
38+
Following these guidelines will help ensure that you get the most out of Karabo's parallel processing capabilities.
39+
40+
41+
### Parameters
42+
- iterate_function (callable): The function to be applied to each element of the iterable. This function should take the current element of the iterable as its first argument, followed by any specified positional and keyword arguments.
43+
44+
- iterable (iterable): The collection of elements over which the iterate_function will be applied.
45+
46+
- args (tuple): Positional arguments that will be passed to the iterate_function after the current element of the iterable.
47+
48+
- kwargs (dict): Keyword arguments that will be passed to the iterate_function.
49+
50+
### Returns
51+
- tuple: A tuple containing the results of the iterate_function for each element in the iterable. Results are gathered using Dask's compute function.
52+
53+
### Additional Notes
54+
It's important when working on a `Slurm Cluster` to call DaskHandler.setup() at the beginning.
55+
56+
If 'verbose' is specified in kwargs and is set to True, progress messages will be printed during processing.
57+
58+
The function internally uses the distributed scheduler of Dask.
59+
60+
Leverage the `parallelize_with_dask` utility in Karabo to harness the power of parallel processing and speed up your data-intensive operations.
61+
62+
### Function Signature
63+
64+
```python
65+
def parallelize_with_dask(
66+
iterate_function: Callable[..., Any],
67+
iterable: Iterable[Any],
68+
*args: Any,
69+
**kwargs: Any,
70+
) -> Union[Any, Tuple[Any, ...], List[Any]]:
71+
72+
# Example
73+
def my_function(element, *args, **kwargs):
74+
# Do something with element
75+
return result
76+
77+
parallelize_with_dask(my_function, my_iterable, *args, **kwargs) # The current element of the iterable is passed as the first argument to my_function
78+
>>> (result1, result2, result3, ...)
79+
```
80+
2081
## Use Karabo on a SLURM cluster
2182

2283
Karabo manages all available nodes through Dask, making the computational power conveniently accessible for the user. The `DaskHandler` class streamlines the creation of a Dask client and offers a user-friendly interface for interaction. This class contains static variables, which when altered, modify the behavior of the Dask client.

docker/dev/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# LEGACY-FILE, has to be checked before usage
22
# Create build container to not have copied filed in real container afterwards
33
FROM --platform=amd64 continuumio/miniconda3:4.12.0 as build
4+
ARG IS_DOCKER_CONTAINER=true
45
COPY environment.yaml environment.yaml
56
COPY requirements.txt requirements.txt
67

docker/user/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Create build container to not have copied filed in real container afterwards
22
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 as build
33
ARG KARABO_TAG
4+
ARG IS_DOCKER_CONTAINER=true
45
RUN apt-get update && apt-get install -y git
56
RUN git clone --branch ${KARABO_TAG} --depth=1 https://github.com/i4Ds/Karabo-Pipeline.git
67

environment.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ dependencies:
1010
- bluebild
1111
- cuda-cudart
1212
- dask=2022.12.1
13+
- dask-mpi
14+
- mpi4py
1315
- distributed
1416
- eidos=1.1.0
1517
- healpy

karabo/examples/HIIM_Img_Recovery.ipynb

Lines changed: 193 additions & 73 deletions
Large diffs are not rendered by default.

karabo/simulation/line_emission.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
from astropy.convolution import Gaussian2DKernel
1414
from astropy.io import fits
1515
from astropy.wcs import WCS
16-
17-
# from dask.delayed import Delayed
18-
from dask import compute, delayed # type: ignore[attr-defined]
19-
from dask.distributed import Client
2016
from numpy.typing import NDArray
2117

2218
from karabo.imaging.imager import Imager
@@ -26,7 +22,9 @@
2622
from karabo.simulation.telescope import Telescope
2723
from karabo.simulation.visibility import Visibility
2824
from karabo.util._types import DirPathType, FilePathType, IntFloat, NPFloatLikeStrict
29-
from karabo.util.dask import DaskHandler
25+
26+
# from dask.delayed import Delayed
27+
from karabo.util.dask import parallelize_with_dask
3028
from karabo.util.plotting_util import get_slices
3129

3230

@@ -226,7 +224,7 @@ def plot_scatter_recon(
226224
plt.savefig(outfile)
227225

228226

229-
def sky_slice(sky: SkyModel, z_min: np.float_, z_max: np.float_) -> SkyModel:
227+
def sky_slice(sky: SkyModel, z_min: IntFloat, z_max: IntFloat) -> SkyModel:
230228
"""
231229
Extracting a slice from the sky which includes only sources between redshift z_min
232230
and z_max.
@@ -240,15 +238,7 @@ def sky_slice(sky: SkyModel, z_min: np.float_, z_max: np.float_) -> SkyModel:
240238
:return: Sky model only including the sources with redshifts between z_min and
241239
z_max.
242240
"""
243-
sky_bin = SkyModel.copy_sky(sky)
244-
if sky_bin.sources is None:
245-
raise TypeError("`sky.sources` is None which is not allowed.")
246-
247-
z_obs = sky_bin.sources[:, 13]
248-
sky_bin_idx = np.where((z_obs > z_min) & (z_obs < z_max))
249-
sky_bin.sources = sky_bin.sources[sky_bin_idx]
250-
251-
return sky_bin
241+
return sky.filter_by_column(13, z_min, z_max)
252242

253243

254244
T = TypeVar("T", NDArray[np.float_], xr.DataArray, IntFloat)
@@ -462,8 +452,6 @@ def run_one_channel_simulation(
462452
path: FilePathType,
463453
sky: SkyModel,
464454
telescope: Telescope,
465-
z_min: np.float_,
466-
z_max: np.float_,
467455
freq_bin_start: float,
468456
freq_bin_width: float,
469457
ra_deg: IntFloat,
@@ -488,8 +476,6 @@ def run_one_channel_simulation(
488476
of each source.
489477
:param telescope: Telescope used. If None, the MEERKAT telescope will be used as a
490478
default.
491-
:param z_min: Smallest redshift in this bin.
492-
:param z_max: Largest redshift in this bin.
493479
:param freq_bin_start: Starting frequency in this bin
494480
(i.e., largest frequency in the bin).
495481
:param freq_bin_width: Size of the sky frequency bin which is simulated.
@@ -513,15 +499,13 @@ def run_one_channel_simulation(
513499
:return: Reconstruction of one bin slice of the sky and its header.
514500
"""
515501

516-
sky_bin = sky_slice(sky, z_min, z_max)
517-
518502
if verbose:
519503
print("Starting simulation...")
520504

521505
freq_bin_middle = freq_bin_start - freq_bin_width / 2
522506
dirty_image, header = karabo_reconstruction(
523507
path,
524-
sky=sky_bin,
508+
sky=sky,
525509
telescope=telescope,
526510
ra_deg=ra_deg,
527511
dec_deg=dec_deg,
@@ -594,7 +578,6 @@ def line_emission_pointing(
594578
img_size: int = 4096,
595579
circle: bool = True,
596580
rascil: bool = True,
597-
client: Optional[Client] = None,
598581
verbose: bool = False,
599582
) -> Tuple[NDArray[np.float_], List[NDArray[np.float_]], fits.header.Header, np.float_]:
600583
"""
@@ -629,7 +612,6 @@ def line_emission_pointing(
629612
:param circle: If set to True, the pointing has a round shape of size cut.
630613
:param rascil: If True we use the Imager Rascil otherwise the Imager from Oskar is
631614
used.
632-
:param client: Setting a dask client is optional.
633615
:param verbose: If True you get more print statements.
634616
:return: Total line emission reconstruction, 3D line emission reconstruction,
635617
Header of reconstruction and mean frequency.
@@ -647,18 +629,12 @@ def line_emission_pointing(
647629

648630
os.makedirs(outpath)
649631

650-
# Load sky into memory and close connection to h5
651-
sky.compute()
652-
653632
if sky.sources is None:
654633
raise TypeError(
655634
"`sources` None is not allowed! Please set them in"
656635
" the `SkyModel` before calling this function."
657636
)
658637

659-
if not client:
660-
client = DaskHandler.get_dask_client()
661-
662638
redshift_channel, freq_channel, freq_bin, freq_mid = freq_channels(
663639
z_obs=sky.sources[:, 13],
664640
channel_num=num_bins,
@@ -669,20 +645,39 @@ def line_emission_pointing(
669645
n_jobs = num_bins
670646
print(f"Submitting {n_jobs} jobs to the cluster.")
671647

672-
delayed_results = []
648+
# Load the sky into memory
649+
sky.compute()
673650

674-
for bin_idx in range(num_bins):
675-
if verbose:
676-
print(
677-
f"Channel {bin_idx} is being processed...\n"
678-
"Extracting the corresponding frequency slice from the sky model..."
679-
)
680-
delayed_ = delayed(run_one_channel_simulation)(
651+
# Helper function to parallise with dask
652+
def process_channel( # type: ignore[no-untyped-def]
653+
bin_idx,
654+
outpath,
655+
sky,
656+
telescope,
657+
redshift_channel,
658+
freq_channel,
659+
ra_deg,
660+
dec_deg,
661+
beam_type,
662+
gaussian_fwhm,
663+
gaussian_ref_freq,
664+
start_time,
665+
obs_length,
666+
cut,
667+
img_size,
668+
circle,
669+
rascil,
670+
verbose,
671+
):
672+
# Do the sky slicing here, so that less data is sent to each worker
673+
z_min = redshift_channel[bin_idx]
674+
z_max = redshift_channel[bin_idx + 1]
675+
sky_bin = sky_slice(sky, z_min, z_max)
676+
677+
return run_one_channel_simulation(
681678
path=outpath / f"slice_{bin_idx}",
682-
sky=sky,
679+
sky=sky_bin,
683680
telescope=telescope,
684-
z_min=redshift_channel[bin_idx],
685-
z_max=redshift_channel[bin_idx + 1],
686681
freq_bin_start=freq_channel[bin_idx],
687682
freq_bin_width=freq_bin[bin_idx],
688683
ra_deg=ra_deg,
@@ -698,9 +693,29 @@ def line_emission_pointing(
698693
rascil=rascil,
699694
verbose=verbose,
700695
)
701-
delayed_results.append(delayed_)
702696

703-
result = compute(*delayed_results, scheduler="distributed")
697+
result = parallelize_with_dask(
698+
process_channel,
699+
range(num_bins),
700+
outpath=outpath,
701+
sky=sky,
702+
telescope=telescope,
703+
redshift_channel=redshift_channel,
704+
freq_channel=freq_channel,
705+
ra_deg=ra_deg,
706+
dec_deg=dec_deg,
707+
beam_type=beam_type,
708+
gaussian_fwhm=gaussian_fwhm,
709+
gaussian_ref_freq=gaussian_ref_freq,
710+
start_time=start_time,
711+
obs_length=obs_length,
712+
cut=cut,
713+
img_size=img_size,
714+
circle=circle,
715+
rascil=rascil,
716+
verbose=verbose,
717+
)
718+
704719
dirty_images = [x[0] for x in result]
705720
headers = [x[1] for x in result]
706721
header = headers[0]

karabo/simulation/sky_model.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -618,29 +618,29 @@ def filter_by_radius_euclidean_flat_approximation(
618618
else:
619619
return copied_sky
620620

621-
def filter_by_flux(
621+
def filter_by_column(
622622
self,
623-
min_flux_jy: IntFloat,
624-
max_flux_jy: IntFloat,
623+
col_idx: int,
624+
min_val: IntFloat,
625+
max_val: IntFloat,
625626
) -> SkyModel:
626627
"""
627-
Filters the sky using the Stokes-I-flux
628-
Values outside the range are removed
628+
Filters the sky based on a specific column index
629629
630-
:param min_flux_jy: Minimum flux in Jy
631-
:param max_flux_jy: Maximum flux in Jy
630+
:param col_idx: Column index to filter by
631+
:param min_val: Minimum value for the column
632+
:param max_val: Maximum value for the column
632633
:return sky: Filtered copy of the sky
633634
"""
634635
copied_sky = SkyModel.copy_sky(self)
635636
if copied_sky.sources is None:
636637
raise KaraboSkyModelError(
637-
"`sources` None is not allowed. "
638-
+ "Add sources before calling `filter_by_flux`."
638+
"`sources` is None, add sources before filtering."
639639
)
640640

641641
# Create mask
642-
filter_mask = (copied_sky[:, 2] >= min_flux_jy) & (
643-
copied_sky[:, 2] <= max_flux_jy
642+
filter_mask = (copied_sky.sources[:, col_idx] >= min_val) & (
643+
copied_sky.sources[:, col_idx] <= max_val
644644
)
645645
filter_mask = self.rechunk_array_based_on_self(filter_mask)
646646

@@ -649,34 +649,19 @@ def filter_by_flux(
649649

650650
return copied_sky
651651

652+
def filter_by_flux(
653+
self,
654+
min_flux_jy: IntFloat,
655+
max_flux_jy: IntFloat,
656+
) -> SkyModel:
657+
return self.filter_by_column(2, min_flux_jy, max_flux_jy)
658+
652659
def filter_by_frequency(
653660
self,
654661
min_freq: IntFloat,
655662
max_freq: IntFloat,
656663
) -> SkyModel:
657-
"""
658-
Filters the sky using the reference frequency in Hz
659-
660-
:param min_freq: Minimum frequency in Hz
661-
:param max_freq: Maximum frequency in Hz
662-
:return sky: Filtered copy of the sky
663-
"""
664-
copied_sky = SkyModel.copy_sky(self)
665-
if copied_sky.sources is None:
666-
raise KaraboSkyModelError(
667-
"`sources` is None, add sources before calling `filter_by_frequency`."
668-
)
669-
670-
# Create mask
671-
filter_mask = (copied_sky.sources[:, 6] >= min_freq) & (
672-
copied_sky.sources[:, 6] <= max_freq
673-
)
674-
filter_mask = self.rechunk_array_based_on_self(filter_mask)
675-
676-
# Apply the filter mask and drop the unmatched rows
677-
copied_sky.sources = copied_sky.sources.where(filter_mask, drop=True)
678-
679-
return copied_sky
664+
return self.filter_by_column(6, min_freq, max_freq)
680665

681666
def get_wcs(self) -> WCS:
682667
"""

0 commit comments

Comments
 (0)