Skip to content

Commit 4ad23e4

Browse files
committed
ressurect eager-mode systematics handling
1 parent 13ae5e1 commit 4ad23e4

File tree

4 files changed

+31
-28
lines changed

4 files changed

+31
-28
lines changed

Diff for: coffea/nanoevents/methods/base.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _ensure_systematics(self):
4040
Make sure that the parent object always has a field called '__systematics__'.
4141
"""
4242
if "__systematics__" not in awkward.fields(self):
43-
self["__systematics__"] = {}
43+
self["__systematics__"] = awkward.Array(len(self) * [{}])
4444

4545
@property
4646
def systematics(self):
@@ -122,23 +122,27 @@ def add_systematic(
122122
if what == "weight" and "__ones__" not in awkward.fields(
123123
flat["__systematics__"]
124124
):
125-
flat["__systematics__", "__ones__"] = numpy.ones(
126-
len(flat), dtype=numpy.float32
127-
)
125+
fields = awkward.fields(flat["__systematics__"])
126+
as_dict = {field: flat["__systematics__", field] for field in fields}
127+
as_dict["__ones__"] = numpy.ones(len(flat), dtype=numpy.float32)
128+
flat["__systematics__"] = awkward.zip(as_dict, depth_limit=1)
128129

129130
rendered_type = flat.layout.parameters["__record__"]
130131
as_syst_type = awkward.with_parameter(flat, "__record__", kind)
131132
as_syst_type._build_variations(name, what, varying_function)
132133
variations = as_syst_type.describe_variations()
133134

134-
flat["__systematics__", name] = awkward.zip(
135+
fields = awkward.fields(flat["__systematics__"])
136+
as_dict = {field: flat["__systematics__", field] for field in fields}
137+
as_dict[name] = awkward.zip(
135138
{
136139
v: getattr(as_syst_type, v)(name, what, rendered_type)
137140
for v in variations
138141
},
139142
depth_limit=1,
140143
with_name=f"{name}Systematics",
141144
)
145+
flat["__systematics__"] = awkward.zip(as_dict, depth_limit=1)
142146

143147
self["__systematics__"] = wrap(flat["__systematics__"])
144148
self.behavior[("__typestr__", f"{name}Systematics")] = f"{kind}"

Diff for: coffea/nanoevents/methods/systematics/UpDownSystematic.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ def _build_variations(self, name, what, varying_function, *args, **kwargs):
1616
self[what] if what != "weight" else self["__systematics__", "__ones__"]
1717
)
1818

19-
self["__systematics__", f"__{name}__"] = awkward.virtual(
20-
varying_function,
21-
args=(whatarray, *args),
22-
kwargs=kwargs,
23-
length=len(whatarray),
19+
fields = awkward.fields(self["__systematics__"])
20+
as_dict = {field: self["__systematics__", field] for field in fields}
21+
as_dict[f"__{name}__"] = varying_function(
22+
whatarray,
23+
*args,
24+
**kwargs,
2425
)
26+
self["__systematics__"] = awkward.zip(as_dict, depth_limit=1)
2527

2628
def describe_variations(self):
2729
"""Show the map of variation names to indices."""
@@ -53,20 +55,20 @@ def get_variation(self, name, what, astype, updown):
5355

5456
def up(self, name, what, astype):
5557
"""Return the "up" variation of this observable."""
56-
return awkward.virtual(
57-
self.get_variation,
58-
args=(name, what, astype, "up"),
59-
length=len(self),
60-
parameters=self[what].layout.parameters if what != "weight" else None,
58+
return self.get_variation(
59+
name,
60+
what,
61+
astype,
62+
"up",
6163
)
6264

6365
def down(self, name, what, astype):
6466
"""Return the "down" variation of this observable."""
65-
return awkward.virtual(
66-
self.get_variation,
67-
args=(name, what, astype, "down"),
68-
length=len(self),
69-
parameters=self[what].layout.parameters if what != "weight" else None,
67+
return self.get_variation(
68+
name,
69+
what,
70+
astype,
71+
"down",
7072
)
7173

7274

Diff for: coffea/nanoevents/schemas/nanoaod.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _build_collections(self, field_names, input_contents):
291291
output[name].setdefault("parameters", {})
292292
output[name]["parameters"].update({"collection_name": name})
293293

294-
return output.keys(), output.values()
294+
return list(output.keys()), list(output.values())
295295

296296
@property
297297
def behavior(self):

Diff for: coffea/util.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,12 @@ def deprecate(exception, version, date=None):
163163

164164
# re-nest a record array into a ListArray
165165
def awkward_rewrap(arr, like_what, gfunc):
166-
behavior = awkward._util.behaviorof(like_what)
167166
func = partial(gfunc, data=arr.layout)
168-
layout = awkward.operations.convert.to_layout(like_what)
169-
newlayout = awkward._util.recursively_apply(layout, func)
170-
return awkward._util.wrap(newlayout, behavior=behavior)
167+
return awkward.transform(func, like_what, behavior=like_what.behavior)
171168

172169

173170
# we're gonna assume that the first record array we encounter is the flattened data
174-
def rewrap_recordarray(layout, depth, data):
175-
if isinstance(layout, awkward.layout.RecordArray):
176-
return lambda: data
171+
def rewrap_recordarray(layout, depth, data, **kwargs):
172+
if isinstance(layout, awkward.contents.RecordArray):
173+
return data
177174
return None

0 commit comments

Comments
 (0)