@@ -353,6 +353,24 @@ def _set_target_normalizer(self, data: pd.DataFrame):
353353 self .target_normalizer , (TorchNormalizer , NaNLabelEncoder )
354354 ), f"target_normalizer has to be either None or of class TorchNormalizer but found { self .target_normalizer } "
355355
356+ @property
357+ def _group_ids_mapping (self ) -> Dict [str , str ]:
358+ """
359+ Mapping of group id names to group ids used to identify series in dataset -
360+ group ids can also be used for target normalizer.
361+ The former can change from training to validation and test dataset while the later must not.
362+ """
363+ return {name : f"__group_id__{ name } " for name in self .group_ids }
364+
365+ @property
366+ def _group_ids (self ) -> List [str ]:
367+ """
368+ Group ids used to identify series in dataset.
369+
370+ See :py:meth:`~TimeSeriesDataSet._group_ids_mapping` for details.
371+ """
372+ return list (self ._group_ids_mapping .values ())
373+
356374 def _validate_data (self , data : pd .DataFrame ):
357375 """
358376 Validate that data will not cause hick-ups later on.
@@ -403,9 +421,13 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
403421 Returns:
404422 pd.DataFrame: pre-processed dataframe
405423 """
424+ # encode group ids - this encoding
425+ for name , group_name in self ._group_ids_mapping .items ():
426+ self .categorical_encoders [group_name ] = NaNLabelEncoder ().fit (data [name ].to_numpy ().reshape (- 1 ))
427+ data [group_name ] = self .transform_values (name , data [name ], inverse = False , group_id = True )
406428
407429 # encode categoricals
408- for name in set (self .categoricals + self .group_ids ):
430+ for name in set (self .group_ids + self .categoricals ):
409431 allow_nans = name in self .dropout_categoricals
410432 if name in self .variable_groups : # fit groups
411433 columns = self .variable_groups [name ]
@@ -430,7 +452,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
430452 self .categorical_encoders [name ] = self .categorical_encoders [name ].fit (data [name ])
431453
432454 # encode them
433- for name in set (self .flat_categoricals + self .group_ids ):
455+ for name in set (self .group_ids + self .flat_categoricals ):
434456 data [name ] = self .transform_values (name , data [name ], inverse = False )
435457
436458 # save special variables
@@ -515,7 +537,12 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
515537 return data
516538
517539 def transform_values (
518- self , name : str , values : Union [pd .Series , torch .Tensor , np .ndarray ], data : pd .DataFrame = None , inverse = False
540+ self ,
541+ name : str ,
542+ values : Union [pd .Series , torch .Tensor , np .ndarray ],
543+ data : pd .DataFrame = None ,
544+ inverse = False ,
545+ group_id : bool = False ,
519546 ) -> np .ndarray :
520547 """
521548 Scale and encode values.
@@ -526,12 +553,16 @@ def transform_values(
526553 data (pd.DataFrame, optional): extra data used for scaling (e.g. dataframe with groups columns).
527554 Defaults to None.
528555 inverse (bool, optional): if to conduct inverse transformation. Defaults to False.
556+ group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
557+ Defaults to False.
529558
530559 Returns:
531560 np.ndarray: (de/en)coded/(de)scaled values
532561 """
562+ if group_id :
563+ name = self ._group_ids_mapping [name ]
533564 # remaining categories
534- if name in set (self .flat_categoricals + self .group_ids ):
565+ if name in set (self .flat_categoricals + self .group_ids + self . _group_ids ):
535566 name = self .variable_to_group_mapping .get (name , name ) # map name to encoder
536567 encoder = self .categorical_encoders [name ]
537568 if encoder is None :
@@ -575,7 +606,7 @@ def _data_to_tensors(self, data: pd.DataFrame) -> Dict[str, torch.Tensor]:
575606 time index
576607 """
577608
578- index = torch .tensor (data [self .group_ids ].to_numpy (np .long ), dtype = torch .long )
609+ index = torch .tensor (data [self ._group_ids ].to_numpy (np .long ), dtype = torch .long )
579610 time = torch .tensor (data ["__time_idx__" ].to_numpy (np .long ), dtype = torch .long )
580611
581612 categorical = torch .tensor (data [self .flat_categoricals ].to_numpy (np .long ), dtype = torch .long )
@@ -735,7 +766,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
735766 Returns:
736767 pd.DataFrame: index dataframe
737768 """
738- g = data .groupby (self .group_ids , observed = True )
769+ g = data .groupby (self ._group_ids , observed = True )
739770
740771 df_index_first = g ["__time_idx__" ].transform ("nth" , 0 ).to_frame ("time_first" )
741772 df_index_last = g ["__time_idx__" ].transform ("nth" , - 1 ).to_frame ("time_last" )
@@ -797,10 +828,10 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
797828
798829 # check that all groups/series have at least one entry in the index
799830 if not group_ids .isin (df_index .group_id ).all ():
800- missing_groups = data .loc [~ group_ids .isin (df_index .group_id ), self .group_ids ].drop_duplicates ()
831+ missing_groups = data .loc [~ group_ids .isin (df_index .group_id ), self ._group_ids ].drop_duplicates ()
801832 # decode values
802833 for name in missing_groups .columns :
803- missing_groups [name ] = self .transform_values (name , missing_groups [name ], inverse = True )
834+ missing_groups [name ] = self .transform_values (name , missing_groups [name ], inverse = True , group_id = True )
804835 warnings .warn (
805836 "Min encoder length and/or min_prediction_idx and/or min prediction length is too large for "
806837 f"{ len (missing_groups )} series/groups which therefore are not present in the dataset index. "
@@ -1210,7 +1241,7 @@ def x_to_index(self, x: Dict[str, torch.Tensor]) -> pd.DataFrame:
12101241 for id in self .group_ids :
12111242 index_data [id ] = x ["groups" ][:, self .group_ids .index (id )].cpu ()
12121243 # decode if possible
1213- index_data [id ] = self .transform_values (id , index_data [id ], inverse = True )
1244+ index_data [id ] = self .transform_values (id , index_data [id ], inverse = True , group_id = True )
12141245 index = pd .DataFrame (index_data )
12151246 return index
12161247
0 commit comments