@@ -155,7 +155,7 @@ def __init__(
155155 method : str = "standard" ,
156156 center : bool = True ,
157157 log_scale : Union [bool , float ] = False ,
158- log_zero_value : float = 0.0 ,
158+ log_zero_value : float = - np . inf ,
159159 coerce_positive : Union [float , bool ] = None ,
160160 eps : float = 1e-8 ,
161161 ):
@@ -167,13 +167,14 @@ def __init__(
167167 (scale using quantiles 0.25-0.75). Defaults to "standard".
168168 center (bool, optional): If to center the output to zero. Defaults to True.
169169 log_scale (bool, optional): If to take log of values. Defaults to False. Defaults to False.
170- log_zero_value (float, optional): Value to map 0 to for ``log_scale=True`` or in softplus. Defaults to 0.0
170+ log_zero_value (float, optional): Value to map 0 to for ``log_scale=True`` or in softplus. Defaults to -inf.
171171 coerce_positive (Union[bool, float, str], optional): If to coerce output to positive. Valid values:
172172 * None, i.e. is automatically determined and might change to True if all values are >= 0 (Default).
173173 * True, i.e. output is clamped at 0.
174174 * False, i.e. values are not coerced
175175 * float, i.e. softmax is applied with beta = coerce_positive.
176- eps (float, optional): Number for numerical stability of calcualtions. Defaults to 1e-8.
176+ eps (float, optional): Number for numerical stability of calcualtions.
177+ Defaults to 1e-8. For count data, 1.0 is recommended.
177178 """
178179 self .method = method
179180 assert method in ["standard" , "robust" ], f"method has invalid value { method } "
@@ -202,7 +203,7 @@ def get_parameters(self, *args, **kwargs) -> torch.Tensor:
202203 Returns:
203204 torch.Tensor: First element is center of data and second is scale
204205 """
205- return torch .tensor ([ self .center_ , self .scale_ ] )
206+ return torch .stack ([ torch . as_tensor ( self .center_ ), torch . as_tensor ( self .scale_ )], dim = - 1 )
206207
207208 def _preprocess_y (self , y : Union [pd .Series , np .ndarray , torch .Tensor ]) -> Union [np .ndarray , torch .Tensor ]:
208209 """
@@ -213,14 +214,11 @@ def _preprocess_y(self, y: Union[pd.Series, np.ndarray, torch.Tensor]) -> Union[
213214 Returns:
214215 Union[np.ndarray, torch.Tensor]: return rescaled series with type depending on input type
215216 """
216- if self .coerce_positive is None and not self .log_scale :
217- self .coerce_positive = (y >= 0 ).all ()
218-
219217 if self .log_scale :
220218 if isinstance (y , torch .Tensor ):
221- y = torch .log (y + self .log_zero_value )
219+ y = torch .log (y + self .log_zero_value + self . eps )
222220 else :
223- y = np .log (y + self .log_zero_value )
221+ y = np .log (y + self .log_zero_value + self . eps )
224222 return y
225223
226224 def fit (self , y : Union [pd .Series , np .ndarray , torch .Tensor ]):
@@ -233,53 +231,77 @@ def fit(self, y: Union[pd.Series, np.ndarray, torch.Tensor]):
233231 Returns:
234232 TorchNormalizer: self
235233 """
234+ if self .coerce_positive is None and not self .log_scale :
235+ self .coerce_positive = (y >= 0 ).all ()
236236 y = self ._preprocess_y (y )
237237
238238 if self .method == "standard" :
239239 if isinstance (y , torch .Tensor ):
240- self .center_ = torch .mean (y )
241- self .scale_ = torch .std (y ) / (self .center_ + self .eps )
240+ self .center_ = torch .mean (y , dim = - 1 ) + self .eps
241+ self .scale_ = torch .std (y , dim = - 1 ) + self .eps
242+ elif isinstance (y , np .ndarray ):
243+ self .center_ = np .mean (y , axis = - 1 ) + self .eps
244+ self .scale_ = np .std (y , axis = - 1 ) + self .eps
242245 else :
243- self .center_ = np .mean (y )
244- self .scale_ = np .std (y ) / ( self . center_ + self .eps )
246+ self .center_ = np .mean (y ) + self . eps
247+ self .scale_ = np .std (y ) + self .eps
245248
246249 elif self .method == "robust" :
247250 if isinstance (y , torch .Tensor ):
248- self .center_ = torch .median (y )
249- q_75 = y .kthvalue (int (len (y ) * 0.75 )).values
250- q_25 = y .kthvalue (int (len (y ) * 0.25 )).values
251+ self .center_ = torch .median (y , dim = - 1 ).values + self .eps
252+ q_75 = y .kthvalue (int (len (y ) * 0.75 ), dim = - 1 ).values
253+ q_25 = y .kthvalue (int (len (y ) * 0.25 ), dim = - 1 ).values
254+ elif isinstance (y , np .ndarray ):
255+ self .center_ = np .median (y , axis = - 1 ) + self .eps
256+ q_75 = np .percentiley (y , 75 , axis = - 1 )
257+ q_25 = np .percentiley (y , 25 , axis = - 1 )
251258 else :
252- self .center_ = np .median (y )
259+ self .center_ = np .median (y ) + self . eps
253260 q_75 = np .percentiley (y , 75 )
254261 q_25 = np .percentiley (y , 25 )
255- self .scale_ = (q_75 - q_25 ) / (self .center_ + self .eps ) / 2.0
262+ self .scale_ = (q_75 - q_25 ) / 2.0 + self .eps
263+ if not self .center :
264+ self .scale_ = self .center_
265+ if isinstance (y , torch .Tensor ):
266+ self .center_ = torch .zeros_like (self .center_ )
267+ else :
268+ self .center_ = np .zeros_like (self .center_ )
256269 return self
257270
258271 def transform (
259- self , y : Union [pd .Series , np .ndarray , torch .Tensor ], return_norm : bool = False
272+ self ,
273+ y : Union [pd .Series , np .ndarray , torch .Tensor ],
274+ return_norm : bool = False ,
275+ target_scale : torch .Tensor = None ,
260276 ) -> Union [Tuple [Union [np .ndarray , torch .Tensor ], np .ndarray ], Union [np .ndarray , torch .Tensor ]]:
261277 """
262278 Rescale data.
263279
264280 Args:
265281 y (Union[pd.Series, np.ndarray, torch.Tensor]): input data
266282 return_norm (bool, optional): [description]. Defaults to False.
283+ target_scale (torch.Tensor): target scale to use instead of fitted center and scale
267284
268285 Returns:
269286 Union[Tuple[Union[np.ndarray, torch.Tensor], np.ndarray], Union[np.ndarray, torch.Tensor]]: rescaled
270287 data with type depending on input type. returns second element if ``return_norm=True``
271288 """
272- if self .log_scale :
273- if isinstance (y , torch .Tensor ):
274- y = (y + self .log_zero_value + self .eps ).log ()
275- else :
276- y = np .log (y + self .log_zero_value + self .eps )
277- if self .center :
278- y = (y / (self .center_ + self .eps ) - 1 ) / (self .scale_ + self .eps )
279- else :
280- y = y / (self .center_ + self .eps )
289+ y = self ._preprocess_y (y )
290+ # get center and scale
291+ if target_scale is None :
292+ target_scale = self .get_parameters ().numpy ()[None , :]
293+ center = target_scale [..., 0 ]
294+ scale = target_scale [..., 1 ]
295+ if y .ndim > center .ndim : # multiple batches -> expand size
296+ center = center .view (* center .size (), * (1 ,) * (y .ndim - center .ndim ))
297+ scale = scale .view (* scale .size (), * (1 ,) * (y .ndim - scale .ndim ))
298+
299+ # transform
300+ y = (y - center ) / scale
301+
302+ # return with center and scale or without
281303 if return_norm :
282- return y , self . get_parameters (). numpy ()[ None , :]
304+ return y , target_scale
283305 else :
284306 return y
285307
@@ -303,6 +325,8 @@ def __call__(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
303325 data (Dict[str, torch.Tensor]): Dictionary with entries
304326 * prediction: data to de-scale
305327 * target_scale: center and scale of data
328+ scale_only (bool): if to only scale prediction and not center it (even if `self.center is True`).
329+ Defaults to False.
306330
307331 Returns:
308332 torch.Tensor: de-scaled data
@@ -315,10 +339,8 @@ def __call__(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
315339 norm = norm .unsqueeze (- 1 )
316340
317341 # transform
318- if self .center :
319- y_normed = (data ["prediction" ] * norm [:, 1 , None ] + 1 ) * norm [:, 0 , None ]
320- else :
321- y_normed = data ["prediction" ] * norm [:, 0 , None ]
342+ y_normed = data ["prediction" ] * norm [:, 1 , None ] + norm [:, 0 , None ]
343+
322344 if self .log_scale :
323345 y_normed = (y_normed .exp () - self .log_zero_value ).clamp_min (0.0 )
324346 elif isinstance (self .coerce_positive , bool ) and self .coerce_positive :
@@ -379,7 +401,8 @@ def __init__(
379401 * True, i.e. output is clamped at 0.
380402 * False, i.e. values are not coerced
381403 * float, i.e. softmax is applied with beta = coerce_positive.
382- eps (float, optional): Number for numerical stability of calcualtions. Defaults to 1e-8.
404+ eps (float, optional): Number for numerical stability of calcualtions.
405+ Defaults to 1e-8. For count data, 1.0 is recommended.
383406 """
384407 self .groups = groups
385408 self .scale_by_group = scale_by_group
@@ -403,24 +426,31 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
403426 Returns:
404427 self
405428 """
429+ if self .coerce_positive is None and not self .log_scale :
430+ self .coerce_positive = (y >= 0 ).all ()
406431 y = self ._preprocess_y (y )
407432 if len (self .groups ) == 0 :
408433 assert not self .scale_by_group , "No groups are defined, i.e. `scale_by_group=[]`"
409434 if self .method == "standard" :
410- mean = np .mean (y )
411- self .norm_ = mean , np .std (y ) / (mean + self .eps )
435+ self .norm_ = [np .mean (y ) + self .eps , np .std (y ) + self .eps ] # center and scale
412436 else :
413437 quantiles = np .quantile (y , [0.25 , 0.5 , 0.75 ])
414- self .norm_ = quantiles [1 ], (quantiles [2 ] - quantiles [0 ]) / (quantiles [1 ] + self .eps )
438+ self .norm_ = [
439+ quantiles [1 ] + self .eps ,
440+ (quantiles [2 ] - quantiles [0 ]) / 2.0 + self .eps ,
441+ ] # center and scale
442+ if not self .center :
443+ self .norm_ [1 ] = self .norm_ [0 ]
444+ self .norm_ [0 ] = 0.0
415445
416446 elif self .scale_by_group :
417447 if self .method == "standard" :
418448 self .norm_ = {
419449 g : X [[g ]]
420450 .assign (y = y )
421451 .groupby (g , observed = True )
422- .agg (mean = ("y" , "mean" ), scale = ("y" , "std" ))
423- .assign (scale = lambda x : x . scale / ( x [ "mean" ] + self .eps ) )
452+ .agg (center = ("y" , "mean" ), scale = ("y" , "std" ))
453+ .assign (center = lambda x : x [ "center" ] + self . eps , scale = lambda x : x . scale + self .eps )
424454 for g in self .groups
425455 }
426456 else :
@@ -431,12 +461,20 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
431461 .y .quantile ([0.25 , 0.5 , 0.75 ])
432462 .unstack (- 1 )
433463 .assign (
434- median = lambda x : x [0.5 ] + self .eps ,
435- scale = lambda x : (x [0.75 ] - x [0.25 ] + self . eps ) / ( x [ 0.5 ] + self .eps ) ,
436- )[["median " , "scale" ]]
464+ center = lambda x : x [0.5 ] + self .eps ,
465+ scale = lambda x : (x [0.75 ] - x [0.25 ]) / 2.0 + self .eps ,
466+ )[["center " , "scale" ]]
437467 for g in self .groups
438468 }
439469 # calculate missings
470+ if not self .center : # swap center and scale
471+
472+ def swap_parameters (norm ):
473+ norm ["scale" ] = norm ["center" ]
474+ norm ["center" ] = 0.0
475+ return norm
476+
477+ self .norm = {g : swap_parameters (norm ) for g , norm in self .norm_ .items ()}
440478 self .missing_ = {group : scales .median ().to_dict () for group , scales in self .norm_ .items ()}
441479
442480 else :
@@ -445,8 +483,8 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
445483 X [self .groups ]
446484 .assign (y = y )
447485 .groupby (self .groups , observed = True )
448- .agg (mean = ("y" , "mean" ), scale = ("y" , "std" ))
449- .assign (scale = lambda x : x . scale / ( x [ "mean" ] + self .eps ) )
486+ .agg (center = ("y" , "mean" ), scale = ("y" , "std" ))
487+ .assign (center = lambda x : x [ "center" ] + self . eps , scale = lambda x : x . scale + self .eps )
450488 )
451489 else :
452490 self .norm_ = (
@@ -456,10 +494,13 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
456494 .y .quantile ([0.25 , 0.5 , 0.75 ])
457495 .unstack (- 1 )
458496 .assign (
459- median = lambda x : x [0.5 ] + self .eps ,
460- scale = lambda x : (x [0.75 ] - x [0.25 ] + self . eps ) / ( x [ 0.5 ] + self .eps ) / 2.0 ,
461- )[["median " , "scale" ]]
497+ center = lambda x : x [0.5 ] + self .eps ,
498+ scale = lambda x : (x [0.75 ] - x [0.25 ]) / 2.0 + self .eps ,
499+ )[["center " , "scale" ]]
462500 )
501+ if not self .center : # swap center and scale
502+ self .norm_ ["scale" ] = self .norm_ ["center" ]
503+ self .norm_ ["center" ] = 0.0
463504 self .missing_ = self .norm_ .median ().to_dict ()
464505 return self
465506
@@ -471,10 +512,7 @@ def names(self) -> List[str]:
471512 Returns:
472513 List[str]: list of names
473514 """
474- if self .method == "standard" :
475- return ["mean" , "scale" ]
476- else :
477- return ["median" , "scale" ]
515+ return ["center" , "scale" ]
478516
479517 def fit_transform (
480518 self , y : pd .Series , X : pd .DataFrame , return_norm : bool = False
@@ -495,12 +533,12 @@ def fit_transform(
495533
496534 def inverse_transform (self , y : pd .Series , X : pd .DataFrame ):
497535 """
498- Rescaling data to original scale - not implemented.
536+ Rescaling data to original scale - not implemented - call class with target scale instead .
499537 """
500538 raise NotImplementedError ()
501539
502540 def transform (
503- self , y : pd .Series , X : pd .DataFrame , return_norm : bool = False
541+ self , y : pd .Series , X : pd .DataFrame = None , return_norm : bool = False , target_scale : torch . Tensor = None
504542 ) -> Union [np .ndarray , Tuple [np .ndarray , np .ndarray ]]:
505543 """
506544 Scale input data.
@@ -509,21 +547,16 @@ def transform(
509547 y (pd.Series): data to scale
510548 X (pd.DataFrame): dataframe with ``groups`` columns
511549 return_norm (bool, optional): If to return . Defaults to False.
550+ target_scale (torch.Tensor): target scale to use instead of fitted center and scale
512551
513552 Returns:
514553 Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: Scaled data, if ``return_norm=True``, returns also scales
515554 as second element
516555 """
517- norm = self .get_norm (X )
518- y = self ._preprocess_y (y )
519- if self .center :
520- y_normed = (y / (norm [:, 0 ] + self .eps ) - 1 ) / (norm [:, 1 ] + self .eps )
521- else :
522- y_normed = y / (norm [:, 0 ] + self .eps )
523- if return_norm :
524- return y_normed , norm
525- else :
526- return y_normed
556+ if target_scale is None :
557+ assert X is not None , "either target_scale or X has to be passed"
558+ target_scale = self .get_norm (X )
559+ return super ().transform (y = y , return_norm = return_norm , target_scale = target_scale )
527560
528561 def get_parameters (self , groups : Union [torch .Tensor , list , tuple ], group_names : List [str ] = None ) -> np .ndarray :
529562 """
0 commit comments