17
17
)
18
18
from coffea .nanoevents import BaseSchema , NanoAODSchema , NanoEventsFactory
19
19
from coffea .processor import ProcessorABC
20
- from coffea .util import decompress_form
20
+ from coffea .util import decompress_form , load , save
21
21
22
22
DaskOutputBaseType = Union [
23
23
dask .base .DaskMethodsMixin ,
@@ -48,8 +48,6 @@ def _pack_meta_to_wire(*collections):
48
48
attrs = unpacked [i ]._meta .attrs ,
49
49
)
50
50
packed_out = repacker (output )
51
- if len (packed_out ) == 1 :
52
- return packed_out [0 ]
53
51
return packed_out
54
52
55
53
@@ -68,21 +66,13 @@ def _unpack_meta_from_wire(*collections):
68
66
attrs = unpacked [i ]._meta .attrs ,
69
67
)
70
68
packed_out = repacker (output )
71
- if len (packed_out ) == 1 :
72
- return packed_out [0 ]
73
69
return packed_out
74
70
75
71
76
- def _apply_analysis_wire (analysis , events_and_maybe_report_wire ):
77
- events = _unpack_meta_from_wire (events_and_maybe_report_wire )
78
- report = None
79
- if isinstance (events , tuple ):
80
- events , report = events
72
+ def _apply_analysis_wire (analysis , events_wire ):
73
+ (events ,) = _unpack_meta_from_wire (events_wire )
81
74
events ._meta .attrs ["@original_array" ] = events
82
-
83
75
out = analysis (events )
84
- if report is not None :
85
- return _pack_meta_to_wire (out , report )
86
76
return _pack_meta_to_wire (out )
87
77
88
78
@@ -145,16 +135,14 @@ def apply_to_dataset(
145
135
146
136
out = None
147
137
if parallelize_with_dask :
148
- if not isinstance (events_and_maybe_report , tuple ):
149
- events_and_maybe_report = (events_and_maybe_report ,)
150
- wired_events = _pack_meta_to_wire (* events_and_maybe_report )
138
+ (wired_events ,) = _pack_meta_to_wire (events )
151
139
out = dask .delayed (partial (_apply_analysis_wire , analysis , wired_events ))()
152
140
else :
153
141
out = analysis (events )
154
142
155
143
if report is not None :
156
- return out , report
157
- return out
144
+ return events , out , report
145
+ return events , out
158
146
159
147
160
148
def apply_to_fileset (
@@ -184,11 +172,14 @@ def apply_to_fileset(
184
172
185
173
Returns
186
174
-------
175
+ events: dict[str, dask_awkward.Array]
176
+ The NanoEvents objects the analysis function was applied to.
187
177
out : dict[str, DaskOutputType]
188
178
The output of the analysis workflow applied to the datasets, keyed by dataset name.
189
179
report : dask_awkward.Array, optional
190
180
The file access report for running the analysis on the input dataset. Needs to be computed in simultaneously with the analysis to be accurate.
191
181
"""
182
+ events = {}
192
183
out = {}
193
184
analyses_to_compute = {}
194
185
report = {}
@@ -206,24 +197,92 @@ def apply_to_fileset(
206
197
parallelize_with_dask ,
207
198
)
208
199
if parallelize_with_dask :
209
- analyses_to_compute [name ] = dataset_out
210
- elif isinstance (dataset_out , tuple ):
211
- out [name ], report [name ] = dataset_out
200
+ if len (dataset_out ) == 3 :
201
+ events [name ], analyses_to_compute [name ], report [name ] = dataset_out
202
+ elif len (dataset_out ) == 2 :
203
+ events [name ], analyses_to_compute [name ] = dataset_out
204
+ else :
205
+ raise ValueError (
206
+ "apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
207
+ )
208
+ elif isinstance (dataset_out , tuple ) and len (dataset_out ) == 3 :
209
+ events [name ], out [name ], report [name ] = dataset_out
210
+ elif isinstance (dataset_out , tuple ) and len (dataset_out ) == 2 :
211
+ events [name ], out [name ] = dataset_out
212
212
else :
213
- out [name ] = dataset_out
213
+ raise ValueError (
214
+ "apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
215
+ )
214
216
215
217
if parallelize_with_dask :
216
218
(calculated_graphs ,) = dask .compute (analyses_to_compute , scheduler = scheduler )
217
219
for name , dataset_out_wire in calculated_graphs .items ():
218
- to_unwire = dataset_out_wire
219
- if not isinstance (dataset_out_wire , tuple ):
220
- to_unwire = (dataset_out_wire ,)
221
- dataset_out = _unpack_meta_from_wire (* to_unwire )
222
- if isinstance (dataset_out , tuple ):
223
- out [name ], report [name ] = dataset_out
224
- else :
225
- out [name ] = dataset_out
220
+ (out [name ],) = _unpack_meta_from_wire (* dataset_out_wire )
226
221
227
222
if len (report ) > 0 :
228
- return out , report
229
- return out
223
+ return events , out , report
224
+ return events , out
225
+
226
+
227
+ def save_taskgraph (filename , events , * data_products , optimize_graph = False ):
228
+ """
229
+ Save a task graph and its originating nanoevents to a file
230
+ Parameters
231
+ ----------
232
+ filename: str
233
+ Where to save the resulting serialized taskgraph and nanoevents.
234
+ Suggested postfix ".hlg", after dask's HighLevelGraph object.
235
+ events: dict[str, dask_awkward.Array]
236
+ A dictionary of nanoevents objects.
237
+ data_products: dict[str, DaskOutputBaseType]
238
+ The data products resulting from applying an analysis to
239
+ a NanoEvents object. This may include report objects.
240
+ optimize_graph: bool, default False
241
+ Whether or not to save the task graph in its optimized form.
242
+
243
+ Returns
244
+ -------
245
+ """
246
+ (events_wire ,) = _pack_meta_to_wire (events )
247
+
248
+ if len (data_products ) == 0 :
249
+ raise ValueError (
250
+ "You must supply at least one analysis data product to save a task graph!"
251
+ )
252
+
253
+ data_products_out = data_products
254
+ if optimize_graph :
255
+ data_products_out = dask .optimize (data_products )
256
+
257
+ data_products_wire = _pack_meta_to_wire (* data_products_out )
258
+
259
+ save (
260
+ {
261
+ "events" : events_wire ,
262
+ "data_products" : data_products_wire ,
263
+ "optimized" : optimize_graph ,
264
+ },
265
+ filename ,
266
+ )
267
+
268
+
269
+ def load_taskgraph (filename ):
270
+ """
271
+ Load a task graph and its originating nanoevents from a file.
272
+ Parameters
273
+ ----------
274
+ filename: str
275
+ The file from which to load the task graph.
276
+ Returns
277
+ _______
278
+ """
279
+ graph_information_wire = load (filename )
280
+
281
+ (events ,) = _unpack_meta_from_wire (graph_information_wire ["events" ])
282
+ (data_products ,) = _unpack_meta_from_wire (* graph_information_wire ["data_products" ])
283
+ optimized = graph_information_wire ["optimized" ]
284
+
285
+ for dataset_name in events :
286
+ events [dataset_name ]._meta .attrs ["@original_array" ] = events [dataset_name ]
287
+
288
+ return events , data_products , optimized
0 commit comments