Skip to content

Commit b8112ff

Browse files
author
Jan Beitner
committed
Re-encode group ids by dataset to identify series
1 parent c7574bc commit b8112ff

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

pytorch_forecasting/data/timeseries.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,24 @@ def _set_target_normalizer(self, data: pd.DataFrame):
353353
self.target_normalizer, (TorchNormalizer, NaNLabelEncoder)
354354
), f"target_normalizer has to be either None or of class TorchNormalizer but found {self.target_normalizer}"
355355

356+
@property
357+
def _group_ids_mapping(self) -> Dict[str, str]:
358+
"""
359+
Mapping of group id names to group ids used to identify series in dataset -
360+
group ids can also be used for target normalizer.
361+
The former can change from training to validation and test dataset while the later must not.
362+
"""
363+
return {name: f"__group_id__{name}" for name in self.group_ids}
364+
365+
@property
366+
def _group_ids(self) -> List[str]:
367+
"""
368+
Group ids used to identify series in dataset.
369+
370+
See :py:meth:`~TimeSeriesDataSet._group_ids_mapping` for details.
371+
"""
372+
return list(self._group_ids_mapping.values())
373+
356374
def _validate_data(self, data: pd.DataFrame):
357375
"""
358376
Validate that data will not cause hick-ups later on.
@@ -403,9 +421,13 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
403421
Returns:
404422
pd.DataFrame: pre-processed dataframe
405423
"""
424+
# encode group ids - this encoding
425+
for name, group_name in self._group_ids_mapping.items():
426+
self.categorical_encoders[group_name] = NaNLabelEncoder().fit(data[name].to_numpy().reshape(-1))
427+
data[group_name] = self.transform_values(name, data[name], inverse=False, group_id=True)
406428

407429
# encode categoricals
408-
for name in set(self.categoricals + self.group_ids):
430+
for name in set(self.group_ids + self.categoricals):
409431
allow_nans = name in self.dropout_categoricals
410432
if name in self.variable_groups: # fit groups
411433
columns = self.variable_groups[name]
@@ -430,7 +452,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
430452
self.categorical_encoders[name] = self.categorical_encoders[name].fit(data[name])
431453

432454
# encode them
433-
for name in set(self.flat_categoricals + self.group_ids):
455+
for name in set(self.group_ids + self.flat_categoricals):
434456
data[name] = self.transform_values(name, data[name], inverse=False)
435457

436458
# save special variables
@@ -515,7 +537,12 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
515537
return data
516538

517539
def transform_values(
518-
self, name: str, values: Union[pd.Series, torch.Tensor, np.ndarray], data: pd.DataFrame = None, inverse=False
540+
self,
541+
name: str,
542+
values: Union[pd.Series, torch.Tensor, np.ndarray],
543+
data: pd.DataFrame = None,
544+
inverse=False,
545+
group_id: bool = False,
519546
) -> np.ndarray:
520547
"""
521548
Scale and encode values.
@@ -526,12 +553,16 @@ def transform_values(
526553
data (pd.DataFrame, optional): extra data used for scaling (e.g. dataframe with groups columns).
527554
Defaults to None.
528555
inverse (bool, optional): if to conduct inverse transformation. Defaults to False.
556+
group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
557+
Defaults to False.
529558
530559
Returns:
531560
np.ndarray: (de/en)coded/(de)scaled values
532561
"""
562+
if group_id:
563+
name = self._group_ids_mapping[name]
533564
# remaining categories
534-
if name in set(self.flat_categoricals + self.group_ids):
565+
if name in set(self.flat_categoricals + self.group_ids + self._group_ids):
535566
name = self.variable_to_group_mapping.get(name, name) # map name to encoder
536567
encoder = self.categorical_encoders[name]
537568
if encoder is None:
@@ -575,7 +606,7 @@ def _data_to_tensors(self, data: pd.DataFrame) -> Dict[str, torch.Tensor]:
575606
time index
576607
"""
577608

578-
index = torch.tensor(data[self.group_ids].to_numpy(np.long), dtype=torch.long)
609+
index = torch.tensor(data[self._group_ids].to_numpy(np.long), dtype=torch.long)
579610
time = torch.tensor(data["__time_idx__"].to_numpy(np.long), dtype=torch.long)
580611

581612
categorical = torch.tensor(data[self.flat_categoricals].to_numpy(np.long), dtype=torch.long)
@@ -735,7 +766,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
735766
Returns:
736767
pd.DataFrame: index dataframe
737768
"""
738-
g = data.groupby(self.group_ids, observed=True)
769+
g = data.groupby(self._group_ids, observed=True)
739770

740771
df_index_first = g["__time_idx__"].transform("nth", 0).to_frame("time_first")
741772
df_index_last = g["__time_idx__"].transform("nth", -1).to_frame("time_last")
@@ -797,10 +828,10 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
797828

798829
# check that all groups/series have at least one entry in the index
799830
if not group_ids.isin(df_index.group_id).all():
800-
missing_groups = data.loc[~group_ids.isin(df_index.group_id), self.group_ids].drop_duplicates()
831+
missing_groups = data.loc[~group_ids.isin(df_index.group_id), self._group_ids].drop_duplicates()
801832
# decode values
802833
for name in missing_groups.columns:
803-
missing_groups[name] = self.transform_values(name, missing_groups[name], inverse=True)
834+
missing_groups[name] = self.transform_values(name, missing_groups[name], inverse=True, group_id=True)
804835
warnings.warn(
805836
"Min encoder length and/or min_prediction_idx and/or min prediction length is too large for "
806837
f"{len(missing_groups)} series/groups which therefore are not present in the dataset index. "
@@ -1210,7 +1241,7 @@ def x_to_index(self, x: Dict[str, torch.Tensor]) -> pd.DataFrame:
12101241
for id in self.group_ids:
12111242
index_data[id] = x["groups"][:, self.group_ids.index(id)].cpu()
12121243
# decode if possible
1213-
index_data[id] = self.transform_values(id, index_data[id], inverse=True)
1244+
index_data[id] = self.transform_values(id, index_data[id], inverse=True, group_id=True)
12141245
index = pd.DataFrame(index_data)
12151246
return index
12161247

0 commit comments

Comments
 (0)