Skip to content

Commit fabba0a

Browse files
Fix/inverse transform for static covariate with single category across series (#2710)
* fix: bug when several static covariates have only one category * feat: added corresponding tests * update changelog * update changelog --------- Co-authored-by: dennisbader <[email protected]>
1 parent 62dca8f commit fabba0a

File tree

3 files changed

+39
-18
lines changed

3 files changed

+39
-18
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2121
**Fixed**
2222

2323
- 🔴 / 🟢 Fixed a bug which raised an error when loading torch models that were saved with Darts versions < 0.33.0. This is a breaking change and models saved with version 0.33.0 will not be loadable anymore. [#2692](https://github.com/unit8co/darts/pull/2692) by [Dennis Bader](https://github.com/dennisbader).
24+
- Fixed a bug in `StaticCovariatesTransformer` which raised an error when trying to inverse transform one-hot encoded categorical static covariates with identical values across time-series. Each categorical static covariates is now referred to by `{covariate_name}_{category_name}`, regardless of the number of categories. [#2710](https://github.com/unit8co/darts/pull/2710) by [Antoine Madrona](https://github.com/madtoinou)
2425
- Fixed a bug in `13-TFT-examples.ipynb` where two calls to `TimeSeries.from_series()` were not providing `series` but `pd.Index`. The method calls were changed to `TimeSeries.from_values()`. [#2719](https://github.com/unit8co/darts/pull/2719) by [Jules Authier](https://github.com/authierj)
2526

2627
**Dependencies**

darts/dataprocessing/transformers/static_covariates_transformer.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,8 @@ def _create_category_mappings(
300300
for col, categories in zip(cols_cat, transformer_cat.categories_):
301301
col_map_cat_i = []
302302
for cat in categories:
303-
col_map_cat_i.append(cat)
304-
if len(categories) > 1:
305-
cat_col_name = str(col) + "_" + str(cat)
306-
inv_col_map_cat[cat_col_name] = [col]
307-
else:
308-
inv_col_map_cat[cat] = [col]
303+
col_map_cat_i.append(str(col) + "_" + str(cat))
304+
inv_col_map_cat[str(col) + "_" + str(cat)] = [col]
309305
col_map_cat[col] = col_map_cat_i
310306
# If we don't have any categorical static covariates, don't need to generate mapping:
311307
else:
@@ -393,16 +389,6 @@ def _transform_static_covs(
393389
series, mask_num, mask_cat
394390
)
395391

396-
# Transform static covs:
397-
tr_out_num, tr_out_cat = None, None
398-
if mask_num.any():
399-
tr_out_num = getattr(transformer_num, method)(vals_num)
400-
if mask_cat.any():
401-
tr_out_cat = getattr(transformer_cat, method)(vals_cat)
402-
# sparse one hot encoding to dense array
403-
if isinstance(tr_out_cat, csr_matrix):
404-
tr_out_cat = tr_out_cat.toarray()
405-
406392
# quick check if everything is in order
407393
n_vals_cat_cols = 0 if vals_cat is None else vals_cat.shape[1]
408394
if (method == "inverse_transform") and (n_vals_cat_cols != n_cat_cols):
@@ -413,6 +399,16 @@ def _transform_static_covs(
413399
logger,
414400
)
415401

402+
# Transform static covs:
403+
tr_out_num, tr_out_cat = None, None
404+
if mask_num.any():
405+
tr_out_num = getattr(transformer_num, method)(vals_num)
406+
if mask_cat.any():
407+
tr_out_cat = getattr(transformer_cat, method)(vals_cat)
408+
# sparse one hot encoding to dense array
409+
if isinstance(tr_out_cat, csr_matrix):
410+
tr_out_cat = tr_out_cat.toarray()
411+
416412
series = StaticCovariatesTransformer._add_back_static_covs(
417413
series, tr_out_num, tr_out_cat, mask_num, mask_cat, col_map_cat
418414
)
@@ -458,8 +454,6 @@ def _add_back_static_covs(
458454
elif is_cat: # categorical transformed column
459455
# covers one to one feature map (ordinal/label encoding) and one to multi feature (one hot encoding)
460456
for col_name in col_map_cat[col]:
461-
if len(col_map_cat[col]) > 1:
462-
col_name = str(col) + "_" + str(col_name)
463457
if col_name not in static_cov_columns:
464458
data[col_name] = vals_cat[:, idx_cat]
465459
static_cov_columns.append(col_name)

darts/tests/dataprocessing/transformers/test_static_covariates_transformer.py

+26
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,32 @@ def test_scaling_multi_series(self):
184184
series_recovered_multi[1].static_covariates
185185
)
186186

187+
def test_zero_cardinality_multi_series(self):
188+
"""Check that inverse-transform works as expected when OneHotEncoder is used on several series with
189+
identical static covariates categories and values.
190+
"""
191+
ts1 = self.series.with_static_covariates(
192+
pd.Series({
193+
"cov_a": "foo",
194+
"cov_b": "foo",
195+
"cov_c": "foo",
196+
})
197+
)
198+
ts2 = self.series.with_static_covariates(
199+
pd.Series({
200+
"cov_a": "foo",
201+
"cov_b": "foo",
202+
"cov_c": "bar",
203+
})
204+
)
205+
206+
transformer = StaticCovariatesTransformer(transformer_cat=OneHotEncoder())
207+
transformer.fit([ts1, ts2])
208+
ts1_enc, ts2_enc = transformer.transform([ts1, ts2])
209+
ts1_inv, ts2_inv = transformer.inverse_transform([ts1_enc, ts2_enc])
210+
pd.testing.assert_frame_equal(ts1_inv.static_covariates, ts1.static_covariates)
211+
pd.testing.assert_frame_equal(ts2_inv.static_covariates, ts2.static_covariates)
212+
187213
def helper_test_scaling(self, series, scaler, test_values):
188214
series_tr = scaler.fit_transform(series)
189215
assert all([

0 commit comments

Comments
 (0)