Skip to content

Commit b986740

Browse files
committed
option to ignore NaN in CombineResultsRule
1 parent a46b203 commit b986740

2 files changed

Lines changed: 30 additions & 12 deletions

File tree

decoimpact/business/entities/rules/combine_results_rule.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,18 @@ def __init__(
3838
):
3939
super().__init__(name, input_variable_names)
4040
self._operation_type: MultiArrayOperationType = operation_type
41+
self._ignore_nan = ignore_nan
4142
self._operations = self._create_operations()
4243

4344
@property
4445
def operation_type(self) -> MultiArrayOperationType:
4546
"""Name of the rule"""
4647
return self._operation_type
4748

49+
@property
50+
def ignore_nan(self) -> bool:
51+
return self._ignore_nan
52+
4853
def validate(self, logger: ILogger) -> bool:
4954
if self._operation_type not in self._operations:
5055

@@ -90,17 +95,30 @@ def execute(
9095
return result_variable
9196

9297
def _create_operations(self) -> dict[MultiArrayOperationType, Callable]:
93-
return {
94-
MultiArrayOperationType.MULTIPLY: lambda npa: _np.prod(npa, axis=0),
95-
MultiArrayOperationType.MIN: lambda npa: _np.min(npa, axis=0),
96-
MultiArrayOperationType.MAX: lambda npa: _np.max(npa, axis=0),
97-
MultiArrayOperationType.AVERAGE: lambda npa: _np.average(npa, axis=0),
98-
MultiArrayOperationType.MEDIAN: lambda npa: _np.median(npa, axis=0),
99-
MultiArrayOperationType.ADD: lambda npa: _np.sum(npa, axis=0),
100-
MultiArrayOperationType.SUBTRACT: lambda npa: _np.subtract(
101-
npa[0], _np.sum(npa[1:], axis=0)
102-
),
103-
}
98+
if self.ignore_nan:
99+
return {
100+
MultiArrayOperationType.MULTIPLY: lambda npa: _np.prod(npa, axis=0),
101+
MultiArrayOperationType.MIN: lambda npa: _np.nanmin(npa, axis=0),
102+
MultiArrayOperationType.MAX: lambda npa: _np.nanmax(npa, axis=0),
103+
MultiArrayOperationType.AVERAGE: lambda npa: _np.nanmean(npa, axis=0),
104+
MultiArrayOperationType.MEDIAN: lambda npa: _np.nanmedian(npa, axis=0),
105+
MultiArrayOperationType.ADD: lambda npa: _np.nansum(npa, axis=0),
106+
MultiArrayOperationType.SUBTRACT: lambda npa: _np.subtract(
107+
npa[0], _np.nansum(npa[1:], axis=0)
108+
),
109+
}
110+
else:
111+
return {
112+
MultiArrayOperationType.MULTIPLY: lambda npa: _np.prod(npa, axis=0),
113+
MultiArrayOperationType.MIN: lambda npa: _np.min(npa, axis=0),
114+
MultiArrayOperationType.MAX: lambda npa: _np.max(npa, axis=0),
115+
MultiArrayOperationType.AVERAGE: lambda npa: _np.average(npa, axis=0),
116+
MultiArrayOperationType.MEDIAN: lambda npa: _np.median(npa, axis=0),
117+
MultiArrayOperationType.ADD: lambda npa: _np.sum(npa, axis=0),
118+
MultiArrayOperationType.SUBTRACT: lambda npa: _np.subtract(
119+
npa[0], _np.sum(npa[1:], axis=0)
120+
),
121+
}
104122

105123
def _check_dimensions(self, np_arrays: List[_np.ndarray]) -> bool:
106124
"""Brief check if all the arrays to be combined have the

tests/business/entities/rules/test_combine_results_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_all_operations_incl_nan(
193193
(MultiArrayOperationType.MAX, [20, 12, 24]),
194194
(MultiArrayOperationType.MULTIPLY, [_np.nan, 420, 432]),
195195
(MultiArrayOperationType.AVERAGE, [12, 8, 11]),
196-
(MultiArrayOperationType.MEDIAN, [20, 7, 6]),
196+
(MultiArrayOperationType.MEDIAN, [12, 7, 6]),
197197
(MultiArrayOperationType.ADD, [24, 24, 33]),
198198
(MultiArrayOperationType.SUBTRACT, [16, -10, -27]),
199199
],

0 commit comments

Comments
 (0)