@@ -38,13 +38,18 @@ def __init__(
3838 ):
3939 super ().__init__ (name , input_variable_names )
4040 self ._operation_type : MultiArrayOperationType = operation_type
41+ self ._ignore_nan = ignore_nan
4142 self ._operations = self ._create_operations ()
4243
4344 @property
4445 def operation_type (self ) -> MultiArrayOperationType :
4546 """Name of the rule"""
4647 return self ._operation_type
4748
49+ @property
50+ def ignore_nan (self ) -> bool :
51+ return self ._ignore_nan
52+
4853 def validate (self , logger : ILogger ) -> bool :
4954 if self ._operation_type not in self ._operations :
5055
@@ -90,17 +95,30 @@ def execute(
9095 return result_variable
9196
9297 def _create_operations (self ) -> dict [MultiArrayOperationType , Callable ]:
93- return {
94- MultiArrayOperationType .MULTIPLY : lambda npa : _np .prod (npa , axis = 0 ),
95- MultiArrayOperationType .MIN : lambda npa : _np .min (npa , axis = 0 ),
96- MultiArrayOperationType .MAX : lambda npa : _np .max (npa , axis = 0 ),
97- MultiArrayOperationType .AVERAGE : lambda npa : _np .average (npa , axis = 0 ),
98- MultiArrayOperationType .MEDIAN : lambda npa : _np .median (npa , axis = 0 ),
99- MultiArrayOperationType .ADD : lambda npa : _np .sum (npa , axis = 0 ),
100- MultiArrayOperationType .SUBTRACT : lambda npa : _np .subtract (
101- npa [0 ], _np .sum (npa [1 :], axis = 0 )
102- ),
103- }
98+ if self .ignore_nan :
99+ return {
100+ MultiArrayOperationType .MULTIPLY : lambda npa : _np .prod (npa , axis = 0 ),
101+ MultiArrayOperationType .MIN : lambda npa : _np .nanmin (npa , axis = 0 ),
102+ MultiArrayOperationType .MAX : lambda npa : _np .nanmax (npa , axis = 0 ),
103+ MultiArrayOperationType .AVERAGE : lambda npa : _np .nanmean (npa , axis = 0 ),
104+ MultiArrayOperationType .MEDIAN : lambda npa : _np .nanmedian (npa , axis = 0 ),
105+ MultiArrayOperationType .ADD : lambda npa : _np .nansum (npa , axis = 0 ),
106+ MultiArrayOperationType .SUBTRACT : lambda npa : _np .subtract (
107+ npa [0 ], _np .nansum (npa [1 :], axis = 0 )
108+ ),
109+ }
110+ else :
111+ return {
112+ MultiArrayOperationType .MULTIPLY : lambda npa : _np .prod (npa , axis = 0 ),
113+ MultiArrayOperationType .MIN : lambda npa : _np .min (npa , axis = 0 ),
114+ MultiArrayOperationType .MAX : lambda npa : _np .max (npa , axis = 0 ),
115+ MultiArrayOperationType .AVERAGE : lambda npa : _np .average (npa , axis = 0 ),
116+ MultiArrayOperationType .MEDIAN : lambda npa : _np .median (npa , axis = 0 ),
117+ MultiArrayOperationType .ADD : lambda npa : _np .sum (npa , axis = 0 ),
118+ MultiArrayOperationType .SUBTRACT : lambda npa : _np .subtract (
119+ npa [0 ], _np .sum (npa [1 :], axis = 0 )
120+ ),
121+ }
104122
105123 def _check_dimensions (self , np_arrays : List [_np .ndarray ]) -> bool :
106124 """Brief check if all the arrays to be combined have the
0 commit comments