Skip to content

Commit ef0436d

Browse files
committed
Enhance SympyTransform to handle scalar results and empty symbols in transformations
1 parent e932b52 commit ef0436d

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

AFL/double_agent/Preprocessor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -782,20 +782,28 @@ def __init__(
782782
def calculate(self, dataset: xr.Dataset) -> Self:
783783
"""Apply this `PipelineOp` to the supplied `xarray.dataset`"""
784784
data = dataset[self.input_variable].transpose(self.sample_dim, self.component_dim)
785+
n = data.sizes[self.sample_dim]
785786

786787
# need to construct a dict of arrays
787788
comps = {k: v.squeeze().values for k, v in data.groupby(self.component_dim, squeeze=False)}
788789

790+
789791
# apply transform
790792
new_comps = xr.Dataset()
791793
for name, transform in self.transforms.items():
792794
transform = sympy.sympify(transform)
793795
symbols = list(transform.free_symbols)
794796
lam = sympy.lambdify(symbols, transform)
795-
new_comps[name] = (
796-
(self.sample_dim,),
797-
listify(lam(**{k.name: comps[k.name] for k in symbols})),
798-
)
797+
798+
if symbols:
799+
result = lam(**{k.name: comps[k.name] for k in symbols})
800+
else:
801+
result = transform.evalf()
802+
803+
if np.isscalar(result):
804+
result = np.full(n, float(result))
805+
806+
new_comps[name] = ((self.sample_dim,), listify(result))
799807

800808
new_comps = new_comps.to_array(self.transform_dim).transpose(..., self.transform_dim)
801809

0 commit comments

Comments
 (0)