From 6364b25c65c11e766dd760110063492b928c0c25 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 4 Aug 2024 21:25:10 +0200 Subject: [PATCH 1/2] Compatibility with newer optuna version --- .../temporal_fusion_transformer/tuning.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 1344c0b78..80103c17a 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -1,6 +1,7 @@ """ Hyperparameters can be efficiently tuned with `optuna `_. """ + import copy import logging import os @@ -25,9 +26,13 @@ optuna_logger = logging.getLogger("optuna") -# need to inherit from callback for this to work -class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): - pass +# On previous versions of optuna, PyTorchLightningPruningCallback did not inherit from pl.Callback. +# In newest it does, so if we try to inherit from it again, we get a TypeError due to bad use of inheritance. +if not issubclass(PyTorchLightningPruningCallback, pl.Callback): + class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): + pass +else: + PyTorchLightningPruningCallbackAdjusted = PyTorchLightningPruningCallback def optimize_hyperparameters( @@ -108,16 +113,12 @@ def optimize_hyperparameters( optuna_verbose = logging_level[verbose] optuna.logging.set_verbosity(optuna_verbose) - loss = kwargs.get( - "loss", QuantileLoss() - ) # need a deepcopy of loss as it will otherwise propagate from one trial to the next + loss = kwargs.get("loss", QuantileLoss()) # need a deepcopy of loss as it will otherwise propagate from one trial to the next # create objective function def objective(trial: optuna.Trial) -> float: # Filenames for each trial must be made unique in order to access each checkpoint. - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss" - ) + checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss") learning_rate_callback = LearningRateMonitor() logger = TensorBoardLogger(log_dir, name="optuna", version=trial.number) From 9d5eb777ab48cafd7180e00bccf7f9479d579791 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 4 Aug 2024 21:25:52 +0200 Subject: [PATCH 2/2] Fix typo --- examples/stallion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stallion.py b/examples/stallion.py index 55584cf5c..958511ba7 100644 --- a/examples/stallion.py +++ b/examples/stallion.py @@ -90,7 +90,7 @@ # save datasets -training.save("t raining.pkl") +training.save("training.pkl") validation.save("validation.pkl") early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")