Skip to content

Commit 0cb2248

Browse files
author
Jan Beitner
committed
Ensure encoding of group ids is only done if needed
1 parent b8112ff commit 0cb2248

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed

pytorch_forecasting/data/timeseries.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,13 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
427427
data[group_name] = self.transform_values(name, data[name], inverse=False, group_id=True)
428428

429429
# encode categoricals
430-
for name in set(self.group_ids + self.categoricals):
430+
if isinstance(
431+
self.target_normalizer, GroupNormalizer
432+
): # if we use a group normalizer, group_ids must be encoded as well
433+
group_ids_to_encode = self.group_ids
434+
else:
435+
group_ids_to_encode = []
436+
for name in set(group_ids_to_encode + self.categoricals):
431437
allow_nans = name in self.dropout_categoricals
432438
if name in self.variable_groups: # fit groups
433439
columns = self.variable_groups[name]
@@ -452,7 +458,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
452458
self.categorical_encoders[name] = self.categorical_encoders[name].fit(data[name])
453459

454460
# encode them
455-
for name in set(self.group_ids + self.flat_categoricals):
461+
for name in set(group_ids_to_encode + self.flat_categoricals):
456462
data[name] = self.transform_values(name, data[name], inverse=False)
457463

458464
# save special variables
@@ -494,6 +500,10 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
494500
data[self.target], scales = self.target_normalizer.transform(data[self.target], data, return_norm=True)
495501
elif isinstance(self.target_normalizer, NaNLabelEncoder):
496502
data[self.target] = self.target_normalizer.transform(data[self.target])
503+
data["__target__"] = data[
504+
self.target
505+
] # overwrite target because it requires encoding (continuous targets should not be normalized)
506+
scales = "no target scales available for categorical target"
497507
else:
498508
data[self.target], scales = self.target_normalizer.transform(data[self.target], return_norm=True)
499509

@@ -510,6 +520,8 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
510520

511521
if self.target in self.reals:
512522
self.scalers[self.target] = self.target_normalizer
523+
else:
524+
self.categorical_encoders[self.target] = self.target_normalizer
513525

514526
# rescale continuous variables apart from target
515527
for name in self.reals:
@@ -830,8 +842,8 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
830842
if not group_ids.isin(df_index.group_id).all():
831843
missing_groups = data.loc[~group_ids.isin(df_index.group_id), self._group_ids].drop_duplicates()
832844
# decode values
833-
for name in missing_groups.columns:
834-
missing_groups[name] = self.transform_values(name, missing_groups[name], inverse=True, group_id=True)
845+
for name, id in self._group_ids_mapping.items():
846+
missing_groups[id] = self.transform_values(name, missing_groups[id], inverse=True, group_id=True)
835847
warnings.warn(
836848
"Min encoder length and/or min_prediction_idx and/or min prediction length is too large for "
837849
f"{len(missing_groups)} series/groups which therefore are not present in the dataset index. "

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111
from pytorch_forecasting.data.examples import get_stallion_data # isort:skip
1212

1313

14+
# for vscode debugging: https://stackoverflow.com/a/62563106/14121677
15+
if os.getenv("_PYTEST_RAISE", "0") != "0":
16+
17+
@pytest.hookimpl(tryfirst=True)
18+
def pytest_exception_interact(call):
19+
raise call.excinfo.value
20+
21+
@pytest.hookimpl(tryfirst=True)
22+
def pytest_internalerror(excinfo):
23+
raise excinfo.value
24+
25+
1426
@pytest.fixture
1527
def test_data():
1628
data = get_stallion_data()

tests/test_data.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,48 @@ def test_categorical_target(test_data):
373373
min_encoder_length=1,
374374
)
375375

376-
x, y = next(iter(dataset.to_dataloader()))
376+
_, y = next(iter(dataset.to_dataloader()))
377377
assert y.dtype is torch.long, "target must be of type long"
378378

379379

380380
def test_pickle(test_dataset):
381381
pickle.dumps(test_dataset)
382382
pickle.dumps(test_dataset.to_dataloader())
383+
384+
385+
@pytest.mark.parametrize(
386+
"kwargs",
387+
[
388+
{},
389+
dict(
390+
target_normalizer=GroupNormalizer(
391+
groups=["agency", "sku"], log_scale=True, scale_by_group=True, log_zero_value=1.0
392+
),
393+
),
394+
],
395+
)
396+
def test_new_group_ids(test_data, kwargs):
397+
"""Test for new group ids in dataset"""
398+
train_agency = test_data["agency"].iloc[0]
399+
train_dataset = TimeSeriesDataSet(
400+
test_data[lambda x: x.agency == train_agency],
401+
time_idx="time_idx",
402+
target="volume",
403+
group_ids=["agency", "sku"],
404+
max_encoder_length=5,
405+
max_prediction_length=2,
406+
min_prediction_length=1,
407+
min_encoder_length=1,
408+
categorical_encoders=dict(agency=NaNLabelEncoder(add_nan=True), sku=NaNLabelEncoder(add_nan=True)),
409+
**kwargs,
410+
)
411+
412+
# test sampling from training dataset
413+
next(iter(train_dataset.to_dataloader()))
414+
415+
# create test dataset with group ids that have not been observed before
416+
test_dataset = TimeSeriesDataSet.from_dataset(train_dataset, test_data)
417+
418+
# check that we can iterate through dataset without error
419+
for _ in iter(test_dataset.to_dataloader()):
420+
pass

0 commit comments

Comments
 (0)