diff --git a/training/src/anemoi/training/diagnostics/logger.py b/training/src/anemoi/training/diagnostics/logger.py index c266bf45c..b0cd5a8a5 100644 --- a/training/src/anemoi/training/diagnostics/logger.py +++ b/training/src/anemoi/training/diagnostics/logger.py @@ -11,24 +11,35 @@ import logging import os from pathlib import Path +from typing import Any import pytorch_lightning as pl -from omegaconf import DictConfig from omegaconf import OmegaConf -from anemoi.training.schemas.base_schema import BaseSchema from anemoi.training.schemas.base_schema import convert_to_omegaconf LOGGER = logging.getLogger(__name__) -def get_mlflow_logger(config: BaseSchema) -> None: - if not config.diagnostics.log.mlflow.enabled: +def get_mlflow_logger( + diagnostics_config: Any, + run_id: Any, + fork_run_id: Any, + paths: Any, + config: Any, + **kwargs, +) -> None: + del kwargs + + if not diagnostics_config.log.mlflow.enabled: LOGGER.debug("MLFlow logging is disabled.") return None + save_dir = paths.logs.mlflow + artifact_save_dir = paths.get("plots") + # 35 retries allow for 1 hour of server downtime - http_max_retries = config.diagnostics.log.mlflow.http_max_retries + http_max_retries = diagnostics_config.log.mlflow.http_max_retries os.environ["MLFLOW_HTTP_REQUEST_MAX_RETRIES"] = str(http_max_retries) os.environ["_MLFLOW_HTTP_REQUEST_MAX_RETRIES_LIMIT"] = str(http_max_retries + 1) @@ -40,26 +51,23 @@ def get_mlflow_logger(config: BaseSchema) -> None: from anemoi.training.diagnostics.mlflow import MAX_PARAMS_LENGTH from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger - resumed = config.training.run_id is not None - forked = config.training.fork_run_id is not None - - save_dir = config.hardware.paths.logs.mlflow + resumed = run_id is not None + forked = fork_run_id is not None - offline = config.diagnostics.log.mlflow.offline + offline = diagnostics_config.log.mlflow.offline if not offline: - tracking_uri = config.diagnostics.log.mlflow.tracking_uri + tracking_uri = diagnostics_config.log.mlflow.tracking_uri LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) else: tracking_uri = None if (resumed or forked) and (offline): # when resuming or forking offline - - # tracking_uri = ${hardware.paths.logs.mlflow} tracking_uri = str(save_dir) # create directory if it does not exist - Path(config.hardware.paths.logs.mlflow).mkdir(parents=True, exist_ok=True) + Path(save_dir).mkdir(parents=True, exist_ok=True) log_hyperparams = True - if resumed and not config.diagnostics.log.mlflow.on_resume_create_child: + if resumed and not diagnostics_config.log.mlflow.on_resume_create_child: LOGGER.info( ( "Resuming run without creating child run - MLFlow logs will not update the" @@ -70,42 +78,42 @@ def get_mlflow_logger(config: BaseSchema) -> None: ) log_hyperparams = False - max_params_length = getattr(config.diagnostics.log.mlflow, "max_params_length", MAX_PARAMS_LENGTH) + max_params_length = getattr(diagnostics_config.log.mlflow, "max_params_length", MAX_PARAMS_LENGTH) LOGGER.info("Maximum number of params allowed to be logged is: %s", max_params_length) - log_model = getattr(config.diagnostics.log.mlflow, "log_model", LOG_MODEL) + log_model = getattr(diagnostics_config.log.mlflow, "log_model", LOG_MODEL) logger = AnemoiMLflowLogger( - experiment_name=config.diagnostics.log.mlflow.experiment_name, - project_name=config.diagnostics.log.mlflow.project_name, + experiment_name=diagnostics_config.log.mlflow.experiment_name, + project_name=diagnostics_config.log.mlflow.project_name, tracking_uri=tracking_uri, save_dir=save_dir, - run_name=config.diagnostics.log.mlflow.run_name, - run_id=config.training.run_id, - fork_run_id=config.training.fork_run_id, + run_name=diagnostics_config.log.mlflow.run_name, + run_id=run_id, + fork_run_id=fork_run_id, log_model=log_model, offline=offline, resumed=resumed, forked=forked, log_hyperparams=log_hyperparams, - authentication=config.diagnostics.log.mlflow.authentication, - on_resume_create_child=config.diagnostics.log.mlflow.on_resume_create_child, + authentication=diagnostics_config.log.mlflow.authentication, + on_resume_create_child=diagnostics_config.log.mlflow.on_resume_create_child, max_params_length=max_params_length, ) config_params = OmegaConf.to_container(convert_to_omegaconf(config), resolve=True) logger.log_hyperparams( config_params, - expand_keys=config.diagnostics.log.mlflow.expand_hyperparams, + expand_keys=diagnostics_config.log.mlflow.expand_hyperparams, ) - if config.diagnostics.log.mlflow.terminal: - logger.log_terminal_output(artifact_save_dir=config.hardware.paths.plots) - if config.diagnostics.log.mlflow.system: + if diagnostics_config.log.mlflow.terminal: + logger.log_terminal_output(artifact_save_dir=artifact_save_dir) + if diagnostics_config.log.mlflow.system: logger.log_system_metrics() return logger -def get_tensorboard_logger(config: DictConfig) -> pl.loggers.TensorBoardLogger | None: +def get_tensorboard_logger(diagnostics_config: Any, paths: Any, **kwargs) -> pl.loggers.TensorBoardLogger | None: """Setup TensorBoard experiment logger. Parameters @@ -119,19 +127,27 @@ def get_tensorboard_logger(config: DictConfig) -> pl.loggers.TensorBoardLogger | Logger object, or None """ - if not config.diagnostics.log.tensorboard.enabled: + del kwargs + + if not diagnostics_config.log.tensorboard.enabled: LOGGER.debug("Tensorboard logging is disabled.") return None + save_dir = paths.logs.tensorboard + from pytorch_lightning.loggers import TensorBoardLogger - return TensorBoardLogger( - save_dir=config.hardware.paths.logs.tensorboard, - log_graph=False, - ) + return TensorBoardLogger(save_dir=save_dir, log_graph=False) -def get_wandb_logger(config: DictConfig, model: pl.LightningModule) -> pl.loggers.WandbLogger | None: +def get_wandb_logger( + diagnostics_config: Any, + run_id: Any, + paths: Any, + model: pl.LightningModule, + config: Any, + **kwargs, +) -> pl.loggers.WandbLogger | None: """Setup Weights & Biases experiment logger. Parameters @@ -152,10 +168,14 @@ def get_wandb_logger(config: DictConfig, model: pl.LightningModule) -> pl.logger If `wandb` is not installed """ - if not config.diagnostics.log.wandb.enabled: + del kwargs + + if not diagnostics_config.log.wandb.enabled: LOGGER.debug("Weights & Biases logging is disabled.") return None + save_dir = paths.logs.wandb + try: from pytorch_lightning.loggers.wandb import WandbLogger except ImportError as err: @@ -163,22 +183,22 @@ def get_wandb_logger(config: DictConfig, model: pl.LightningModule) -> pl.logger raise ImportError(msg) from err logger = WandbLogger( - project=config.diagnostics.log.wandb.project, - entity=config.diagnostics.log.wandb.entity, - id=config.training.run_id, - save_dir=config.hardware.paths.logs.wandb, - offline=config.diagnostics.log.wandb.offline, - log_model=config.diagnostics.log.wandb.log_model, - resume=config.training.run_id is not None, + project=diagnostics_config.log.wandb.project, + entity=diagnostics_config.log.wandb.entity, + id=run_id, + save_dir=save_dir, + offline=diagnostics_config.log.wandb.offline, + log_model=diagnostics_config.log.wandb.log_model, + resume=run_id is not None, ) logger.log_hyperparams(OmegaConf.to_container(config, resolve=True)) - if config.diagnostics.log.wandb.gradients or config.diagnostics.log.wandb.parameters: - if config.diagnostics.log.wandb.gradients and config.diagnostics.log.wandb.parameters: + if diagnostics_config.log.wandb.gradients or diagnostics_config.log.wandb.parameters: + if diagnostics_config.log.wandb.gradients and diagnostics_config.log.wandb.parameters: log_ = "all" - elif config.diagnostics.log.wandb.gradients: + elif diagnostics_config.log.wandb.gradients: log_ = "gradients" else: log_ = "parameters" - logger.watch(model, log=log_, log_freq=config.diagnostics.log.interval, log_graph=False) + logger.watch(model, log=log_, log_freq=diagnostics_config.log.interval, log_graph=False) return logger diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py index 7ae1bf6e5..e1a40c98c 100644 --- a/training/src/anemoi/training/train/train.py +++ b/training/src/anemoi/training/train/train.py @@ -271,21 +271,6 @@ def run_id(self) -> str: return str(uuid.uuid4()) - @cached_property - def wandb_logger(self) -> pl.loggers.WandbLogger: - """WandB logger.""" - return get_wandb_logger(self.config, self.model) - - @cached_property - def mlflow_logger(self) -> pl.loggers.MLFlowLogger: - """Mlflow logger.""" - return get_mlflow_logger(self.config) - - @cached_property - def tensorboard_logger(self) -> pl.loggers.TensorBoardLogger: - """TensorBoard logger.""" - return get_tensorboard_logger(self.config) - def _get_warm_start_checkpoint(self) -> Path | None: """Returns the warm start checkpoint path if specified.""" warm_start_dir = getattr(self.config.hardware.paths, "warm_start", None) # avoid breaking change @@ -380,16 +365,26 @@ def profiler(self) -> PyTorchProfiler | None: @cached_property def loggers(self) -> list: + diagnostics_config = self.config.diagnostics + + kwargs = { + "diagnostics_config": self.diagnostics_config, + "run_id": self.my_config.run_id, + "fork_run_id": self.my_config.fork_run_id, + "paths": self.config.hardware.paths, + "model": self.model_task, + "config": self.config, + } loggers = [] - if self.config.diagnostics.log.wandb.enabled: + if diagnostics_config.log.wandb.enabled: LOGGER.info("W&B logger enabled") - loggers.append(self.wandb_logger) - if self.config.diagnostics.log.tensorboard.enabled: + loggers.append(get_wandb_logger(**kwargs)) + if diagnostics_config.log.tensorboard.enabled: LOGGER.info("TensorBoard logger enabled") - loggers.append(self.tensorboard_logger) - if self.config.diagnostics.log.mlflow.enabled: + loggers.append(get_tensorboard_logger(**kwargs)) + if diagnostics_config.log.mlflow.enabled: LOGGER.info("MLFlow logger enabled") - loggers.append(self.mlflow_logger) + loggers.append(get_mlflow_logger(**kwargs)) return loggers @cached_property