Skip to content

Commit 213ba81

Browse files
authored
Update torch_forecasting_model.py (#4)
1 parent a184762 commit 213ba81

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

darts/models/forecasting/torch_forecasting_model.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -2048,7 +2048,6 @@ def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
20482048
logger=logger,
20492049
)
20502050

2051-
dim_component = self.past_covariate_series.n_components
20522051
(
20532052
past_target,
20542053
past_covariates,
@@ -2057,13 +2056,9 @@ def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
20572056
# I think these have to do with future covariates (which isn't supported in Dlinear)
20582057
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in self.train_sample]
20592058

2060-
n_past_covs = (
2061-
past_covariates.shape[dim_component] if past_covariates is not None else 0
2062-
)
2063-
20642059
input_past = torch.cat(
20652060
[ds for ds in [past_target, past_covariates] if ds is not None],
2066-
dim=dim_component,
2061+
dim=2, # Shape is (1, lookback_size, no. of variates (in either target or series))
20672062
)
20682063

20692064
input_sample = [input_past.float(), static_covariates.float() if static_covariates is not None else None]

0 commit comments

Comments
 (0)