Skip to content

Commit f76dd9f

Browse files
committed
fix save train result bug
1 parent 7dc02c8 commit f76dd9f

File tree

2 files changed

+1
-8
lines changed

2 files changed

+1
-8
lines changed

paddlets/models/forecasting/dl/PatchTST.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,6 @@ def _update_fit_params(
367367
"known_cov_dim": 0,
368368
"observed_cov_dim": 0
369369
}
370-
if train_tsdataset[0].get_known_cov() is not None:
371-
fit_params["known_cov_dim"] = train_tsdataset[0].get_known_cov(
372-
).data.shape[1]
373-
if train_tsdataset[0].get_observed_cov() is not None:
374-
fit_params["observed_cov_dim"] = train_tsdataset[
375-
0].get_observed_cov().data.shape[1]
376370
return fit_params
377371

378372
def _init_network(self) -> paddle.nn.Layer:

paddlets/utils/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,7 @@ def update_train_results(save_path, score, model_name="", done_flag=True):
535535
train_results["models"]["best"]["score"] = score
536536
for tag in save_model_tag:
537537
train_results["models"]["best"][
538-
tag] = "" if tag != "pdparams" else os.path.join("best_model",
539-
"model.pdparams")
538+
tag] = "" if tag != "pdparams" else "best_accuracy.pdparams.tar"
540539
for tag in save_inference_tag:
541540
train_results["models"]["best"][tag] = os.path.join(
542541
"inference", f"inference.{tag}"

0 commit comments

Comments
 (0)