Skip to content

Commit 7af15c0

Browse files
committed
add in hooks for delayed calculation of task graph
1 parent 2b321ee commit 7af15c0

File tree

1 file changed

+72
-8
lines changed

1 file changed

+72
-8
lines changed

Diff for: src/coffea/dataset_tools/apply_processor.py

+72-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
import copy
4+
from functools import partial
45
from typing import Any, Callable, Dict, Hashable, List, Set, Tuple, Union
56

67
import awkward
78
import dask.base
9+
import dask.delayed
810
import dask_awkward
911

1012
from coffea.dataset_tools.preprocess import (
@@ -31,12 +33,66 @@
3133
GenericHEPAnalysis = Callable[[dask_awkward.Array], DaskOutputType]
3234

3335

36+
def _pack_meta_to_wire(*collections):
37+
unpacked, repacker = dask.base.unpack_collections(*collections)
38+
39+
output = []
40+
for i in range(len(unpacked)):
41+
output.append(unpacked[i])
42+
if isinstance(
43+
unpacked[i], (dask_awkward.Array, dask_awkward.Record, dask_awkward.Scalar)
44+
):
45+
output[-1]._meta = awkward.Array(
46+
unpacked[i]._meta.layout.form.length_zero_array(),
47+
behavior=unpacked[i]._meta.behavior,
48+
attrs=unpacked[i]._meta.attrs,
49+
)
50+
packed_out = repacker(output)
51+
if len(packed_out) == 1:
52+
return packed_out[0]
53+
return packed_out
54+
55+
56+
def _unpack_meta_from_wire(*collections):
57+
unpacked, repacker = dask.base.unpack_collections(*collections)
58+
59+
output = []
60+
for i in range(len(unpacked)):
61+
output.append(unpacked[i])
62+
if isinstance(
63+
unpacked[i], (dask_awkward.Array, dask_awkward.Record, dask_awkward.Scalar)
64+
):
65+
output[-1]._meta = awkward.Array(
66+
unpacked[i]._meta.layout.to_typetracer(forget_length=True),
67+
behavior=unpacked[i]._meta.behavior,
68+
attrs=unpacked[i]._meta.attrs,
69+
)
70+
packed_out = repacker(output)
71+
if len(packed_out) == 1:
72+
return packed_out[0]
73+
return packed_out
74+
75+
76+
def _apply_analysis(analysis, events_and_maybe_report):
77+
events = events_and_maybe_report
78+
report = None
79+
if isinstance(events_and_maybe_report, tuple):
80+
events, report = events_and_maybe_report
81+
82+
out = analysis(events)
83+
84+
if report is not None:
85+
return out, report
86+
return out
87+
88+
3489
def apply_to_dataset(
3590
data_manipulation: ProcessorABC | GenericHEPAnalysis,
3691
dataset: DatasetSpec | DatasetSpecOptional,
3792
schemaclass: BaseSchema = NanoAODSchema,
3893
metadata: dict[Hashable, Any] = {},
3994
uproot_options: dict[str, Any] = {},
95+
parallelize_with_dask: bool = False,
4096
) -> DaskOutputType | tuple[DaskOutputType, dask_awkward.Array]:
4197
"""
4298
Apply the supplied function or processor to the supplied dataset.
@@ -52,6 +108,8 @@ def apply_to_dataset(
52108
Metadata for the dataset that is accessible by the input analysis. Should also be dask-serializable.
53109
uproot_options: dict[str, Any], default {}
54110
Options to pass to uproot. Pass at least {"allow_read_errors_with_report": True} to turn on file access reports.
111+
parallelize_with_dask: bool, default False
112+
Create dask.delayed objects that will return the the computable dask collections for the analysis when computed.
55113
56114
Returns
57115
-------
@@ -64,26 +122,32 @@ def apply_to_dataset(
64122
if maybe_base_form is not None:
65123
maybe_base_form = awkward.forms.from_json(decompress_form(maybe_base_form))
66124
files = dataset["files"]
67-
events = NanoEventsFactory.from_root(
125+
events_and_maybe_report = NanoEventsFactory.from_root(
68126
files,
69127
metadata=metadata,
70128
schemaclass=schemaclass,
71129
known_base_form=maybe_base_form,
72130
uproot_options=uproot_options,
73131
).events()
74132

75-
report = None
76-
if isinstance(events, tuple):
77-
events, report = events
78-
79-
out = None
133+
analysis = None
80134
if isinstance(data_manipulation, ProcessorABC):
81-
out = data_manipulation.process(events)
135+
analysis = data_manipulation.process
82136
elif isinstance(data_manipulation, Callable):
83-
out = data_manipulation(events)
137+
out = data_manipulation
84138
else:
85139
raise ValueError("data_manipulation must either be a ProcessorABC or Callable")
86140

141+
out = None
142+
if parallelize_with_dask:
143+
out = dask.delayed(partial(_apply_analysis, analysis, events_and_maybe_report))
144+
else:
145+
out = _apply_analysis(analysis, events_and_maybe_report)
146+
147+
report = None
148+
if isinstance(out, tuple):
149+
out, report = out
150+
87151
if report is not None:
88152
return out, report
89153
return out

0 commit comments

Comments
 (0)