Skip to content

Commit ec1006c

Browse files
committed
get _ensure_systematics daskified
1 parent 1a5c7ca commit ec1006c

File tree

1 file changed

+36
-8
lines changed
  • src/coffea/nanoevents/methods

1 file changed

+36
-8
lines changed

Diff for: src/coffea/nanoevents/methods/base.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
class _ClassMethodFn:
1919
def __init__(self, attr: str, **kwargs: Any) -> None:
2020
self.attr = attr
21+
self.kwargs = kwargs
2122

2223
def __call__(self, coll: awkward.Array, *args: Any, **kwargs: Any) -> awkward.Array:
23-
return getattr(coll, self.attr)(*args, **kwargs)
24+
allkwargs = self.kwargs
25+
allkwargs.update(kwargs)
26+
return getattr(coll, self.attr)(*args, **allkwargs)
2427

2528

2629
@awkward.mixin_class(behavior)
@@ -36,12 +39,35 @@ def add_kind(cls, kind: str):
3639
"""
3740
cls._systematic_kinds.add(kind)
3841

39-
def _ensure_systematics(self):
42+
def _ensure_systematics(self, _dask_array_=None):
4043
"""
4144
Make sure that the parent object always has a field called '__systematics__'.
4245
"""
4346
if "__systematics__" not in awkward.fields(self):
44-
self["__systematics__"] = {}
47+
if _dask_array_ is not None:
48+
x = awkward.Array(
49+
awkward.Array([{}]).layout.to_typetracer(forget_length=True)
50+
)
51+
_dask_array_._meta["__systematics__"] = x
52+
53+
def add_systematics_hack(array):
54+
if awkward.backend(array) == "typetracer":
55+
array["__systematics__"] = x
56+
return array
57+
array["__systematics__"] = {}
58+
return array
59+
60+
temp = dask_awkward.map_partitions(
61+
add_systematics_hack,
62+
_dask_array_,
63+
label="ensure-systematics",
64+
meta=_dask_array_._meta,
65+
)
66+
_dask_array_._meta = temp._meta
67+
_dask_array_._dask = temp._dask
68+
_dask_array_._name = temp._name
69+
else:
70+
self["__systematics__"] = {}
4571

4672
@property
4773
def systematics(self):
@@ -109,11 +135,13 @@ def add_systematic(
109135
print("vf ", varying_function)
110136
print("da ", _dask_array_, type(_dask_array_))
111137
_dask_array_.map_partitions(
112-
_ClassMethodFn("add_systematic"),
113-
name,
114-
kind,
115-
what,
116-
varying_function,
138+
_ClassMethodFn(
139+
"add_systematic",
140+
name=name,
141+
kind=kind,
142+
varying_function=varying_function,
143+
),
144+
what=what,
117145
)
118146

119147
self._ensure_systematics()

0 commit comments

Comments
 (0)