Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 66 additions & 46 deletions training/src/anemoi/training/diagnostics/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,35 @@
import logging
import os
from pathlib import Path
from typing import Any

import pytorch_lightning as pl
from omegaconf import DictConfig
from omegaconf import OmegaConf

from anemoi.training.schemas.base_schema import BaseSchema
from anemoi.training.schemas.base_schema import convert_to_omegaconf

LOGGER = logging.getLogger(__name__)


def get_mlflow_logger(config: BaseSchema) -> None:
if not config.diagnostics.log.mlflow.enabled:
def get_mlflow_logger(
diagnostics_config: Any,
run_id: Any,
fork_run_id: Any,
paths: Any,
config: Any,
**kwargs,
) -> None:
del kwargs

if not diagnostics_config.log.mlflow.enabled:
LOGGER.debug("MLFlow logging is disabled.")
return None

save_dir = paths.logs.mlflow
artifact_save_dir = paths.get("plots")

# 35 retries allow for 1 hour of server downtime
http_max_retries = config.diagnostics.log.mlflow.http_max_retries
http_max_retries = diagnostics_config.log.mlflow.http_max_retries

os.environ["MLFLOW_HTTP_REQUEST_MAX_RETRIES"] = str(http_max_retries)
os.environ["_MLFLOW_HTTP_REQUEST_MAX_RETRIES_LIMIT"] = str(http_max_retries + 1)
Expand All @@ -40,26 +51,23 @@ def get_mlflow_logger(config: BaseSchema) -> None:
from anemoi.training.diagnostics.mlflow import MAX_PARAMS_LENGTH
from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger

resumed = config.training.run_id is not None
forked = config.training.fork_run_id is not None

save_dir = config.hardware.paths.logs.mlflow
resumed = run_id is not None
forked = fork_run_id is not None

offline = config.diagnostics.log.mlflow.offline
offline = diagnostics_config.log.mlflow.offline
if not offline:
tracking_uri = config.diagnostics.log.mlflow.tracking_uri
tracking_uri = diagnostics_config.log.mlflow.tracking_uri
LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri)
else:
tracking_uri = None

if (resumed or forked) and (offline): # when resuming or forking offline -
# tracking_uri = ${hardware.paths.logs.mlflow}
tracking_uri = str(save_dir)
# create directory if it does not exist
Path(config.hardware.paths.logs.mlflow).mkdir(parents=True, exist_ok=True)
Path(save_dir).mkdir(parents=True, exist_ok=True)

