Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/stallion.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@


# save datasets
training.save("t raining.pkl")
training.save("training.pkl")
validation.save("validation.pkl")

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
optuna_logger = logging.getLogger("optuna")


# need to inherit from callback for this to work
class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback):
pass
# On previous versions of optuna, PyTorchLightningPruningCallback did not inherit from pl.Callback.
# In newest it does, so if we try to inherit from it again, we get a TypeError due to bad use of inheritance.
if not issubclass(PyTorchLightningPruningCallback, pl.Callback):
class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback):
pass
else:
PyTorchLightningPruningCallbackAdjusted = PyTorchLightningPruningCallback


def optimize_hyperparameters(
Expand Down Expand Up @@ -109,16 +113,12 @@ def optimize_hyperparameters(
optuna_verbose = logging_level[verbose]
optuna.logging.set_verbosity(optuna_verbose)

loss = kwargs.get(
"loss", QuantileLoss()
) # need a deepcopy of loss as it will otherwise propagate from one trial to the next
loss = kwargs.get("loss", QuantileLoss()) # need a deepcopy of loss as it will otherwise propagate from one trial to the next

# create objective function
def objective(trial: optuna.Trial) -> float:
# Filenames for each trial must be made unique in order to access each checkpoint.
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss"
)
checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss")

learning_rate_callback = LearningRateMonitor()
logger = TensorBoardLogger(log_dir, name="optuna", version=trial.number)
Expand Down
Loading