Skip to content

Commit 5f138d9

Browse files
committed
fix save train result bug
1 parent 7dc02c8 commit 5f138d9

File tree

3 files changed

+8
-14
lines changed

3 files changed

+8
-14
lines changed

paddlets/models/forecasting/dl/PatchTST.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,6 @@ def _update_fit_params(
367367
"known_cov_dim": 0,
368368
"observed_cov_dim": 0
369369
}
370-
if train_tsdataset[0].get_known_cov() is not None:
371-
fit_params["known_cov_dim"] = train_tsdataset[0].get_known_cov(
372-
).data.shape[1]
373-
if train_tsdataset[0].get_observed_cov() is not None:
374-
fit_params["observed_cov_dim"] = train_tsdataset[
375-
0].get_observed_cov().data.shape[1]
376370
return fit_params
377371

378372
def _init_network(self) -> paddle.nn.Layer:

paddlets/models/forecasting/dl/RLinear.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ def _update_fit_params(
221221
"known_cov_dim": 0,
222222
"observed_cov_dim": 0
223223
}
224-
if train_tsdataset[0].get_known_cov() is not None:
225-
fit_params["known_cov_dim"] = train_tsdataset[0].get_known_cov(
226-
).data.shape[1]
227-
if train_tsdataset[0].get_observed_cov() is not None:
228-
fit_params["observed_cov_dim"] = train_tsdataset[
229-
0].get_observed_cov().data.shape[1]
224+
#if train_tsdataset[0].get_known_cov() is not None:
225+
# fit_params["known_cov_dim"] = train_tsdataset[0].get_known_cov(
226+
# ).data.shape[1]
227+
#if train_tsdataset[0].get_observed_cov() is not None:
228+
# fit_params["observed_cov_dim"] = train_tsdataset[
229+
# 0].get_observed_cov().data.shape[1]
230230
return fit_params
231231

232232
def _init_network(self) -> paddle.nn.Layer:

paddlets/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,12 +535,12 @@ def update_train_results(save_path, score, model_name="", done_flag=True):
535535
train_results["models"]["best"]["score"] = score
536536
for tag in save_model_tag:
537537
train_results["models"]["best"][
538-
tag] = "" if tag != "pdparams" else os.path.join("best_model",
539-
"model.pdparams")
538+
tag] = "" if tag != "pdparams" else "best_accuracy.pdparams.tar"
540539
for tag in save_inference_tag:
541540
train_results["models"]["best"][tag] = os.path.join(
542541
"inference", f"inference.{tag}"
543542
if tag != "inference_config" else "inference.yml")
543+
train_results["models"]["best"]["pdiparams"] = "best_accuracy.pdparams.tar"
544544

545545
train_results = convert_and_remove_types(train_results)
546546
with open(train_results_path, "w") as fp:

0 commit comments

Comments
 (0)