diff --git a/training/src/anemoi/training/train/tasks/base.py b/training/src/anemoi/training/train/tasks/base.py index 45883c88d..bf9b3d207 100644 --- a/training/src/anemoi/training/train/tasks/base.py +++ b/training/src/anemoi/training/train/tasks/base.py @@ -748,7 +748,7 @@ def _normalize_batch(self, batch: torch.Tensor) -> torch.Tensor: Normalized batch """ for dataset_name in batch: - batch[dataset_name] = self.model.pre_processors[dataset_name](batch[dataset_name]) # normalized in-place + batch[dataset_name] = self.model.pre_processors[dataset_name](batch[dataset_name]) # normalized in-place return batch def _prepare_loss_scalers(self) -> None: diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py index 169ca1a3a..023040b68 100644 --- a/training/src/anemoi/training/train/train.py +++ b/training/src/anemoi/training/train/train.py @@ -10,6 +10,7 @@ import datetime import logging +import os from functools import cached_property from pathlib import Path from typing import Any @@ -42,6 +43,34 @@ LOGGER = logging.getLogger(__name__) +if "TMP_ANEMOI_DEV_DEBUG" in os.environ: + from hydra.utils import instantiate as hydra_instantiate + + def instanciate_wrapper(config: Any, *args, **kwargs) -> Any: + config_keys = "+".join(k for k in config if k != "_target_") + try: + target = config._target_ + except Exception as e: # noqa + target = "" + try: + return hydra_instantiate(config, *args, **kwargs) + except Exception as e: + config_keys = str(target) + "+" + config_keys + + def show_if_dict(v): + if hasattr(v, "keys"): # must be a dict + return v.__class__.__name__ + f"({'+'.join(v.keys())})" + return v.__class__.__name__ + + _args = ",".join(show_if_dict(a) for a in args) + _kwargs = ",".join(f"{k}={show_if_dict(v)}" for k, v in kwargs.items()) + e.add_note(f"This exception happend when doing instanciate for config: {config}") + e.add_note(f"🆕 instanciate({config_keys},{_args},{_kwargs})") + raise e + + hydra.utils.instantiate = instanciate_wrapper + instantiate = hydra.utils.instantiate + class AnemoiTrainer: """Utility class for training the model.""" @@ -260,8 +289,9 @@ def model(self) -> pl.LightningModule: model.data_indices = self.data_indices # check data indices in original checkpoint and current data indices are the same self.data_indices.compare_variables( - model._ckpt_model_name_to_index, self.data_indices.name_to_index, - ) # TODO for multi dataset + model._ckpt_model_name_to_index, + self.data_indices.name_to_index, + ) # T_O_D_O_: for multi dataset if hasattr(self.config.training, "submodules_to_freeze"): # Freeze the chosen model weights