log_hyperparams = True
if resumed and not config.diagnostics.log.mlflow.on_resume_create_child:
if resumed and not diagnostics_config.log.mlflow.on_resume_create_child:
LOGGER.info(
(
"Resuming run without creating child run - MLFlow logs will not update the"
Expand All @@ -70,42 +78,42 @@ def get_mlflow_logger(config: BaseSchema) -> None:
)
log_hyperparams = False

max_params_length = getattr(config.diagnostics.log.mlflow, "max_params_length", MAX_PARAMS_LENGTH)
max_params_length = getattr(diagnostics_config.log.mlflow, "max_params_length", MAX_PARAMS_LENGTH)
LOGGER.info("Maximum number of params allowed to be logged is: %s", max_params_length)
log_model = getattr(config.diagnostics.log.mlflow, "log_model", LOG_MODEL)
log_model = getattr(diagnostics_config.log.mlflow, "log_model", LOG_MODEL)

logger = AnemoiMLflowLogger(
experiment_name=config.diagnostics.log.mlflow.experiment_name,
project_name=config.diagnostics.log.mlflow.project_name,
experiment_name=diagnostics_config.log.mlflow.experiment_name,
project_name=diagnostics_config.log.mlflow.project_name,
tracking_uri=tracking_uri,
save_dir=save_dir,
run_name=config.diagnostics.log.mlflow.run_name,
run_id=config.training.run_id,
fork_run_id=config.training.fork_run_id,
run_name=diagnostics_config.log.mlflow.run_name,
run_id=run_id,
fork_run_id=fork_run_id,
log_model=log_model,
offline=offline,
resumed=resumed,
forked=forked,
log_hyperparams=log_hyperparams,
authentication=config.diagnostics.log.mlflow.authentication,
on_resume_create_child=config.diagnostics.log.mlflow.on_resume_create_child,
authentication=diagnostics_config.log.mlflow.authentication,
on_resume_create_child=diagnostics_config.log.mlflow.on_resume_create_child,
max_params_length=max_params_length,
)
config_params = OmegaConf.to_container(convert_to_omegaconf(config), resolve=True)
logger.log_hyperparams(
config_params,
expand_keys=config.diagnostics.log.mlflow.expand_hyperparams,
expand_keys=diagnostics_config.log.mlflow.expand_hyperparams,
)

if config.diagnostics.log.mlflow.terminal:
logger.log_terminal_output(artifact_save_dir=config.hardware.paths.plots)
if config.diagnostics.log.mlflow.system:
if diagnostics_config.log.mlflow.terminal:
logger.log_terminal_output(artifact_save_dir=artifact_save_dir)
if diagnostics_config.log.mlflow.system:
logger.log_system_metrics()

return logger


def get_tensorboard_logger(config: DictConfig) -> pl.loggers.TensorBoardLogger | None:
def get_tensorboard_logger(diagnostics_config: Any, paths: Any, **kwargs) -> pl.loggers.TensorBoardLogger | None:
"""Setup TensorBoard experiment logger.

Parameters
Expand All @@ -119,19 +127,27 @@ def get_tensorboard_logger(config: DictConfig) -> pl.loggers.TensorBoardLogger |
Logger object, or None

"""
if not config.diagnostics.log.tensorboard.enabled:
del kwargs

if not diagnostics_config.log.tensorboard.enabled:
LOGGER.debug("Tensorboard logging is disabled.")
return None

save_dir = paths.logs.tensorboard

from pytorch_lightning.loggers import TensorBoardLogger

return TensorBoardLogger(
save_dir=config.hardware.paths.logs.tensorboard,
log_graph=False,
)
return TensorBoardLogger(save_dir=save_dir, log_graph=False)


def get_wandb_logger(config: DictConfig, model: pl.LightningModule) -> pl.loggers.WandbLogger | None:
def get_wandb_logger(
diagnostics_config: Any,
run_id: Any,
paths: Any,
model: pl.LightningModule,
config: Any,
**kwargs,
) -> pl.loggers.WandbLogger | None:
"""Setup Weights & Biases experiment logger.

Parameters
Expand All @@ -152,33 +168,37 @@ def get_wandb_logger(config: DictConfig, model: pl.LightningModule) -> pl.logger
If `wandb` is not installed

"""
if not config.diagnostics.log.wandb.enabled:
del kwargs

if not diagnostics_config.log.wandb.enabled:
LOGGER.debug("Weights & Biases logging is disabled.")
return None

save_dir = paths.logs.wandb

try:
from pytorch_lightning.loggers.wandb import WandbLogger
except ImportError as err:
msg = "To activate W&B logging, please install `wandb` as an optional dependency."
raise ImportError(msg) from err

logger = WandbLogger(
project=config.diagnostics.log.wandb.project,
entity=config.diagnostics.log.wandb.entity,
id=config.training.run_id,
save_dir=config.hardware.paths.logs.wandb,
offline=config.diagnostics.log.wandb.offline,
log_model=config.diagnostics.log.wandb.log_model,
resume=config.training.run_id is not None,
project=diagnostics_config.log.wandb.project,
entity=diagnostics_config.log.wandb.entity,
id=run_id,
save_dir=save_dir,
offline=diagnostics_config.log.wandb.offline,
log_model=diagnostics_config.log.wandb.log_model,
resume=run_id is not None,
)
logger.log_hyperparams(OmegaConf.to_container(config, resolve=True))
if config.diagnostics.log.wandb.gradients or config.diagnostics.log.wandb.parameters:
if config.diagnostics.log.wandb.gradients and config.diagnostics.log.wandb.parameters:
if diagnostics_config.log.wandb.gradients or diagnostics_config.log.wandb.parameters:
if diagnostics_config.log.wandb.gradients and diagnostics_config.log.wandb.parameters:
log_ = "all"
elif config.diagnostics.log.wandb.gradients:
elif diagnostics_config.log.wandb.gradients:
log_ = "gradients"
else:
log_ = "parameters"
logger.watch(model, log=log_, log_freq=config.diagnostics.log.interval, log_graph=False)
logger.watch(model, log=log_, log_freq=diagnostics_config.log.interval, log_graph=False)

return logger
37 changes: 16 additions & 21 deletions training/src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,6 @@ def run_id(self) -> str:

return str(uuid.uuid4())

@cached_property
def wandb_logger(self) -> pl.loggers.WandbLogger:
"""WandB logger."""
return get_wandb_logger(self.config, self.model)

@cached_property
def mlflow_logger(self) -> pl.loggers.MLFlowLogger:
"""Mlflow logger."""
return get_mlflow_logger(self.config)

@cached_property
def tensorboard_logger(self) -> pl.loggers.TensorBoardLogger:
"""TensorBoard logger."""
return get_tensorboard_logger(self.config)

def _get_warm_start_checkpoint(self) -> Path | None:
"""Returns the warm start checkpoint path if specified."""
warm_start_dir = getattr(self.config.hardware.paths, "warm_start", None) # avoid breaking change
Expand Down Expand Up @@ -380,16 +365,26 @@ def profiler(self) -> PyTorchProfiler | None:

@cached_property
def loggers(self) -> list:
diagnostics_config = self.config.diagnostics

kwargs = {
"diagnostics_config": self.diagnostics_config,
"run_id": self.my_config.run_id,
"fork_run_id": self.my_config.fork_run_id,
"paths": self.config.hardware.paths,
"model": self.model_task,
"config": self.config,
}
loggers = []
if self.config.diagnostics.log.wandb.enabled:
if diagnostics_config.log.wandb.enabled:
LOGGER.info("W&B logger enabled")
loggers.append(self.wandb_logger)
if self.config.diagnostics.log.tensorboard.enabled:
loggers.append(get_wandb_logger(**kwargs))
if diagnostics_config.log.tensorboard.enabled:
LOGGER.info("TensorBoard logger enabled")
loggers.append(self.tensorboard_logger)
if self.config.diagnostics.log.mlflow.enabled:
loggers.append(get_tensorboard_logger(**kwargs))
if diagnostics_config.log.mlflow.enabled:
LOGGER.info("MLFlow logger enabled")
loggers.append(self.mlflow_logger)
loggers.append(get_mlflow_logger(**kwargs))
return loggers

@cached_property
Expand Down
Loading