Skip to content

Commit 6b9cda8

Browse files
committed
ressurect tests after tuple-out fix
1 parent e26f51f commit 6b9cda8

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

Diff for: src/coffea/dataset_tools/apply_processor.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,14 @@ def apply_to_dataset(
139139
out = None
140140
if parallelize_with_dask:
141141
(wired_events,) = _pack_meta_to_wire(events)
142-
out = (
143-
dask.delayed(
144-
lambda: lz4.frame.compress(
145-
cloudpickle.dumps(
146-
partial(_apply_analysis_wire, analysis, wired_events)()
147-
),
148-
compression_level=6,
149-
)
150-
)(),
151-
)
142+
out = dask.delayed(
143+
lambda: lz4.frame.compress(
144+
cloudpickle.dumps(
145+
partial(_apply_analysis_wire, analysis, wired_events)()
146+
),
147+
compression_level=6,
148+
)
149+
)()
152150
dask.base.function_cache.clear()
153151
else:
154152
out = analysis(events)
@@ -217,14 +215,15 @@ def apply_to_fileset(
217215
events[name], analyses_to_compute[name], report[name] = dataset_out
218216
elif len(dataset_out) == 2:
219217
events[name], analyses_to_compute[name] = dataset_out
218+
print(dataset_out)
220219
else:
221220
raise ValueError(
222221
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
223222
)
224223
elif isinstance(dataset_out, tuple) and len(dataset_out) == 3:
225224
events[name], out[name], report[name] = dataset_out
226225
elif isinstance(dataset_out, tuple) and len(dataset_out) == 2:
227-
events[name], out[name] = dataset_out[0]
226+
events[name], out[name] = dataset_out
228227
else:
229228
raise ValueError(
230229
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
@@ -238,6 +237,10 @@ def apply_to_fileset(
238237
)
239238
(out[name],) = _unpack_meta_from_wire(*dataset_out_wire)
240239

240+
for name in out:
241+
if isinstance(out[name], tuple) and len(out[name]) == 1:
242+
out[name] = out[name][0]
243+
241244
if len(report) > 0:
242245
return events, out, report
243246
return events, out

Diff for: tests/test_dataset_tools.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
220220

221221
if allow_read_errors_with_report:
222222
assert isinstance(out, tuple)
223-
assert len(out) == 2
224-
out, report = out
223+
assert len(out) == 3
224+
_, out, report = out
225225
assert isinstance(out, dict)
226226
assert isinstance(report, dict)
227227
assert out.keys() == {"ZJets", "Data"}
@@ -236,8 +236,10 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
236236
assert isinstance(report["ZJets"], dask_awkward.Array)
237237
assert isinstance(report["Data"], dask_awkward.Array)
238238
else:
239-
assert isinstance(out, dict)
239+
assert isinstance(out, tuple)
240240
assert len(out) == 2
241+
_, out = out
242+
assert isinstance(out, dict)
241243
assert out.keys() == {"ZJets", "Data"}
242244
assert isinstance(out["ZJets"], tuple)
243245
assert isinstance(out["Data"], tuple)
@@ -255,8 +257,8 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
255257

256258
if allow_read_errors_with_report:
257259
assert isinstance(out, tuple)
258-
assert len(out) == 2
259-
out, report = out
260+
assert len(out) == 3
261+
_, out, report = out
260262
assert isinstance(out, dict)
261263
assert isinstance(report, dict)
262264
assert out.keys() == {"ZJets", "Data"}
@@ -271,8 +273,10 @@ def test_tuple_data_manipulation_output(allow_read_errors_with_report):
271273
assert isinstance(report["ZJets"], dask_awkward.Array)
272274
assert isinstance(report["Data"], dask_awkward.Array)
273275
else:
274-
assert isinstance(out, dict)
276+
assert isinstance(out, tuple)
275277
assert len(out) == 2
278+
_, out = out
279+
assert isinstance(out, dict)
276280
assert out.keys() == {"ZJets", "Data"}
277281
assert isinstance(out["ZJets"], tuple)
278282
assert isinstance(out["Data"], tuple)

0 commit comments

Comments
 (0)