Skip to content

Commit 14c0235

Browse files
fix(logger): Bugs in AzureMLFlowLogger from #646 (#685)
Finally got around to trying our work in #646, and encountered these two issues - a stale reference to `AnemoiMLFlowLogger` - `getattr` -> `dict.get` I'll try to add a test or two that catches these. Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent 24d162c commit 14c0235

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

training/src/anemoi/training/diagnostics/callbacks/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def tracker_metadata(self, trainer: pl.Trainer) -> dict:
122122
if self.config.diagnostics.log.mlflow.enabled:
123123
self._tracker_name = "mlflow"
124124

125-
from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger
125+
from anemoi.training.diagnostics.mlflow.logger import BaseAnemoiMLflowLogger
126126

127-
mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger))
127+
mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, BaseAnemoiMLflowLogger))
128128
run_id = mlflow_logger.run_id
129129
run = mlflow_logger._mlflow_client.get_run(run_id)
130130

training/src/anemoi/training/diagnostics/logger.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ def get_mlflow_logger(config: BaseSchema) -> None:
3030
del logger_config["enabled"]
3131

3232
# backward compatibility to not break configs
33-
logger_config["_target_"] = getattr(
34-
logger_config,
35-
"_target",
33+
logger_config["_target_"] = logger_config.get(
34+
"_target_",
3635
"anemoi.training.diagnostics.mlflow.logger.AnemoiMLflowLogger",
3736
)
38-
logger_config["save_dir"] = getattr(logger_config, "save_dir", str(config.hardware.paths.logs.mlflow))
37+
logger_config["save_dir"] = logger_config.get("save_dir", str(config.hardware.paths.logs.mlflow))
3938

4039
logger = instantiate(
4140
logger_config,

0 commit comments

Comments
 (0)