Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion training/src/anemoi/training/train/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _normalize_batch(self, batch: torch.Tensor) -> torch.Tensor:
Normalized batch
"""
for dataset_name in batch:
batch[dataset_name] = self.model.pre_processors[dataset_name](batch[dataset_name]) # normalized in-place
batch[dataset_name] = self.model.pre_processors[dataset_name](batch[dataset_name]) # normalized in-place
return batch

def _prepare_loss_scalers(self) -> None:
Expand Down
34 changes: 32 additions & 2 deletions training/src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import datetime
import logging
import os
from functools import cached_property
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -42,6 +43,34 @@

LOGGER = logging.getLogger(__name__)

if "TMP_ANEMOI_DEV_DEBUG" in os.environ:
from hydra.utils import instantiate as hydra_instantiate

def instanciate_wrapper(config: Any, *args, **kwargs) -> Any:
config_keys = "+".join(k for k in config if k != "_target_")
try:
target = config._target_
except Exception as e: # noqa
target = "<no _target_>"
try:
return hydra_instantiate(config, *args, **kwargs)
except Exception as e:
config_keys = str(target) + "+" + config_keys

def show_if_dict(v):
if hasattr(v, "keys"): # must be a dict
return v.__class__.__name__ + f"({'+'.join(v.keys())})"
return v.__class__.__name__

_args = ",".join(show_if_dict(a) for a in args)
_kwargs = ",".join(f"{k}={show_if_dict(v)}" for k, v in kwargs.items())
e.add_note(f"This exception happend when doing instanciate for config: {config}")
e.add_note(f"🆕 instanciate({config_keys},{_args},{_kwargs})")
raise e

hydra.utils.instantiate = instanciate_wrapper
instantiate = hydra.utils.instantiate


class AnemoiTrainer:
"""Utility class for training the model."""
Expand Down Expand Up @@ -260,8 +289,9 @@ def model(self) -> pl.LightningModule:
model.data_indices = self.data_indices
# check data indices in original checkpoint and current data indices are the same
self.data_indices.compare_variables(
model._ckpt_model_name_to_index, self.data_indices.name_to_index,
) # TODO for multi dataset
model._ckpt_model_name_to_index,
self.data_indices.name_to_index,
) # T_O_D_O_: for multi dataset

if hasattr(self.config.training, "submodules_to_freeze"):
# Freeze the chosen model weights
Expand Down
Loading