Skip to content

Commit fa1399a

Browse files
committed
provide interface for serializing taskgraphs to/from disk
1 parent 22c8bd8 commit fa1399a

File tree

4 files changed

+158
-44
lines changed

4 files changed

+158
-44
lines changed

Diff for: src/coffea/dataset_tools/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from coffea.dataset_tools.apply_processor import apply_to_dataset, apply_to_fileset
1+
from coffea.dataset_tools.apply_processor import (
2+
apply_to_dataset,
3+
apply_to_fileset,
4+
load_taskgraph,
5+
save_taskgraph,
6+
)
27
from coffea.dataset_tools.manipulations import (
38
filter_files,
49
get_failed_steps_for_dataset,
@@ -14,6 +19,8 @@
1419
"preprocess",
1520
"apply_to_dataset",
1621
"apply_to_fileset",
22+
"save_taskgraph",
23+
"load_taskgraph",
1724
"max_chunks",
1825
"slice_chunks",
1926
"filter_files",

Diff for: src/coffea/dataset_tools/apply_processor.py

+91-32
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from coffea.nanoevents import BaseSchema, NanoAODSchema, NanoEventsFactory
1919
from coffea.processor import ProcessorABC
20-
from coffea.util import decompress_form
20+
from coffea.util import decompress_form, load, save
2121

2222
DaskOutputBaseType = Union[
2323
dask.base.DaskMethodsMixin,
@@ -48,8 +48,6 @@ def _pack_meta_to_wire(*collections):
4848
attrs=unpacked[i]._meta.attrs,
4949
)
5050
packed_out = repacker(output)
51-
if len(packed_out) == 1:
52-
return packed_out[0]
5351
return packed_out
5452

5553

@@ -68,21 +66,13 @@ def _unpack_meta_from_wire(*collections):
6866
attrs=unpacked[i]._meta.attrs,
6967
)
7068
packed_out = repacker(output)
71-
if len(packed_out) == 1:
72-
return packed_out[0]
7369
return packed_out
7470

7571

76-
def _apply_analysis_wire(analysis, events_and_maybe_report_wire):
77-
events = _unpack_meta_from_wire(events_and_maybe_report_wire)
78-
report = None
79-
if isinstance(events, tuple):
80-
events, report = events
72+
def _apply_analysis_wire(analysis, events_wire):
73+
(events,) = _unpack_meta_from_wire(events_wire)
8174
events._meta.attrs["@original_array"] = events
82-
8375
out = analysis(events)
84-
if report is not None:
85-
return _pack_meta_to_wire(out, report)
8676
return _pack_meta_to_wire(out)
8777

8878

@@ -145,16 +135,14 @@ def apply_to_dataset(
145135

146136
out = None
147137
if parallelize_with_dask:
148-
if not isinstance(events_and_maybe_report, tuple):
149-
events_and_maybe_report = (events_and_maybe_report,)
150-
wired_events = _pack_meta_to_wire(*events_and_maybe_report)
138+
(wired_events,) = _pack_meta_to_wire(events)
151139
out = dask.delayed(partial(_apply_analysis_wire, analysis, wired_events))()
152140
else:
153141
out = analysis(events)
154142

155143
if report is not None:
156-
return out, report
157-
return out
144+
return events, out, report
145+
return events, out
158146

159147

160148
def apply_to_fileset(
@@ -184,11 +172,14 @@ def apply_to_fileset(
184172
185173
Returns
186174
-------
175+
events: dict[str, dask_awkward.Array]
176+
The NanoEvents objects the analysis function was applied to.
187177
out : dict[str, DaskOutputType]
188178
The output of the analysis workflow applied to the datasets, keyed by dataset name.
189179
report : dask_awkward.Array, optional
190180
The file access report for running the analysis on the input dataset. Needs to be computed in simultaneously with the analysis to be accurate.
191181
"""
182+
events = {}
192183
out = {}
193184
analyses_to_compute = {}
194185
report = {}
@@ -206,24 +197,92 @@ def apply_to_fileset(
206197
parallelize_with_dask,
207198
)
208199
if parallelize_with_dask:
209-
analyses_to_compute[name] = dataset_out
210-
elif isinstance(dataset_out, tuple):
211-
out[name], report[name] = dataset_out
200+
if len(dataset_out) == 3:
201+
events[name], analyses_to_compute[name], report[name] = dataset_out
202+
elif len(dataset_out) == 2:
203+
events[name], analyses_to_compute[name] = dataset_out
204+
else:
205+
raise ValueError(
206+
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
207+
)
208+
elif isinstance(dataset_out, tuple) and len(dataset_out) == 3:
209+
events[name], out[name], report[name] = dataset_out
210+
elif isinstance(dataset_out, tuple) and len(dataset_out) == 2:
211+
events[name], out[name] = dataset_out
212212
else:
213-
out[name] = dataset_out
213+
raise ValueError(
214+
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
215+
)
214216

215217
if parallelize_with_dask:
216218
(calculated_graphs,) = dask.compute(analyses_to_compute, scheduler=scheduler)
217219
for name, dataset_out_wire in calculated_graphs.items():
218-
to_unwire = dataset_out_wire
219-
if not isinstance(dataset_out_wire, tuple):
220-
to_unwire = (dataset_out_wire,)
221-
dataset_out = _unpack_meta_from_wire(*to_unwire)
222-
if isinstance(dataset_out, tuple):
223-
out[name], report[name] = dataset_out
224-
else:
225-
out[name] = dataset_out
220+
(out[name],) = _unpack_meta_from_wire(*dataset_out_wire)
226221

227222
if len(report) > 0:
228-
return out, report
229-
return out
223+
return events, out, report
224+
return events, out
225+
226+
227+
def save_taskgraph(filename, events, *data_products, optimize_graph=False):
228+
"""
229+
Save a task graph and its originating nanoevents to a file
230+
Parameters
231+
----------
232+
filename: str
233+
Where to save the resulting serialized taskgraph and nanoevents.
234+
Suggested postfix ".hlg", after dask's HighLevelGraph object.
235+
events: dict[str, dask_awkward.Array]
236+
A dictionary of nanoevents objects.
237+
data_products: dict[str, DaskOutputBaseType]
238+
The data products resulting from applying an analysis to
239+
a NanoEvents object. This may include report objects.
240+
optimize_graph: bool, default False
241+
Whether or not to save the task graph in its optimized form.
242+
243+
Returns
244+
-------
245+
"""
246+
(events_wire,) = _pack_meta_to_wire(events)
247+
248+
if len(data_products) == 0:
249+
raise ValueError(
250+
"You must supply at least one analysis data product to save a task graph!"
251+
)
252+
253+
data_products_out = data_products
254+
if optimize_graph:
255+
data_products_out = dask.optimize(data_products)
256+
257+
data_products_wire = _pack_meta_to_wire(*data_products_out)
258+
259+
save(
260+
{
261+
"events": events_wire,
262+
"data_products": data_products_wire,
263+
"optimized": optimize_graph,
264+
},
265+
filename,
266+
)
267+
268+
269+
def load_taskgraph(filename):
270+
"""
271+
Load a task graph and its originating nanoevents from a file.
272+
Parameters
273+
----------
274+
filename: str
275+
The file from which to load the task graph.
276+
Returns
277+
_______
278+
"""
279+
graph_information_wire = load(filename)
280+
281+
(events,) = _unpack_meta_from_wire(graph_information_wire["events"])
282+
(data_products,) = _unpack_meta_from_wire(*graph_information_wire["data_products"])
283+
optimized = graph_information_wire["optimized"]
284+
285+
for dataset_name in events:
286+
events[dataset_name]._meta.attrs["@original_array"] = events[dataset_name]
287+
288+
return events, data_products, optimized

Diff for: src/coffea/util.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,20 @@
3636
import lz4.frame
3737

3838

39-
def load(filename):
39+
def load(filename, mode="rb"):
4040
"""Load a coffea file from disk"""
41-
with lz4.frame.open(filename) as fin:
41+
with lz4.frame.open(filename, mode) as fin:
4242
output = cloudpickle.load(fin)
4343
return output
4444

4545

46-
def save(output, filename):
46+
def save(output, filename, mode="wb"):
4747
"""Save a coffea object or collection thereof to disk
4848
4949
This function can accept any picklable object. Suggested suffix: ``.coffea``
5050
"""
51-
with lz4.frame.open(filename, "wb") as fout:
52-
thepickle = cloudpickle.dumps(output)
53-
fout.write(thepickle)
51+
with lz4.frame.open(filename, mode) as fout:
52+
cloudpickle.dump(output, fout)
5453

5554

5655
def _hex(string):

Diff for: tests/test_dataset_tools.py

+54-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
apply_to_fileset,
88
filter_files,
99
get_failed_steps_for_fileset,
10+
load_taskgraph,
1011
max_chunks,
1112
max_files,
1213
preprocess,
14+
save_taskgraph,
1315
slice_chunks,
1416
slice_files,
1517
)
@@ -202,7 +204,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc):
202204
proc, schemaclass = proc_and_schema
203205

204206
with Client() as _:
205-
to_compute = apply_to_fileset(
207+
_, to_compute = apply_to_fileset(
206208
proc(),
207209
_runnable_result,
208210
schemaclass=schemaclass,
@@ -215,7 +217,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc):
215217
assert out["Data"]["cutflow"]["Data_pt"] == 84
216218
assert out["Data"]["cutflow"]["Data_mass"] == 66
217219

218-
to_compute = apply_to_fileset(
220+
_, to_compute = apply_to_fileset(
219221
proc(),
220222
max_chunks(_runnable_result, 1),
221223
schemaclass=schemaclass,
@@ -240,7 +242,7 @@ def test_apply_to_fileset_hinted_form():
240242
save_form=True,
241243
)
242244

243-
to_compute = apply_to_fileset(
245+
_, to_compute = apply_to_fileset(
244246
NanoEventsProcessor(),
245247
dataset_runnable,
246248
schemaclass=NanoAODSchema,
@@ -445,14 +447,14 @@ def test_slice_chunks():
445447
@pytest.mark.parametrize("delayed_taskgraph_calc", [True, False])
446448
def test_recover_failed_chunks(delayed_taskgraph_calc):
447449
with Client() as _:
448-
to_compute = apply_to_fileset(
450+
_, to_compute, reports = apply_to_fileset(
449451
NanoEventsProcessor(),
450452
_starting_fileset_with_steps,
451453
schemaclass=NanoAODSchema,
452454
uproot_options={"allow_read_errors_with_report": True},
453455
parallelize_with_dask=delayed_taskgraph_calc,
454456
)
455-
out, reports = dask.compute(*to_compute)
457+
out, reports = dask.compute(to_compute, reports)
456458

457459
failed_fset = get_failed_steps_for_fileset(_starting_fileset_with_steps, reports)
458460
assert failed_fset == {
@@ -474,3 +476,50 @@ def test_recover_failed_chunks(delayed_taskgraph_calc):
474476
}
475477
}
476478
}
479+
480+
481+
@pytest.mark.parametrize(
482+
"proc_and_schema",
483+
[(NanoTestProcessor, BaseSchema), (NanoEventsProcessor, NanoAODSchema)],
484+
)
485+
@pytest.mark.parametrize(
486+
"with_report",
487+
[True, False],
488+
)
489+
def test_task_graph_serialization(proc_and_schema, with_report):
490+
proc, schemaclass = proc_and_schema
491+
492+
with Client() as _:
493+
output = apply_to_fileset(
494+
proc(),
495+
_runnable_result,
496+
schemaclass=schemaclass,
497+
parallelize_with_dask=False,
498+
uproot_options={"allow_read_errors_with_report": with_report},
499+
)
500+
501+
events = output[0]
502+
to_compute = output[1:]
503+
504+
save_taskgraph(
505+
"./test_task_graph_serialization.hlg",
506+
events,
507+
to_compute,
508+
optimize_graph=False,
509+
)
510+
511+
_, to_compute_serdes, is_optimized = load_taskgraph(
512+
"./test_task_graph_serialization.hlg"
513+
)
514+
515+
print(to_compute_serdes)
516+
517+
if len(to_compute_serdes) > 1:
518+
(out, _) = dask.compute(*to_compute_serdes)
519+
else:
520+
(out,) = dask.compute(*to_compute_serdes)
521+
522+
assert out["ZJets"]["cutflow"]["ZJets_pt"] == 18
523+
assert out["ZJets"]["cutflow"]["ZJets_mass"] == 6
524+
assert out["Data"]["cutflow"]["Data_pt"] == 84
525+
assert out["Data"]["cutflow"]["Data_mass"] == 66

0 commit comments

Comments
 (0)