diff --git a/.gitkeep b/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/training/docs/user-guide/configuring.rst b/training/docs/user-guide/configuring.rst index c28432b52..b86e524e7 100644 --- a/training/docs/user-guide/configuring.rst +++ b/training/docs/user-guide/configuring.rst @@ -420,3 +420,67 @@ error reported is not very intuitive and indeed hides the real issue. We will work on improving this on future releases, but mean time we recommend to double check the schemas and the config files to make sure they are correctly defined. + +******************** + Checkpoint Loading +******************** + +Anemoi Training supports flexible checkpoint loading to initialize model +weights from saved checkpoints. This system supports multiple sources +and loading strategies. + +Configuration +============= + +To configure checkpoint loading, add a ``checkpoint_loading`` section to +your training config: + +.. code:: yaml + + training: + checkpoint_loading: + source: "/path/to/checkpoint.ckpt" + loader_type: "weights_only" + strict: true + +Available Options +================= + +- **source**: Path or URL to checkpoint file (local, S3, HTTP, GCS, + Azure) +- **loader_type**: Strategy for loading ("weights_only", + "transfer_learning", "standard") +- **strict**: Whether to require exact parameter matching (optional) + +Loader Types +============ + +**weights_only** + Load only model weights, ignoring optimizer and scheduler states + +**transfer_learning** + Load weights with size mismatch handling for cross-model transfer + +**standard** + Full Lightning checkpoint loading (weights + optimizer + scheduler) + +Remote Sources +============== + +The system supports loading from multiple remote sources: + +.. code:: yaml + + training: + checkpoint_loading: + # S3 checkpoint + source: "s3://my-bucket/models/checkpoint.ckpt" + + # HTTP checkpoint + source: "https://example.com/models/checkpoint.ckpt" + + # Google Cloud Storage + source: "gs://my-bucket/models/checkpoint.ckpt" + + # Azure Blob Storage + source: "azure://account.blob.core.windows.net/container/checkpoint.ckpt" diff --git a/training/docs/user-guide/training.rst b/training/docs/user-guide/training.rst index a365c8318..966ea6fba 100644 --- a/training/docs/user-guide/training.rst +++ b/training/docs/user-guide/training.rst @@ -281,64 +281,65 @@ specific point they can do this by setting ``config.hardware.files.warm_start`` to be the checkpoint they want to restart from.. -******************* - Transfer Learning -******************* +******************** + Checkpoint Loading +******************** -Transfer learning allows the model to reuse knowledge from a previously -trained checkpoint. This is particularly useful when the new task is -related to the old one, enabling faster convergence and often improving -model performance. +Multi-source checkpoint loading with automatic format detection. -To enable transfer learning, set the config.training.transfer_learning -flag to True in the configuration file. +Quick Start +=========== .. code:: yaml training: - # start the training from a checkpoint of a previous run - fork_run_id: ... - load_weights_only: True - transfer_learning: True - -When this flag is active and a checkpoint path is specified in -config.hardware.files.warm_start or self.last_checkpoint, the system -loads the pre-trained weights using the `transfer_learning_loading` -function. This approach ensures only compatible weights are loaded and -mismatched layers are handled appropriately. - -For example, transfer learning might be used to adapt a weather -forecasting model trained on one geographic region to another region -with similar characteristics. - -**************** - Model Freezing -**************** - -Model freezing is a technique where specific parts (submodules) of a -model are excluded from training. This is useful when certain parts of -the model have been sufficiently trained or should remain unchanged for -the current task. - -To specify which submodules to freeze, use the -config.training.submodules_to_freeze field in the configuration. List -the names of submodules to be frozen. During model initialization, these -submodules will have their parameters frozen, ensuring they are not -updated during training. - -For example with the following configuration, the processor will be -frozen and only the encoder and decoder will be trained: + checkpoint_loading: + source: "s3://bucket/model.ckpt" # Any source + loader_type: "transfer_learning" # Strategy + strict: false # Allow mismatches + +Sources +======= + ++----------+--------------------------------------+-----------------+ +| Source | URL Format | Dependency | ++==========+======================================+=================+ +| Local | ``/path/to/file.ckpt`` | None | ++----------+--------------------------------------+-----------------+ +| S3 | ``s3://bucket/file.ckpt`` | ``boto3`` | ++----------+--------------------------------------+-----------------+ +| HTTP | ``https://url/file.ckpt`` | None | ++----------+--------------------------------------+-----------------+ +| GCS | ``gs://bucket/file.ckpt`` | ``google-cloud``| ++----------+--------------------------------------+-----------------+ +| Azure | ``azure://account.../file.ckpt`` | ``azure-storage``| ++----------+--------------------------------------+-----------------+ + +Loader Types +============ + +- **weights_only**: Model weights only +- **transfer_learning**: Handle size mismatches +- **standard**: Full checkpoint (weights + optimizer) -.. code:: yaml +******************* + Transfer Learning +******************* - training: - # start the training from a checkpoint of a previous run - fork_run_id: ... - load_weights_only: True +Transfer learning allows automatic handling of size mismatches between models: - submodules_to_freeze: - - processor +.. code:: yaml -Freezing can be particularly beneficial in scenarios such as fine-tuning -when only specific components (e.g., the encoder, the decoder) need to -adapt to a new task while keeping others (e.g., the processor) fixed. + training: + checkpoint_loading: + source: "/path/to/pretrained.ckpt" + loader_type: "transfer_learning" + strict: false # allow parameter mismatches + skip_mismatched: true # skip layers with shape mismatches + +The transfer learning system automatically: + +- Identifies parameter shape mismatches +- Logs which parameters are skipped due to mismatches +- Loads only compatible weights +- Preserves data indices for validation diff --git a/training/src/anemoi/training/config/training/checkpoint_loading/standard.yml b/training/src/anemoi/training/config/training/checkpoint_loading/standard.yml new file mode 100644 index 000000000..d9c66e8d4 --- /dev/null +++ b/training/src/anemoi/training/config/training/checkpoint_loading/standard.yml @@ -0,0 +1,3 @@ +source: null # Path or URL to checkpoint file +loader_type: "standard" # Full Lightning checkpoint loading +strict: true # Require exact parameter matching diff --git a/training/src/anemoi/training/config/training/checkpoint_loading/transfer_learning.yml b/training/src/anemoi/training/config/training/checkpoint_loading/transfer_learning.yml new file mode 100644 index 000000000..b4f84235f --- /dev/null +++ b/training/src/anemoi/training/config/training/checkpoint_loading/transfer_learning.yml @@ -0,0 +1,4 @@ +source: null # Path or URL to checkpoint file +loader_type: "transfer_learning" # Handle size mismatches +strict: false # Allow parameter mismatches +skip_mismatched: true # Skip parameters with shape mismatches diff --git a/training/src/anemoi/training/config/training/checkpoint_loading/weights_only.yml b/training/src/anemoi/training/config/training/checkpoint_loading/weights_only.yml new file mode 100644 index 000000000..eaba51ea2 --- /dev/null +++ b/training/src/anemoi/training/config/training/checkpoint_loading/weights_only.yml @@ -0,0 +1,3 @@ +source: null # Path or URL to checkpoint file +loader_type: "weights_only" # Load only model weights +strict: true # Require exact parameter matching diff --git a/training/src/anemoi/training/schemas/training.py b/training/src/anemoi/training/schemas/training.py index f56d5dd70..0a858b3db 100644 --- a/training/src/anemoi/training/schemas/training.py +++ b/training/src/anemoi/training/schemas/training.py @@ -263,5 +263,7 @@ class TrainingSchema(BaseModel): "Configuration of the pressure level scaler apllied in the loss computation." metrics: list[str] "List of metrics" + checkpoint_loading: Union[dict, None] = Field(default=None) + "Checkpoint loading configuration for initializing model weights from saved checkpoints." node_loss_weights: NodeLossWeightsSchema "Node loss weights configuration." diff --git a/training/src/anemoi/training/train/train.py b/training/src/anemoi/training/train/train.py index 2f29a181b..8873e1bfd 100644 --- a/training/src/anemoi/training/train/train.py +++ b/training/src/anemoi/training/train/train.py @@ -167,8 +167,13 @@ def model(self) -> GraphForecaster: model = GraphForecaster(**kwargs) - # Load the model weights - if self.load_weights_only: + # Load checkpoint weights if configured + model = self._load_checkpoint_if_configured(model) + + # Legacy checkpoint loading (for compatibility) + if self.load_weights_only and not ( + hasattr(self.config.training, "checkpoint_loading") and self.config.training.checkpoint_loading + ): if hasattr(self.config.training, "transfer_learning"): # Sanify the checkpoint for transfer learning if self.config.training.transfer_learning: @@ -191,6 +196,39 @@ def model(self) -> GraphForecaster: return model + def _load_checkpoint_if_configured(self, model: torch.nn.Module) -> torch.nn.Module: + """Load checkpoint weights if checkpoint_loading is configured.""" + if not hasattr(self.config.training, "checkpoint_loading") or not self.config.training.checkpoint_loading: + return model + + checkpoint_config = self.config.training.checkpoint_loading + + if not checkpoint_config.source: + LOGGER.warning("checkpoint_loading configured but no source specified") + return model + + from anemoi.training.utils.model_loading import load_model_from_checkpoint + + LOGGER.info( + "Loading checkpoint from %s using %s loader", + checkpoint_config.source, + checkpoint_config.loader_type, + ) + + # Extract parameters from checkpoint config + loader_kwargs = {} + if hasattr(checkpoint_config, "strict"): + loader_kwargs["strict"] = checkpoint_config.strict + if hasattr(checkpoint_config, "skip_mismatched"): + loader_kwargs["skip_mismatched"] = checkpoint_config.skip_mismatched + + return load_model_from_checkpoint( + model=model, + checkpoint_source=checkpoint_config.source, + loader_type=checkpoint_config.loader_type, + **loader_kwargs, + ) + @rank_zero_only def _get_mlflow_run_id(self) -> str: run_id = self.mlflow_logger.run_id diff --git a/training/src/anemoi/training/utils/checkpoint_loaders.py b/training/src/anemoi/training/utils/checkpoint_loaders.py new file mode 100644 index 000000000..6badf58a1 --- /dev/null +++ b/training/src/anemoi/training/utils/checkpoint_loaders.py @@ -0,0 +1,1271 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Checkpoint Loading System for Anemoi Training. + +This module provides a flexible and extensible checkpoint loading system that supports +loading PyTorch model checkpoints from various sources including local filesystem, +S3, HTTP/HTTPS, Google Cloud Storage, and Azure Blob Storage. + +Key Features +============ + +* **Multi-source Support**: Load checkpoints from local files, cloud storage, or web URLs +* **Automatic Source Detection**: Loaders automatically detect compatible sources based on URL schemes +* **Registry Pattern**: Extensible loader registry for custom checkpoint sources +* **Error Handling**: Robust error handling with informative messages for debugging +* **Cloud Integration**: Built-in support for major cloud storage providers + +Architecture +============ + +The system uses a registry pattern with abstract base classes: + +1. **CheckpointLoader**: Abstract base class defining the loader interface +2. **Concrete Loaders**: Specific implementations for different sources (local, remote) +3. **CheckpointLoaderRegistry**: Central registry that manages loaders and routes requests +4. **Global Registry**: Pre-configured registry instance ready for immediate use + +Supported Sources +================= + +* **Local Files**: Any file system path (``/path/to/checkpoint.ckpt``) +* **HTTP/HTTPS**: Web URLs (``https://example.com/model.ckpt``) +* **Amazon S3**: S3 URLs (``s3://bucket/path/model.ckpt``) +* **Google Cloud Storage**: GCS URLs (``gs://bucket/path/model.ckpt``) +* **Azure Blob Storage**: Azure URLs (``azure://account.blob.core.windows.net/container/model.ckpt``) + +Basic Usage +=========== + +.. code-block:: python + + from anemoi.training.utils.checkpoint_loaders import load_checkpoint_from_source + + # Load from local file + checkpoint = load_checkpoint_from_source("/path/to/model.ckpt") + + # Load from S3 + checkpoint = load_checkpoint_from_source("s3://my-bucket/models/checkpoint.ckpt") + + # Load from HTTP + checkpoint = load_checkpoint_from_source("https://example.com/model.ckpt") + + # The checkpoint is a standard PyTorch dictionary + model.load_state_dict(checkpoint["state_dict"]) + +Advanced Usage +============== + +.. code-block:: python + + from anemoi.training.utils.checkpoint_loaders import ( + CheckpointLoaderRegistry, + RemoteCheckpointLoader, + checkpoint_registry + ) + + # Use the global registry directly + loader = checkpoint_registry.get_loader("s3://my-bucket/model.ckpt") + checkpoint = loader.load_checkpoint("s3://my-bucket/model.ckpt") + + # Create a custom registry + custom_registry = CheckpointLoaderRegistry() + custom_registry.register(RemoteCheckpointLoader()) + checkpoint = custom_registry.load_checkpoint("https://example.com/model.ckpt") + +Extending the System +==================== + +To add support for custom checkpoint sources: + +.. code-block:: python + + class CustomCheckpointLoader(CheckpointLoader): + def supports_source(self, source: str | Path) -> bool: + return str(source).startswith("custom://") + + def load_checkpoint(self, source: str | Path) -> dict: + # Custom loading logic + return custom_load_function(source) + + # Register with the global registry + checkpoint_registry.register(CustomCheckpointLoader()) + +Cloud Provider Setup +==================== + +For cloud storage access, ensure proper authentication: + +**AWS S3**: + Configure AWS credentials via AWS CLI, environment variables, or IAM roles. + Requires ``boto3`` package. + +**Google Cloud Storage**: + Set up Google Cloud authentication via service account or gcloud CLI. + Requires ``google-cloud-storage`` package. + +**Azure Blob Storage**: + Configure Azure credentials via Azure CLI or environment variables. + Requires ``azure-storage-blob`` package. + +Error Handling +============== + +The system provides detailed error messages for common issues: + +* ``FileNotFoundError``: Local checkpoint file not found +* ``ValueError``: Unsupported URL scheme or no compatible loader found +* ``ImportError``: Required cloud storage library not installed +* Network errors: Connection timeouts, authentication failures + +Integration with Training +========================= + +This module integrates with the model loading system: + +.. code-block:: python + + # In training configuration + training: + checkpoint_loading: + source: "s3://my-bucket/pretrained.ckpt" + loader_type: "transfer_learning" + + # The training system automatically uses this module to fetch the checkpoint + +See Also +-------- +* :mod:`anemoi.training.utils.model_loading`: Model weight loading strategies +* :mod:`anemoi.training.train.modify`: Model modification system +* :mod:`anemoi.training.train.train`: Main training pipeline integration + +Notes +----- +* Remote checkpoints are downloaded to temporary files and cleaned up automatically +* Large checkpoints may take time to download; consider network bandwidth +* Cloud storage credentials must be properly configured for remote access +* The system respects PyTorch's ``weights_only=False`` for full checkpoint loading +""" + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from urllib.parse import urlparse + +LOGGER = logging.getLogger(__name__) + + +class CheckpointLoader(ABC): + """Abstract base class for loading model checkpoints from various sources. + + This class defines the interface that all checkpoint loaders must implement. + Concrete implementations handle specific source types (local files, remote URLs, etc.) + and provide the actual loading logic. + + The loader system uses a capability-based approach where each loader declares + which sources it can handle via the ``supports_source`` method, and the registry + automatically routes requests to appropriate loaders. + + Design Principles + ================= + + * **Single Responsibility**: Each loader handles one type of source + * **Capability Declaration**: Loaders explicitly declare supported sources + * **Consistent Interface**: All loaders return standard PyTorch checkpoint dictionaries + * **Error Transparency**: Loaders provide informative error messages for debugging + + Implementing Custom Loaders + ============================ + + To create a custom checkpoint loader: + + .. code-block:: python + + class MyCustomLoader(CheckpointLoader): + def supports_source(self, source: str | Path) -> bool: + # Return True if this loader can handle the source + return str(source).startswith("myprotocol://") + + def load_checkpoint(self, source: str | Path) -> dict: + # Implement your loading logic + # Must return a PyTorch checkpoint dictionary + return torch.load(processed_source, weights_only=False) + + Error Handling Guidelines + ========================= + + Implementations should raise appropriate exceptions: + + * ``FileNotFoundError``: When the source doesn't exist + * ``ValueError``: For malformed sources or unsupported formats + * ``ImportError``: When required dependencies are missing + * ``ConnectionError``: For network-related issues + + See Also + -------- + * :class:`LocalCheckpointLoader`: Loader for local filesystem checkpoints + * :class:`RemoteCheckpointLoader`: Loader for remote cloud/web checkpoints + * :class:`CheckpointLoaderRegistry`: Registry for managing multiple loaders + """ + + @abstractmethod + def load_checkpoint(self, source: str | Path) -> dict: + """Load a PyTorch checkpoint from the specified source. + + This method must be implemented by all concrete checkpoint loaders. + It should handle the specific loading logic for the loader's supported + source types and return a standard PyTorch checkpoint dictionary. + + The returned dictionary should contain at minimum a ``state_dict`` key + with the model parameters, and may include additional keys like + ``hyper_parameters``, ``optimizer_state_dict``, ``lr_scheduler_state_dict``, etc. + + Parameters + ---------- + source : str | Path + The checkpoint source to load from. This can be a local file path, + a remote URL, or any other identifier that the loader understands. + The format depends on the specific loader implementation. + + Returns + ------- + dict + A PyTorch checkpoint dictionary containing the loaded model state + and any additional metadata. The dictionary structure follows + PyTorch Lightning conventions: + + - ``state_dict``: Model parameters and buffers + - ``hyper_parameters``: Training hyperparameters (optional) + - ``optimizer_state_dict``: Optimizer state (optional) + - ``lr_scheduler_state_dict``: Learning rate scheduler state (optional) + - ``epoch``: Training epoch (optional) + - ``global_step``: Training step (optional) + + Raises + ------ + FileNotFoundError + If the checkpoint source doesn't exist or cannot be accessed. + ValueError + If the source format is invalid or the checkpoint data is corrupted. + ImportError + If required dependencies for loading from this source are not installed. + ConnectionError + If there are network issues accessing remote sources. + RuntimeError + If the checkpoint loading fails for any other reason. + + Examples + -------- + + .. code-block:: python + + loader = SomeCheckpointLoader() + checkpoint = loader.load_checkpoint("/path/to/model.ckpt") + + # Access the model state dict + model.load_state_dict(checkpoint["state_dict"]) + + # Access hyperparameters if available + if "hyper_parameters" in checkpoint: + config = checkpoint["hyper_parameters"] + """ + ... + + @abstractmethod + def supports_source(self, source: str | Path) -> bool: + """Check if this loader can handle the specified checkpoint source. + + This method determines whether the loader is capable of loading from + the given source. It's used by the registry system to automatically + select the appropriate loader for each source. + + Implementations should be fast and lightweight since this method + may be called multiple times during loader selection. + + Parameters + ---------- + source : str | Path + The checkpoint source to evaluate. This could be a file path, + URL, or any other source identifier. + + Returns + ------- + bool + True if this loader can handle the source, False otherwise. + + Notes + ----- + The method should be conservative in its checks - only return True + if the loader can definitely handle the source. It's better to + return False and let another loader handle it than to return True + and fail during loading. + + Examples + -------- + + .. code-block:: python + + # Local file loader + def supports_source(self, source): + return Path(source).exists() + + # S3 loader + def supports_source(self, source): + return str(source).startswith("s3://") + + # HTTP loader + def supports_source(self, source): + parsed = urlparse(str(source)) + return parsed.scheme in {"http", "https"} + """ + ... + + +class LocalCheckpointLoader(CheckpointLoader): + """Checkpoint loader for local filesystem access. + + This loader handles loading checkpoints from the local filesystem. It supports + any valid file system path and automatically detects when a source is a local + path versus a remote URL. + + Features + ======== + + * **Path Resolution**: Handles both string paths and pathlib.Path objects + * **Existence Checking**: Validates file existence before attempting to load + * **Cross-platform Support**: Works with Windows, Linux, and macOS file paths + * **Memory Efficiency**: Loads checkpoints directly to CPU to avoid GPU memory issues + + Usage Examples + ============== + + .. code-block:: python + + loader = LocalCheckpointLoader() + + # Check if loader supports a path + if loader.supports_source("/path/to/model.ckpt"): + checkpoint = loader.load_checkpoint("/path/to/model.ckpt") + + # Works with pathlib.Path objects too + from pathlib import Path + checkpoint_path = Path("/models/checkpoint.ckpt") + if loader.supports_source(checkpoint_path): + checkpoint = loader.load_checkpoint(checkpoint_path) + + Performance Considerations + ========================== + + * Large checkpoints load directly into CPU memory first + * File I/O performance depends on storage type (SSD vs HDD) + * Network-attached storage may be slower than local drives + + Error Handling + ============== + + Common errors and their meanings: + + * ``FileNotFoundError``: The checkpoint file doesn't exist at the specified path + * ``PermissionError``: Insufficient permissions to read the file + * ``OSError``: File system errors (corrupted file, I/O errors) + * ``RuntimeError``: PyTorch loading errors (corrupted checkpoint format) + """ + + def supports_source(self, source: str | Path) -> bool: + """Check if the source is a local file system path. + + This method determines if the given source represents a local file system + path rather than a remote URL. It uses heuristics to distinguish between + local paths and remote URLs. + + The method considers a source to be local if: + 1. It's already a pathlib.Path object + 2. It's a string that can be converted to a valid path + 3. The path exists on the local filesystem OR has no URL scheme + + Parameters + ---------- + source : str | Path + The source to check. Can be a file path string or Path object. + + Returns + ------- + bool + True if this loader can handle the source (i.e., it's a local path), + False if it appears to be a remote URL or invalid path. + + Examples + -------- + + .. code-block:: python + + loader = LocalCheckpointLoader() + + # These return True + loader.supports_source("/path/to/file.ckpt") # Unix path + loader.supports_source("C:\\models\\file.ckpt") # Windows path + loader.supports_source(Path("/models/file.ckpt")) # pathlib.Path + + # These return False + loader.supports_source("https://example.com/file.ckpt") # HTTP URL + loader.supports_source("s3://bucket/file.ckpt") # S3 URL + """ + if isinstance(source, Path): + return True + try: + path = Path(source) + return path.exists() or not urlparse(str(source)).scheme + except (ValueError, OSError): + return False + + def load_checkpoint(self, source: str | Path) -> dict: + """Load a PyTorch checkpoint from the local filesystem. + + This method loads a checkpoint file from the local filesystem using + PyTorch's standard loading mechanism. The checkpoint is loaded to CPU + memory first to avoid GPU memory issues. + + Parameters + ---------- + source : str | Path + The local file path to the checkpoint. Can be a string path or + pathlib.Path object. + + Returns + ------- + dict + The loaded PyTorch checkpoint dictionary containing model weights + and any additional metadata saved during training. + + Raises + ------ + FileNotFoundError + If the checkpoint file doesn't exist at the specified path. + PermissionError + If there are insufficient permissions to read the file. + RuntimeError + If PyTorch fails to load the checkpoint (e.g., corrupted file). + OSError + If there are file system I/O errors. + + Examples + -------- + + .. code-block:: python + + loader = LocalCheckpointLoader() + + # Load checkpoint from string path + checkpoint = loader.load_checkpoint("/path/to/model.ckpt") + model.load_state_dict(checkpoint["state_dict"]) + + # Load checkpoint from pathlib.Path + from pathlib import Path + checkpoint_path = Path("/models/latest.ckpt") + checkpoint = loader.load_checkpoint(checkpoint_path) + + Notes + ----- + * The checkpoint is loaded with ``weights_only=False`` to support full + PyTorch Lightning checkpoints with metadata + * Loading is done with ``map_location="cpu"`` to ensure compatibility + across different hardware configurations + * Large checkpoints may take time to load depending on storage speed + """ + import torch + + path = Path(source) + if not path.exists(): + msg = f"Checkpoint not found: {path}" + raise FileNotFoundError(msg) + + LOGGER.info("Loading checkpoint from local path: %s", path) + return torch.load(path, weights_only=False, map_location="cpu") + + +class RemoteCheckpointLoader(CheckpointLoader): + """Checkpoint loader for remote sources including cloud storage and web URLs. + + This loader handles downloading and loading checkpoints from various remote + sources including cloud storage providers (AWS S3, Google Cloud Storage, + Azure Blob Storage) and web servers (HTTP/HTTPS). + + Supported Protocols + =================== + + * **HTTP/HTTPS**: Standard web servers (``https://example.com/model.ckpt``) + * **Amazon S3**: S3 buckets (``s3://bucket-name/path/to/checkpoint.ckpt``) + * **Google Cloud Storage**: GCS buckets (``gs://bucket-name/path/to/checkpoint.ckpt``) + * **Azure Blob Storage**: Azure containers (``azure://account.blob.core.windows.net/container/file.ckpt``) + + Authentication Requirements + =========================== + + **AWS S3**: Requires proper AWS credentials configured via: + - AWS CLI (``aws configure``) + - Environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``) + - IAM roles (for EC2 instances) + - Requires ``boto3`` package + + **Google Cloud Storage**: Requires Google Cloud authentication via: + - Service account key file + - gcloud CLI authentication + - Application default credentials + - Requires ``google-cloud-storage`` package + + **Azure Blob Storage**: Requires Azure credentials via: + - Azure CLI authentication + - Environment variables (``AZURE_STORAGE_CONNECTION_STRING``) + - Managed identity (for Azure resources) + - Requires ``azure-storage-blob`` package + + Features + ======== + + * **Automatic Cleanup**: Downloads to temporary files that are automatically cleaned up + * **Streaming Support**: Efficient handling of large checkpoint files + * **Error Recovery**: Detailed error messages for troubleshooting authentication and connectivity issues + * **Cloud Provider Integration**: Native integration with major cloud storage APIs + + Usage Examples + ============== + + .. code-block:: python + + loader = RemoteCheckpointLoader() + + # Load from S3 + if loader.supports_source("s3://my-bucket/models/checkpoint.ckpt"): + checkpoint = loader.load_checkpoint("s3://my-bucket/models/checkpoint.ckpt") + + # Load from HTTP + if loader.supports_source("https://example.com/public/model.ckpt"): + checkpoint = loader.load_checkpoint("https://example.com/public/model.ckpt") + + # Load from Google Cloud Storage + if loader.supports_source("gs://my-gcs-bucket/checkpoints/model.ckpt"): + checkpoint = loader.load_checkpoint("gs://my-gcs-bucket/checkpoints/model.ckpt") + + Performance Considerations + ========================== + + * **Network Bandwidth**: Download speed depends on internet connection and provider + * **File Size**: Large checkpoints (>1GB) may take significant time to download + * **Temporary Storage**: Requires sufficient local disk space for temporary files + * **Cloud Costs**: Downloads from cloud storage may incur egress charges + + Error Handling + ============== + + Common errors and their meanings: + + * ``ImportError``: Required cloud storage library not installed + * ``ConnectionError``: Network connectivity issues + * ``AuthenticationError``: Invalid or missing cloud credentials + * ``FileNotFoundError``: Checkpoint doesn't exist at the specified URL + * ``PermissionError``: Insufficient permissions to access the resource + + Troubleshooting + =============== + + **S3 Access Issues**: + Verify AWS credentials with: ``aws sts get-caller-identity`` + + **GCS Access Issues**: + Verify credentials with: ``gcloud auth list`` + + **Azure Access Issues**: + Verify credentials with: ``az account show`` + + **Network Issues**: + Test connectivity with curl or wget to the URL + """ + + def supports_source(self, source: str | Path) -> bool: + """Check if the source is a supported remote URL. + + This method determines if the given source is a remote URL that this + loader can handle. It checks the URL scheme against the list of supported + protocols. + + Supported URL schemes: + - ``http://`` and ``https://`` for web servers + - ``s3://`` for Amazon S3 + - ``gs://`` and ``gcs://`` for Google Cloud Storage + - ``azure://`` and ``az://`` for Azure Blob Storage + + Parameters + ---------- + source : str | Path + The source to check. Should be a URL string for remote sources. + Path objects are not supported by this loader. + + Returns + ------- + bool + True if the source is a supported remote URL, False otherwise. + + Examples + -------- + + .. code-block:: python + + loader = RemoteCheckpointLoader() + + # These return True + loader.supports_source("https://example.com/model.ckpt") + loader.supports_source("s3://bucket/path/checkpoint.ckpt") + loader.supports_source("gs://bucket/models/checkpoint.ckpt") + loader.supports_source("azure://account.blob.core.windows.net/container/model.ckpt") + + # These return False + loader.supports_source("/local/path/model.ckpt") + loader.supports_source("ftp://server/model.ckpt") # Unsupported protocol + loader.supports_source(Path("/local/model.ckpt")) # Local Path object + """ + if isinstance(source, Path): + return False + try: + parsed = urlparse(str(source)) + return bool(parsed.scheme and parsed.scheme in {"http", "https", "s3", "gs", "azure"}) + except (ValueError, OSError): + return False + + def load_checkpoint(self, source: str | Path) -> dict: + """Load a PyTorch checkpoint from a remote source. + + This method downloads the checkpoint from the remote source to a temporary + file and then loads it using PyTorch's standard loading mechanism. The + temporary file is automatically cleaned up after loading. + + The loading process involves: + 1. Creating a secure temporary file + 2. Downloading the checkpoint to the temporary file + 3. Loading the checkpoint with PyTorch + 4. Cleaning up the temporary file + + Parameters + ---------- + source : str | Path + The remote URL to download the checkpoint from. Must be a supported + URL scheme (http, https, s3, gs, azure). + + Returns + ------- + dict + The loaded PyTorch checkpoint dictionary containing model weights + and any additional metadata. + + Raises + ------ + ImportError + If the required cloud storage library is not installed (e.g., boto3 + for S3, google-cloud-storage for GCS). + ConnectionError + If there are network connectivity issues during download. + FileNotFoundError + If the checkpoint doesn't exist at the specified remote location. + PermissionError + If there are insufficient permissions to access the remote resource. + ValueError + If the URL scheme is unsupported or malformed. + RuntimeError + If PyTorch fails to load the downloaded checkpoint. + + Examples + -------- + + .. code-block:: python + + loader = RemoteCheckpointLoader() + + # Load from S3 (requires boto3 and AWS credentials) + checkpoint = loader.load_checkpoint("s3://my-bucket/models/checkpoint.ckpt") + model.load_state_dict(checkpoint["state_dict"]) + + # Load from HTTPS + checkpoint = loader.load_checkpoint("https://example.com/public/model.ckpt") + + # Load from Google Cloud Storage (requires google-cloud-storage) + checkpoint = loader.load_checkpoint("gs://my-gcs-bucket/models/checkpoint.ckpt") + + Notes + ----- + * Large checkpoints may take significant time to download + * Temporary files are stored in the system's temporary directory + * The method requires sufficient local disk space for the checkpoint + * Cloud storage access may incur data transfer costs + * All downloads are performed to CPU memory first for compatibility + """ + import tempfile + from pathlib import Path + + import torch + + LOGGER.info("Loading checkpoint from remote source: %s", source) + + # Create temporary file to download checkpoint + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + self._download_checkpoint(str(source), tmp_path) + return torch.load(tmp_path, weights_only=False, map_location="cpu") + finally: + # Clean up temporary file + if tmp_path.exists(): + tmp_path.unlink() + + def _download_checkpoint(self, source: str, dest_path: Path) -> None: + """Download checkpoint from remote source to local temporary file. + + This internal method routes the download request to the appropriate + protocol-specific download method based on the URL scheme. + + Parameters + ---------- + source : str + The remote URL to download from. + dest_path : Path + The local file path to download to. + + Raises + ------ + ValueError + If the URL scheme is not supported. + """ + parsed = urlparse(source) + + if parsed.scheme in {"http", "https"}: + self._download_http(source, dest_path) + elif parsed.scheme == "s3": + self._download_s3(source, dest_path) + elif parsed.scheme in {"gs", "gcs"}: + self._download_gcs(source, dest_path) + elif parsed.scheme in {"azure", "az"}: + self._download_azure(source, dest_path) + else: + msg = f"Unsupported remote scheme: {parsed.scheme}" + raise ValueError(msg) + + def _download_http(self, url: str, dest_path: Path) -> None: + """Download checkpoint from HTTP/HTTPS server. + + Uses urllib.request to download files from web servers. This method + supports both HTTP and HTTPS protocols and handles redirects automatically. + + Parameters + ---------- + url : str + The HTTP/HTTPS URL to download from. + dest_path : Path + The local file path to save the downloaded file. + + Raises + ------ + ConnectionError + If there are network connectivity issues. + FileNotFoundError + If the URL returns a 404 or similar error. + PermissionError + If the server returns a 403 or similar authorization error. + + Notes + ----- + This method uses urllib.request.urlretrieve which may not be suitable + for very large files due to memory usage. For production use with large + checkpoints, consider implementing streaming download. + """ + import urllib.request + + urllib.request.urlretrieve(url, dest_path) # noqa: S310 + + def _download_s3(self, s3_url: str, dest_path: Path) -> None: + """Download checkpoint from Amazon S3. + + Uses boto3 to download files from S3 buckets. Requires proper AWS + credentials to be configured. + + Parameters + ---------- + s3_url : str + The S3 URL in format: s3://bucket-name/path/to/file + dest_path : Path + The local file path to save the downloaded file. + + Raises + ------ + ImportError + If boto3 is not installed. + FileNotFoundError + If the S3 object doesn't exist. + PermissionError + If there are insufficient S3 permissions. + ConnectionError + If there are network or AWS service issues. + + Notes + ----- + * Requires AWS credentials configured via AWS CLI, environment variables, or IAM roles + * Large files are downloaded efficiently using boto3's streaming capabilities + * Data transfer charges may apply depending on your AWS configuration + """ + try: + import boto3 + except ImportError as e: + msg = "boto3 required for S3 downloads. Install with: pip install boto3" + raise ImportError(msg) from e + + parsed = urlparse(s3_url) + bucket = parsed.netloc + key = parsed.path.lstrip("/") + + s3_client = boto3.client("s3") + s3_client.download_file(bucket, key, str(dest_path)) + + def _download_gcs(self, gcs_url: str, dest_path: Path) -> None: + """Download checkpoint from Google Cloud Storage. + + Uses google-cloud-storage client library to download files from GCS buckets. + Requires proper Google Cloud authentication to be configured. + + Parameters + ---------- + gcs_url : str + The GCS URL in format: gs://bucket-name/path/to/file + dest_path : Path + The local file path to save the downloaded file. + + Raises + ------ + ImportError + If google-cloud-storage is not installed. + FileNotFoundError + If the GCS object doesn't exist. + PermissionError + If there are insufficient GCS permissions. + ConnectionError + If there are network or GCS service issues. + + Notes + ----- + * Requires Google Cloud credentials via service account, gcloud CLI, or application default credentials + * Large files are downloaded efficiently using GCS client streaming + * Data transfer charges may apply depending on your GCS configuration + """ + try: + from google.cloud import storage + except ImportError as e: + msg = "google-cloud-storage required for GCS downloads. Install with: pip install google-cloud-storage" + raise ImportError(msg) from e + + parsed = urlparse(gcs_url) + bucket_name = parsed.netloc + blob_name = parsed.path.lstrip("/") + + client = storage.Client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + blob.download_to_filename(str(dest_path)) + + def _download_azure(self, azure_url: str, dest_path: Path) -> None: + """Download checkpoint from Azure Blob Storage. + + Uses azure-storage-blob client library to download files from Azure + blob containers. Requires proper Azure authentication to be configured. + + Parameters + ---------- + azure_url : str + The Azure URL in format: azure://account.blob.core.windows.net/container/path/to/file + dest_path : Path + The local file path to save the downloaded file. + + Raises + ------ + ImportError + If azure-storage-blob is not installed. + FileNotFoundError + If the Azure blob doesn't exist. + PermissionError + If there are insufficient Azure permissions. + ConnectionError + If there are network or Azure service issues. + + Notes + ----- + * Requires Azure credentials via Azure CLI, environment variables, or managed identity + * Large files are downloaded efficiently using Azure client streaming + * Data transfer charges may apply depending on your Azure configuration + * URL format: azure://storageaccount.blob.core.windows.net/container/blob/path + """ + try: + from azure.storage.blob import BlobServiceClient + except ImportError as e: + msg = "azure-storage-blob required for Azure downloads. Install with: pip install azure-storage-blob" + raise ImportError(msg) from e + + # Parse Azure URL format: azure://account.blob.core.windows.net/container/blob + parsed = urlparse(azure_url) + account_url = f"https://{parsed.netloc}" + container_name = parsed.path.split("/")[1] + blob_name = "/".join(parsed.path.split("/")[2:]) + + blob_service_client = BlobServiceClient(account_url=account_url) + blob_client = blob_service_client.get_blob_client(container=container_name, blob=blob_name) + + dest_path.write_bytes(blob_client.download_blob().readall()) + + +class CheckpointLoaderRegistry: + """Central registry for managing checkpoint loaders with automatic routing. + + This registry maintains a collection of checkpoint loaders and automatically + selects the appropriate loader for each checkpoint source. It provides a + unified interface for loading checkpoints from any supported source. + + The registry uses a first-match strategy where loaders are checked in the + order they were registered. The first loader that reports it can handle + a source via ``supports_source()`` will be used to load that checkpoint. + + Features + ======== + + * **Automatic Selection**: Automatically chooses the right loader for each source + * **Extensible**: Easy registration of custom loaders for new source types + * **Pre-configured**: Comes with default loaders for common sources + * **Error Handling**: Clear error messages when no suitable loader is found + + Default Loaders + =============== + + The registry comes pre-configured with: + 1. ``RemoteCheckpointLoader`` - for S3, HTTP, GCS, Azure sources + 2. ``LocalCheckpointLoader`` - for local filesystem paths + + Usage Examples + ============== + + .. code-block:: python + + # Use the global registry + from anemoi.training.utils.checkpoint_loaders import checkpoint_registry + + checkpoint = checkpoint_registry.load_checkpoint("/path/to/local.ckpt") + checkpoint = checkpoint_registry.load_checkpoint("s3://bucket/remote.ckpt") + + # Create a custom registry + custom_registry = CheckpointLoaderRegistry() + custom_registry.register(MyCustomLoader()) + checkpoint = custom_registry.load_checkpoint("custom://source") + + Custom Loaders + ============== + + To add support for custom checkpoint sources: + + .. code-block:: python + + class FTPCheckpointLoader(CheckpointLoader): + def supports_source(self, source): + return str(source).startswith("ftp://") + + def load_checkpoint(self, source): + # Custom FTP loading logic + return load_from_ftp(source) + + # Register the custom loader + checkpoint_registry.register(FTPCheckpointLoader()) + + # Now FTP URLs are supported + checkpoint = checkpoint_registry.load_checkpoint("ftp://server/model.ckpt") + + Performance Considerations + ========================== + + * Loader selection is O(n) where n is the number of registered loaders + * For best performance, register more specific loaders first + * The ``supports_source`` method should be lightweight + + Thread Safety + ============= + + The registry is not thread-safe for registration operations. If you need + to register loaders from multiple threads, use appropriate synchronization. + Loading operations are thread-safe as long as individual loaders are thread-safe. + """ + + def __init__(self) -> None: + """Initialize the registry with default checkpoint loaders. + + The registry is pre-configured with loaders for the most common + checkpoint sources: + - RemoteCheckpointLoader for cloud storage and HTTP + - LocalCheckpointLoader for local filesystem + """ + self._loaders: list[CheckpointLoader] = [] + # Register default loaders - order matters for selection priority + self.register(RemoteCheckpointLoader()) + self.register(LocalCheckpointLoader()) + + def register(self, loader: CheckpointLoader) -> None: + """Register a new checkpoint loader with the registry. + + Loaders are checked in registration order when selecting a loader + for a source. Register more specific loaders before more general ones + for optimal performance. + + Parameters + ---------- + loader : CheckpointLoader + The checkpoint loader instance to register. + + Examples + -------- + + .. code-block:: python + + registry = CheckpointLoaderRegistry() + + # Register custom loaders + registry.register(FTPCheckpointLoader()) + registry.register(DatabaseCheckpointLoader()) + + # More specific loaders should be registered first + registry.register(SpecialS3Loader()) # Handles special S3 cases + registry.register(GeneralS3Loader()) # Handles general S3 cases + """ + self._loaders.append(loader) + + def get_loader(self, source: str | Path) -> CheckpointLoader: + """Get the appropriate loader for the specified checkpoint source. + + This method iterates through registered loaders in registration order + and returns the first loader that reports it can handle the source. + + Parameters + ---------- + source : str | Path + The checkpoint source to find a loader for. + + Returns + ------- + CheckpointLoader + The first loader that can handle the specified source. + + Raises + ------ + ValueError + If no registered loader can handle the source. + + Examples + -------- + + .. code-block:: python + + registry = CheckpointLoaderRegistry() + + # Get loader for local file + loader = registry.get_loader("/path/to/checkpoint.ckpt") + assert isinstance(loader, LocalCheckpointLoader) + + # Get loader for S3 URL + loader = registry.get_loader("s3://bucket/checkpoint.ckpt") + assert isinstance(loader, RemoteCheckpointLoader) + """ + for loader in self._loaders: + if loader.supports_source(source): + return loader + msg = f"No loader found for source: {source}" + raise ValueError(msg) + + def load_checkpoint(self, source: str | Path) -> dict: + """Load a checkpoint using the appropriate loader. + + This is a convenience method that combines loader selection and + checkpoint loading in one call. It automatically selects the + appropriate loader and uses it to load the checkpoint. + + Parameters + ---------- + source : str | Path + The checkpoint source to load from. + + Returns + ------- + dict + The loaded PyTorch checkpoint dictionary. + + Raises + ------ + ValueError + If no loader can handle the source. + FileNotFoundError + If the checkpoint doesn't exist. + ImportError + If required dependencies are missing. + ConnectionError + If there are network issues for remote sources. + + Examples + -------- + + .. code-block:: python + + registry = CheckpointLoaderRegistry() + + # Load from any supported source + checkpoint = registry.load_checkpoint("/local/model.ckpt") + checkpoint = registry.load_checkpoint("s3://bucket/model.ckpt") + checkpoint = registry.load_checkpoint("https://example.com/model.ckpt") + + # All return the same format + model.load_state_dict(checkpoint["state_dict"]) + """ + loader = self.get_loader(source) + return loader.load_checkpoint(source) + + +# Global registry instance pre-configured with default loaders +#: CheckpointLoaderRegistry: Global checkpoint loader registry. +#: +#: This registry comes pre-configured with loaders for the most common +#: checkpoint sources and is used by :func:`load_checkpoint_from_source`. +#: Custom loaders can be registered with this global instance: +#: +#: .. code-block:: python +#: +#: from anemoi.training.utils.checkpoint_loaders import checkpoint_registry +#: checkpoint_registry.register(MyCustomLoader()) +checkpoint_registry = CheckpointLoaderRegistry() + + +def load_checkpoint_from_source(source: str | Path) -> dict: + """Load a PyTorch checkpoint from any supported source using the global registry. + + This is the main entry point for loading checkpoints in the Anemoi training + system. It provides a simple, unified interface for loading checkpoints from + any supported source type. + + The function automatically detects the source type and selects the appropriate + loader from the global registry. It supports local files, cloud storage, and + web URLs without requiring the caller to know which specific loader to use. + + Supported Sources + ================= + + * **Local files**: ``/path/to/checkpoint.ckpt``, ``./models/checkpoint.ckpt`` + * **HTTP/HTTPS**: ``https://example.com/models/checkpoint.ckpt`` + * **Amazon S3**: ``s3://my-bucket/models/checkpoint.ckpt`` + * **Google Cloud Storage**: ``gs://my-bucket/models/checkpoint.ckpt`` + * **Azure Blob Storage**: ``azure://account.blob.core.windows.net/container/checkpoint.ckpt`` + + Parameters + ---------- + source : str | Path + The checkpoint source to load from. Can be a local file path, remote URL, + or cloud storage URL. Both string paths and pathlib.Path objects are supported + for local files. + + Returns + ------- + dict + A PyTorch checkpoint dictionary containing the model state and metadata. + The dictionary typically includes: + + - ``state_dict``: Model parameters and buffers + - ``hyper_parameters``: Training configuration (optional) + - ``optimizer_state_dict``: Optimizer state (optional) + - ``lr_scheduler_state_dict``: Learning rate scheduler state (optional) + - ``epoch``: Training epoch number (optional) + - ``global_step``: Training step number (optional) + + Raises + ------ + ValueError + If the source format is unsupported or no loader can handle it. + FileNotFoundError + If the checkpoint file or URL doesn't exist. + ImportError + If required cloud storage libraries are not installed (e.g., boto3 for S3). + ConnectionError + If there are network issues accessing remote sources. + PermissionError + If there are insufficient permissions to access the source. + RuntimeError + If PyTorch fails to load the checkpoint data. + + Examples + -------- + Basic usage with different source types: + + .. code-block:: python + + from anemoi.training.utils.checkpoint_loaders import load_checkpoint_from_source + + # Load from local file + checkpoint = load_checkpoint_from_source("/path/to/model.ckpt") + model.load_state_dict(checkpoint["state_dict"]) + + # Load from S3 (requires boto3 and AWS credentials) + checkpoint = load_checkpoint_from_source("s3://my-bucket/models/pretrained.ckpt") + + # Load from HTTPS + checkpoint = load_checkpoint_from_source("https://example.com/public/model.ckpt") + + # Load from Google Cloud Storage (requires google-cloud-storage) + checkpoint = load_checkpoint_from_source("gs://my-gcs-bucket/checkpoints/model.ckpt") + + Integration with model loading: + + .. code-block:: python + + # Direct model loading + checkpoint = load_checkpoint_from_source(source_path) + if "state_dict" in checkpoint: + model.load_state_dict(checkpoint["state_dict"]) + + # Access hyperparameters if available + if "hyper_parameters" in checkpoint: + config = checkpoint["hyper_parameters"] + print(f"Model was trained for {checkpoint.get('epoch', 'unknown')} epochs") + + Error handling: + + .. code-block:: python + + try: + checkpoint = load_checkpoint_from_source("s3://bucket/model.ckpt") + except ImportError: + print("boto3 not installed - cannot load from S3") + except FileNotFoundError: + print("Checkpoint not found at specified location") + except ConnectionError: + print("Network issues - check connectivity and credentials") + + Notes + ----- + * This function uses the global ``checkpoint_registry`` which comes pre-configured + with loaders for common source types + * For cloud storage, ensure appropriate authentication is configured + * Large checkpoints may take time to download from remote sources + * All checkpoints are loaded to CPU memory first for compatibility + * The function is thread-safe for loading operations + + See Also + -------- + * :class:`CheckpointLoaderRegistry`: For custom loader registration + * :class:`CheckpointLoader`: For implementing custom loaders + * :mod:`anemoi.training.utils.model_loading`: For model-specific loading strategies + """ + return checkpoint_registry.load_checkpoint(source) diff --git a/training/src/anemoi/training/utils/model_loading.py b/training/src/anemoi/training/utils/model_loading.py new file mode 100644 index 000000000..b07f89dbf --- /dev/null +++ b/training/src/anemoi/training/utils/model_loading.py @@ -0,0 +1,1482 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Model Loading System for Anemoi Training. + +This module provides a flexible model loading system that supports different +strategies for loading PyTorch model weights from checkpoints. It handles +various scenarios including standard loading, transfer learning with size +mismatches, and weights-only loading. + +Key Features +============ + +* **Multiple Loading Strategies**: Support for standard, transfer learning, and weights-only loading +* **Size Mismatch Handling**: Intelligent handling of parameter size mismatches during transfer learning +* **Metadata Preservation**: Maintains important model metadata and hyperparameters +* **Registry Pattern**: Extensible loader registry for custom loading strategies +* **Integration**: Seamless integration with the checkpoint loading system + +Architecture +============ + +The system uses a strategy pattern with abstract base classes: + +1. **ModelLoader**: Abstract base class defining the loader interface +2. **Concrete Loaders**: Specific loading strategies (standard, transfer learning, weights-only) +3. **ModelLoaderRegistry**: Central registry managing different loading strategies +4. **Global Registry**: Pre-configured registry instance ready for immediate use + +Supported Loading Strategies +============================ + +* **Standard**: Direct loading for exact architecture matches +* **Transfer Learning**: Loading with intelligent size mismatch handling +* **Weights Only**: Loading only model parameters, skipping optimizer states + +Basic Usage +=========== + +.. code-block:: python + + from anemoi.training.utils.model_loading import load_model_from_checkpoint + + # Standard loading (exact architecture match required) + model = load_model_from_checkpoint( + model=my_model, + checkpoint_source="/path/to/checkpoint.ckpt", + loader_type="standard" + ) + + # Transfer learning (handles size mismatches) + model = load_model_from_checkpoint( + model=my_model, + checkpoint_source="s3://bucket/pretrained.ckpt", + loader_type="transfer_learning" + ) + + # Weights only (ignores optimizer state) + model = load_model_from_checkpoint( + model=my_model, + checkpoint_source="https://example.com/weights.ckpt", + loader_type="weights_only" + ) + +Advanced Usage +============== + +.. code-block:: python + + from anemoi.training.utils.model_loading import ( + ModelLoaderRegistry, + TransferLearningModelLoader, + model_loader_registry + ) + + # Use registry directly + loader = model_loader_registry.get_loader("transfer_learning") + model = loader.load_model_weights( + model=my_model, + checkpoint_source="s3://bucket/model.ckpt", + strict=False, + skip_mismatched=True + ) + + # Custom registry + custom_registry = ModelLoaderRegistry() + custom_registry.register("my_loader", MyCustomLoader()) + +Extending the System +==================== + +To add custom loading strategies: + +.. code-block:: python + + class QuantizedModelLoader(ModelLoader): + def load_model_weights(self, model, checkpoint_source, **kwargs): + # Custom quantized loading logic + checkpoint = load_checkpoint_from_source(checkpoint_source) + # Apply quantization during loading + return quantize_and_load(model, checkpoint) + + # Register with the global registry + model_loader_registry.register("quantized", QuantizedModelLoader()) + +Transfer Learning Details +========================= + +The transfer learning loader provides several options: + +* **Size Mismatch Handling**: Automatically skips parameters with incompatible shapes +* **Selective Loading**: Option to skip parameters not present in the target model +* **Logging**: Detailed logging of skipped parameters for debugging + +.. code-block:: python + + # Transfer learning with custom options + model = load_model_from_checkpoint( + model=target_model, + checkpoint_source="pretrained.ckpt", + loader_type="transfer_learning", + strict=False, # Allow missing keys + skip_mismatched=True # Skip size mismatches + ) + +Integration with Training +========================= + +This module integrates with the training pipeline: + +.. code-block:: python + + # In training configuration + training: + checkpoint_loading: + source: "s3://bucket/pretrained.ckpt" + loader_type: "transfer_learning" + strict: false + skip_mismatched: true + + # The training system uses this module to load model weights + +See Also +-------- +* :mod:`anemoi.training.utils.checkpoint_loaders`: Checkpoint source handling +* :mod:`anemoi.training.train.modify`: Model modification system +* :mod:`anemoi.training.train.train`: Main training pipeline integration + +Notes +----- +* All loaders preserve model metadata when available +* Transfer learning loader provides detailed logging for debugging +* The system is designed to be thread-safe for loading operations +* Custom loaders should handle errors gracefully and provide informative messages +""" + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING + +from anemoi.training.utils.checkpoint_loaders import load_checkpoint_from_source + +if TYPE_CHECKING: + from pathlib import Path + + import torch + +LOGGER = logging.getLogger(__name__) + + +class ModelLoader(ABC): + """Abstract base class for loading PyTorch model weights from checkpoints. + + This class defines the interface that all model loaders must implement. + Concrete implementations handle different loading strategies such as + standard loading, transfer learning, or specialized loading requirements. + + The loader system provides a unified interface for loading model weights + while allowing different strategies for handling various scenarios like + architecture mismatches, selective parameter loading, or custom preprocessing. + + Design Principles + ================= + + * **Strategy Pattern**: Each loader implements a specific loading strategy + * **Consistent Interface**: All loaders use the same method signature + * **Error Handling**: Loaders provide clear error messages for debugging + * **Metadata Preservation**: Important model metadata is preserved when possible + + Implementing Custom Loaders + ============================ + + To create a custom model loader: + + .. code-block:: python + + class MyCustomLoader(ModelLoader): + def load_model_weights(self, model, checkpoint_source, strict=True, **kwargs): + # Load checkpoint + checkpoint = load_checkpoint_from_source(checkpoint_source) + + # Custom processing logic + processed_state_dict = self.preprocess_weights(checkpoint["state_dict"]) + + # Load into model + model.load_state_dict(processed_state_dict, strict=strict) + + return model + + Error Handling Guidelines + ========================= + + Implementations should raise appropriate exceptions: + + * ``ValueError``: For invalid checkpoint format or missing required keys + * ``RuntimeError``: For PyTorch loading errors or shape mismatches + * ``FileNotFoundError``: When checkpoint source cannot be accessed + * ``KeyError``: When required checkpoint keys are missing + + See Also + -------- + * :class:`StandardModelLoader`: Standard PyTorch Lightning checkpoint loading + * :class:`TransferLearningModelLoader`: Loading with size mismatch handling + * :class:`WeightsOnlyModelLoader`: Loading only model parameters + * :class:`ModelLoaderRegistry`: Registry for managing multiple loaders + """ + + @abstractmethod + def load_model_weights( + self, + model: torch.nn.Module, + checkpoint_source: str | Path, + strict: bool = True, + **kwargs, + ) -> torch.nn.Module: + """Load model weights from a checkpoint source into the target model. + + This method must be implemented by all concrete model loaders. + It should handle loading weights from the checkpoint into the provided + model using the specific strategy implemented by the loader. + + The method is responsible for: + 1. Loading the checkpoint data via the checkpoint loading system + 2. Extracting and processing the model state dictionary + 3. Loading the processed weights into the target model + 4. Preserving any relevant metadata + 5. Returning the model with loaded weights + + Parameters + ---------- + model : torch.nn.Module + The target PyTorch model to load weights into. The model should + already be instantiated with the desired architecture. + checkpoint_source : str | Path + The checkpoint source to load from. This can be a local file path, + remote URL, or cloud storage URL. The source will be handled by + the checkpoint loading system. + strict : bool, optional + Whether to strictly enforce that the keys in the checkpoint's + state_dict match the keys returned by the model's state_dict(). + Default is True. When False, allows loading with missing keys. + **kwargs + Additional keyword arguments specific to the loader implementation. + Common options include: + + - ``skip_mismatched`` (bool): Skip parameters with size mismatches + - ``prefix`` (str): Add prefix to checkpoint parameter names + - ``map_location`` (str): Device to map tensors to during loading + + Returns + ------- + torch.nn.Module + The input model with loaded weights. The same object is returned + for chaining, but the model's parameters are updated in-place. + + Raises + ------ + ValueError + If the checkpoint format is invalid or missing required keys like 'state_dict'. + RuntimeError + If PyTorch fails to load the state dict due to shape mismatches or other issues. + FileNotFoundError + If the checkpoint source cannot be accessed or doesn't exist. + KeyError + If required keys are missing from the checkpoint dictionary. + ImportError + If required dependencies for accessing the checkpoint source are missing. + + Examples + -------- + + .. code-block:: python + + # Basic usage + loader = SomeModelLoader() + loaded_model = loader.load_model_weights( + model=my_model, + checkpoint_source="/path/to/checkpoint.ckpt" + ) + + # With custom options + loaded_model = loader.load_model_weights( + model=my_model, + checkpoint_source="s3://bucket/checkpoint.ckpt", + strict=False, + skip_mismatched=True + ) + + Notes + ----- + * The model is modified in-place; the returned model is the same object + * Implementations should preserve model metadata when available in checkpoints + * Logging should be used to provide feedback on the loading process + * The method should be thread-safe for concurrent loading operations + """ + ... + + +class StandardModelLoader(ModelLoader): + """Standard model loader for PyTorch Lightning checkpoints. + + This loader handles the most common case of loading model weights from + PyTorch Lightning checkpoints where the source and target models have + identical architectures. It performs direct state dict loading with + optional strict mode control. + + Use Cases + ========= + + * **Resuming Training**: Loading from a checkpoint to continue training + * **Model Evaluation**: Loading trained weights for inference + * **Exact Architecture Match**: When source and target models are identical + * **Production Deployment**: Loading stable, validated model weights + + Features + ======== + + * **Strict Validation**: Enforces exact parameter key matching by default + * **Metadata Preservation**: Maintains model metadata and hyperparameters + * **Data Indices Support**: Preserves data indexing information for compatibility + * **Lightning Integration**: Full support for PyTorch Lightning checkpoint format + + Behavior + ======== + + 1. Loads checkpoint from any supported source (local, S3, HTTP, etc.) + 2. Validates that 'state_dict' key exists in checkpoint + 3. Performs direct state dict loading into the model + 4. Preserves metadata including data indices if available + 5. Logs successful loading for debugging + + Usage Examples + ============== + + .. code-block:: python + + loader = StandardModelLoader() + + # Standard loading with strict validation + model = loader.load_model_weights( + model=my_model, + checkpoint_source="/path/to/checkpoint.ckpt", + strict=True + ) + + # Relaxed loading (allows missing keys) + model = loader.load_model_weights( + model=my_model, + checkpoint_source="s3://bucket/checkpoint.ckpt", + strict=False + ) + + Error Handling + ============== + + The loader will raise errors for: + + * Missing 'state_dict' in checkpoint + * Parameter shape mismatches (in strict mode) + * Missing parameters in model (in strict mode) + * Unexpected parameters in checkpoint (in strict mode) + + When Not to Use + =============== + + This loader is not suitable for: + + * Transfer learning between different architectures + * Loading from models with different parameter shapes + * Selective parameter loading + * Custom preprocessing of checkpoint data + + For these scenarios, use :class:`TransferLearningModelLoader` instead. + """ + + def load_model_weights( + self, + model: torch.nn.Module, + checkpoint_source: str | Path, + strict: bool = True, + **kwargs, # noqa: ARG002 + ) -> torch.nn.Module: + """Load PyTorch Lightning checkpoint weights using standard loading. + + This method performs direct state dict loading from a PyTorch Lightning + checkpoint. It expects the checkpoint to contain a 'state_dict' key with + model parameters that exactly match the target model architecture. + + The loading process: + 1. Loads checkpoint from the specified source + 2. Validates checkpoint format and required keys + 3. Extracts the state dictionary + 4. Loads state dict into model with optional strict validation + 5. Preserves important metadata for model compatibility + + Parameters + ---------- + model : torch.nn.Module + The target model to load weights into. Must have the same architecture + as the model that created the checkpoint. + checkpoint_source : str | Path + Path or URL to the checkpoint file. Supports local files, S3, HTTP, + Google Cloud Storage, and Azure Blob Storage. + strict : bool, optional + Whether to strictly enforce that the keys in the checkpoint state_dict + match the keys returned by the model's state_dict(). Default is True. + When False, allows loading with missing keys but still fails on + unexpected keys or shape mismatches. + **kwargs + Additional keyword arguments. Currently unused but maintained for + interface compatibility. + + Returns + ------- + torch.nn.Module + The input model with loaded weights. The model is modified in-place. + + Raises + ------ + ValueError + If the checkpoint doesn't contain a 'state_dict' key. + RuntimeError + If PyTorch fails to load the state dict due to architecture mismatch, + shape mismatch, or missing/unexpected parameters (when strict=True). + FileNotFoundError + If the checkpoint source cannot be accessed. + KeyError + If required checkpoint structure is malformed. + + Examples + -------- + + .. code-block:: python + + loader = StandardModelLoader() + + # Load from local file with strict validation + model = loader.load_model_weights( + model=my_lightning_model, + checkpoint_source="./checkpoints/epoch_10.ckpt", + strict=True + ) + + # Load from S3 with relaxed validation + model = loader.load_model_weights( + model=my_lightning_model, + checkpoint_source="s3://models/pretrained.ckpt", + strict=False + ) + + # The model now has loaded weights + predictions = model(input_data) + + Notes + ----- + * This loader preserves PyTorch Lightning metadata including hyperparameters + * Data indices are maintained for models that use them + * The model's training state is not affected (remains in same mode) + * Loading is performed on CPU first, then moved to model's device + """ + checkpoint = load_checkpoint_from_source(checkpoint_source) + + if "state_dict" not in checkpoint: + msg = f"No 'state_dict' found in checkpoint from {checkpoint_source}" + raise ValueError(msg) + + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=strict) + + # Preserve checkpoint metadata for compatibility + if hasattr(model, "_ckpt_model_name_to_index") and "hyper_parameters" in checkpoint: + hyper_params = checkpoint["hyper_parameters"] + if "data_indices" in hyper_params: + model._ckpt_model_name_to_index = hyper_params["data_indices"].name_to_index + + LOGGER.info("Loaded model weights from %s", checkpoint_source) + return model + + +class TransferLearningModelLoader(ModelLoader): + """Model loader for transfer learning with intelligent size mismatch handling. + + This loader is specifically designed for transfer learning scenarios where + the source and target models may have different architectures. It provides + intelligent handling of parameter mismatches including size differences + and missing/extra parameters. + + Key Features + ============ + + * **Size Mismatch Handling**: Automatically skips parameters with incompatible shapes + * **Selective Loading**: Only loads parameters that exist in both source and target + * **Detailed Logging**: Comprehensive logging of skipped parameters for debugging + * **Flexible Validation**: Supports both strict and relaxed loading modes + * **Metadata Preservation**: Maintains compatibility metadata when possible + + Use Cases + ========= + + * **Domain Adaptation**: Adapting models trained on one domain to another + * **Architecture Changes**: Loading weights when model architecture has been modified + * **Fine-tuning**: Loading pretrained weights for fine-tuning on new tasks + * **Partial Loading**: Loading only compatible parameters from larger models + * **Model Evolution**: Migrating weights between model versions + + Transfer Learning Strategies + ============================ + + **Conservative Strategy** (``skip_mismatched=True``): + Only loads parameters with exact shape matches. Safest approach that + guarantees no loading errors but may skip useful parameters. + + **Aggressive Strategy** (``skip_mismatched=False``): + Attempts to load all available parameters. May fail on shape mismatches + but can be useful for debugging or when you know shapes are compatible. + + Behavior Details + ================ + + When ``skip_mismatched=True`` (default): + 1. Compares each parameter's shape between checkpoint and model + 2. Only includes parameters with exact shape matches + 3. Logs skipped parameters with shape information + 4. Results in partial but safe parameter loading + + When ``skip_mismatched=False``: + 1. Includes all parameters present in both checkpoint and model + 2. May fail during PyTorch loading if shapes don't match + 3. Useful for debugging or when manual shape verification is done + + Usage Examples + ============== + + .. code-block:: python + + loader = TransferLearningModelLoader() + + # Conservative transfer learning (recommended) + model = loader.load_model_weights( + model=target_model, + checkpoint_source="s3://bucket/pretrained.ckpt", + strict=False, # Allow missing keys + skip_mismatched=True # Skip size mismatches (safe) + ) + + # Aggressive transfer learning (for debugging) + model = loader.load_model_weights( + model=target_model, + checkpoint_source="pretrained_model.ckpt", + strict=False, + skip_mismatched=False # May fail on mismatches + ) + + Common Scenarios + ================ + + **Scenario 1: Different Output Dimensions** + Source model trained for 100 classes, target for 10 classes. + The loader will skip the final classification layer but load + all feature extraction layers. + + **Scenario 2: Added Layers** + Target model has additional layers not in source model. + Only shared layers are loaded; new layers retain initialization. + + **Scenario 3: Removed Layers** + Source model has layers not in target model. + Extra parameters are ignored; target model remains unchanged. + + Debugging Transfer Learning + =========================== + + The loader provides detailed logging to help debug transfer learning: + + * Lists all skipped parameters with shape information + * Identifies parameters not present in target model + * Reports successful parameter transfers + * Logs final transfer learning status + + Monitor logs to ensure expected parameters are being transferred. + + Performance Considerations + ========================== + + * Parameter comparison is O(n) where n is number of checkpoint parameters + * Shape checking adds minimal overhead + * Logging can be verbose for models with many mismatched parameters + * Memory usage scales with checkpoint size during filtering + """ + + def load_model_weights( + self, + model: torch.nn.Module, + checkpoint_source: str | Path, + strict: bool = False, + skip_mismatched: bool = True, + **kwargs, # noqa: ARG002 + ) -> torch.nn.Module: + """Load model weights with transfer learning compatibility. + + This method performs intelligent transfer learning by comparing parameter + shapes between the checkpoint and target model, selectively loading only + compatible parameters while providing detailed feedback on skipped parameters. + + The transfer learning process: + 1. Loads checkpoint from any supported source + 2. Extracts source model's state dictionary + 3. Compares parameter shapes with target model + 4. Filters parameters based on compatibility and settings + 5. Loads compatible parameters into target model + 6. Preserves metadata and sets transfer learning flags + + Parameters + ---------- + model : torch.nn.Module + The target model to load weights into. Can have a different architecture + from the source model in the checkpoint. + checkpoint_source : str | Path + Path or URL to the checkpoint file containing source model weights. + Supports all checkpoint loading sources (local, S3, HTTP, etc.). + strict : bool, optional + Whether to strictly enforce parameter key matching. Default is False + to allow missing keys in transfer learning scenarios. Setting to True + will cause failures if any expected parameters are missing. + skip_mismatched : bool, optional + Whether to skip parameters with shape mismatches. Default is True. + When True, only parameters with exact shape matches are loaded. + When False, attempts to load all available parameters (may fail + during PyTorch loading if shapes are incompatible). + **kwargs + Additional keyword arguments. Currently unused but maintained for + interface compatibility. + + Returns + ------- + torch.nn.Module + The input model with compatible weights loaded from the checkpoint. + The model is modified in-place and marked as having initialized weights. + + Raises + ------ + ValueError + If the checkpoint doesn't contain a 'state_dict' key. + RuntimeError + If PyTorch fails to load filtered parameters due to remaining + incompatibilities (rare when skip_mismatched=True). + FileNotFoundError + If the checkpoint source cannot be accessed. + + Examples + -------- + + .. code-block:: python + + loader = TransferLearningModelLoader() + + # Safe transfer learning (recommended approach) + target_model = loader.load_model_weights( + model=my_model, + checkpoint_source="s3://bucket/pretrained_resnet.ckpt", + strict=False, # Allow missing keys + skip_mismatched=True # Skip incompatible shapes + ) + + # Check if weights were initialized + if hasattr(target_model, 'weights_initialized'): + print(f"Transfer learning completed: {target_model.weights_initialized}") + + # View detailed logs to see what was transferred + # Logs will show: + # - "Skipping parameter with size mismatch: layer.weight" + # - "Skipping parameter not in model: old_layer.bias" + # - "Transfer learning applied successfully" + + Advanced usage with custom filtering: + + .. code-block:: python + + # For models where you know shapes are compatible + model = loader.load_model_weights( + model=compatible_model, + checkpoint_source="./checkpoints/similar_arch.ckpt", + strict=True, # Enforce exact match + skip_mismatched=False # Trust shape compatibility + ) + + Notes + ----- + * The model's ``weights_initialized`` attribute is set to True after loading + * Comprehensive logging helps debug transfer learning issues + * Metadata preservation maintains model compatibility when possible + * The method handles both PyTorch and PyTorch Lightning checkpoints + * Parameter filtering reduces memory usage by creating smaller state dicts + """ + checkpoint = load_checkpoint_from_source(checkpoint_source) + + if "state_dict" not in checkpoint: + msg = f"No 'state_dict' found in checkpoint from {checkpoint_source}" + raise ValueError(msg) + + state_dict = checkpoint["state_dict"] + model_state_dict = model.state_dict() + + if skip_mismatched: + # Filter out layers with size mismatch + filtered_state_dict = {} + for key, value in state_dict.items(): + if key in model_state_dict: + if value.shape == model_state_dict[key].shape: + filtered_state_dict[key] = value + else: + LOGGER.info("Skipping parameter with size mismatch: %s", key) + LOGGER.info("Checkpoint shape: %s, Model shape: %s", value.shape, model_state_dict[key].shape) + else: + LOGGER.info("Skipping parameter not in model: %s", key) + + state_dict = filtered_state_dict + + model.load_state_dict(state_dict, strict=strict) + + # Preserve checkpoint metadata for compatibility + if hasattr(model, "_ckpt_model_name_to_index") and "hyper_parameters" in checkpoint: + hyper_params = checkpoint["hyper_parameters"] + if "data_indices" in hyper_params: + model._ckpt_model_name_to_index = hyper_params["data_indices"].name_to_index + + model.weights_initialized = True + LOGGER.info("Transfer learning applied successfully from %s", checkpoint_source) + return model + + +class WeightsOnlyModelLoader(ModelLoader): + """Model loader that exclusively loads model weights, ignoring optimizer states and metadata. + + This loader is designed for scenarios where only the model parameters are needed, + such as inference, evaluation, or when starting fresh training with pretrained weights. + It ignores optimizer states, learning rate scheduler states, and other training metadata. + + Key Features + ============ + + * **Weights Focus**: Loads only model parameters, ignoring training state + * **Clean Loading**: No optimizer state or scheduler state contamination + * **Mismatch Handling**: Uses robust transfer learning logic internally + * **Inference Ready**: Ideal for deploying models in production + * **Fresh Training**: Perfect for starting new training with pretrained weights + + Use Cases + ========= + + * **Model Inference**: Loading weights for evaluation or prediction + * **Production Deployment**: Clean model loading without training artifacts + * **Fresh Training Start**: Using pretrained weights for new training run + * **Model Analysis**: Loading weights for model inspection or analysis + * **Cross-framework Transfer**: Loading weights when changing training frameworks + + Implementation Details + ====================== + + This loader leverages the :class:`TransferLearningModelLoader` internally + with specific settings optimized for weights-only loading: + + * Uses ``skip_mismatched=False`` for more permissive loading + * Inherits robust mismatch handling and logging + * Maintains the same error handling and debugging features + * Preserves essential model metadata while ignoring training state + + Comparison with Other Loaders + ============================= + + **vs StandardModelLoader**: + WeightsOnlyModelLoader is more robust for architecture differences + and focuses purely on model parameters. + + **vs TransferLearningModelLoader**: + WeightsOnlyModelLoader uses different default settings optimized + for clean weight loading rather than transfer learning scenarios. + + Usage Examples + ============== + + .. code-block:: python + + loader = WeightsOnlyModelLoader() + + # Load weights for inference + inference_model = loader.load_model_weights( + model=my_model, + checkpoint_source="s3://bucket/trained_model.ckpt", + strict=True # Expect exact architecture match + ) + inference_model.eval() # Set to evaluation mode + + # Load weights for fresh training start + fresh_model = loader.load_model_weights( + model=new_model, + checkpoint_source="./pretrained/weights.ckpt", + strict=False # Allow some flexibility + ) + + # The model has no optimizer state or training metadata + # Perfect for starting new training or inference + + When to Use + =========== + + **Ideal for**: + - Model deployment and inference + - Starting fresh training with pretrained weights + - Model evaluation and analysis + - Cross-framework model migration + - Clean model state without training artifacts + + **Not ideal for**: + - Resuming training (use StandardModelLoader) + - When you need optimizer state + - When training metadata is important + + Performance Considerations + ========================== + + * Slightly more overhead due to internal transfer learning logic + * Memory efficient as it ignores optimizer and scheduler states + * Same loading performance as TransferLearningModelLoader + * Clean model state reduces memory usage in production + """ + + def load_model_weights( + self, + model: torch.nn.Module, + checkpoint_source: str | Path, + strict: bool = True, + **kwargs, + ) -> torch.nn.Module: + """Load only model weights from checkpoint, ignoring training state. + + This method loads exclusively model parameters from a checkpoint, + ignoring optimizer states, learning rate scheduler states, and other + training metadata. It uses the transfer learning loader internally + for robust handling of potential architecture differences. + + The loading process: + 1. Delegates to TransferLearningModelLoader with weights-optimized settings + 2. Uses ``skip_mismatched=False`` for more permissive loading + 3. Inherits all mismatch handling and logging capabilities + 4. Returns model ready for inference or fresh training + + Parameters + ---------- + model : torch.nn.Module + The target model to load weights into. Should have compatible + architecture with the checkpoint model. + checkpoint_source : str | Path + Path or URL to the checkpoint file. Supports all checkpoint + loading sources (local, S3, HTTP, GCS, Azure). + strict : bool, optional + Whether to strictly enforce parameter key matching. Default is True. + For weights-only loading, this typically should be True unless + you expect architecture differences. + **kwargs + Additional keyword arguments passed to the underlying transfer + learning loader. Common options: + + - Custom parameters specific to your use case + - Any other TransferLearningModelLoader options + + Returns + ------- + torch.nn.Module + The input model with loaded weights, ready for inference or fresh + training. The model contains no optimizer state or training metadata. + + Raises + ------ + ValueError + If the checkpoint doesn't contain a 'state_dict' key. + RuntimeError + If PyTorch fails to load the weights due to architecture mismatch. + FileNotFoundError + If the checkpoint source cannot be accessed. + + Examples + -------- + + .. code-block:: python + + loader = WeightsOnlyModelLoader() + + # Load for inference (strict matching) + model = loader.load_model_weights( + model=inference_model, + checkpoint_source="s3://models/production.ckpt", + strict=True + ) + model.eval() # Ready for inference + + # Load for fresh training (more flexible) + model = loader.load_model_weights( + model=training_model, + checkpoint_source="./pretrained/backbone.ckpt", + strict=False + ) + # Model ready for new training without old optimizer state + + Production deployment example: + + .. code-block:: python + + # Load clean weights for production inference + production_model = MyModel(config) + loader = WeightsOnlyModelLoader() + + production_model = loader.load_model_weights( + model=production_model, + checkpoint_source="s3://production/model_v2.ckpt" + ) + + # Model is clean and ready for inference + production_model.eval() + with torch.no_grad(): + predictions = production_model(input_batch) + + Notes + ----- + * Internally uses TransferLearningModelLoader with ``skip_mismatched=False`` + * Inherits all robust error handling and logging from transfer learning loader + * Results in clean model state without optimizer or scheduler artifacts + * The loaded model will have ``weights_initialized=True`` attribute set + * Perfect for inference deployments where training state is not needed + """ + # For weights-only loading, we use the TransferLearningModelLoader + # as it has better handling of potential mismatches + loader = TransferLearningModelLoader() + return loader.load_model_weights( + model=model, + checkpoint_source=checkpoint_source, + strict=strict, + skip_mismatched=False, + **kwargs, + ) + + +class ModelLoaderRegistry: + """Central registry for managing model loading strategies. + + This registry maintains a collection of model loaders and provides a + unified interface for loading model weights using different strategies. + It supports built-in loaders for common scenarios and allows registration + of custom loaders for specialized requirements. + + The registry uses a name-based lookup system where each loader is + associated with a descriptive string identifier. This allows easy + selection of loading strategies through configuration or runtime decisions. + + Features + ======== + + * **Strategy Selection**: Choose loading strategy by name + * **Extensible**: Easy registration of custom loaders + * **Pre-configured**: Comes with loaders for common use cases + * **Error Handling**: Clear error messages for unknown loader types + * **Unified Interface**: Consistent API across all loading strategies + + Built-in Loaders + ================ + + The registry comes pre-configured with three essential loaders: + + * **"standard"**: :class:`StandardModelLoader` - Direct Lightning checkpoint loading + * **"transfer_learning"**: :class:`TransferLearningModelLoader` - Size mismatch handling + * **"weights_only"**: :class:`WeightsOnlyModelLoader` - Clean weight loading + + Usage Examples + ============== + + .. code-block:: python + + from anemoi.training.utils.model_loading import model_loader_registry + + # Use built-in loaders + model = model_loader_registry.load_model_weights( + model=my_model, + checkpoint_source="s3://bucket/checkpoint.ckpt", + loader_type="transfer_learning", + strict=False + ) + + # Get specific loader + loader = model_loader_registry.get_loader("weights_only") + model = loader.load_model_weights(model, checkpoint_source) + + Custom Loaders + ============== + + Register custom loaders for specialized scenarios: + + .. code-block:: python + + class QuantizedModelLoader(ModelLoader): + def load_model_weights(self, model, checkpoint_source, **kwargs): + # Custom quantized loading logic + return quantize_and_load(model, checkpoint_source) + + # Register custom loader + model_loader_registry.register("quantized", QuantizedModelLoader()) + + # Use custom loader + model = model_loader_registry.load_model_weights( + model=my_model, + checkpoint_source="./checkpoints/quantized.ckpt", + loader_type="quantized" + ) + + Configuration Integration + ========================= + + The registry integrates with training configuration: + + .. code-block:: python + + # In training config + training: + checkpoint_loading: + source: "s3://bucket/pretrained.ckpt" + loader_type: "transfer_learning" # Registry key + strict: false + skip_mismatched: true + + # The training system uses the registry + model = model_loader_registry.load_model_weights( + model=model, + checkpoint_source=config.checkpoint_loading.source, + loader_type=config.checkpoint_loading.loader_type, + **config.checkpoint_loading + ) + + Thread Safety + ============= + + The registry is thread-safe for read operations (getting loaders, loading models) + but not for write operations (registering new loaders). Register all custom + loaders during initialization before concurrent usage. + + Performance Considerations + ========================== + + * Loader lookup is O(1) dictionary access + * Loader instances are reused across calls + * No overhead for unused loaders + * Custom loaders should be lightweight to instantiate + """ + + def __init__(self) -> None: + """Initialize the registry with built-in model loaders. + + The registry comes pre-configured with loaders for the most common + model loading scenarios: + - "standard": StandardModelLoader for exact architecture matches + - "transfer_learning": TransferLearningModelLoader for architecture differences + - "weights_only": WeightsOnlyModelLoader for clean weight loading + """ + self._loaders: dict[str, ModelLoader] = {} + # Register default loaders with descriptive names + self.register("standard", StandardModelLoader()) + self.register("transfer_learning", TransferLearningModelLoader()) + self.register("weights_only", WeightsOnlyModelLoader()) + + def register(self, name: str, loader: ModelLoader) -> None: + """Register a model loader with the registry. + + Associates a model loader instance with a string identifier for + later retrieval. The name should be descriptive and unique. + + Parameters + ---------- + name : str + Unique identifier for the loader. Should be descriptive and + follow naming conventions (lowercase, underscores for spaces). + Examples: "quantized", "pruned", "custom_transfer". + loader : ModelLoader + The model loader instance to register. Must implement the + ModelLoader interface. + + Raises + ------ + TypeError + If the loader doesn't implement the ModelLoader interface. + + Examples + -------- + + .. code-block:: python + + registry = ModelLoaderRegistry() + + # Register custom loaders + registry.register("quantized", QuantizedModelLoader()) + registry.register("pruned", PrunedModelLoader()) + registry.register("distilled", DistilledModelLoader()) + + # Names should be descriptive + registry.register("bert_to_gpt", CrossArchitectureLoader()) + + Notes + ----- + * Loader names are case-sensitive + * Registering with an existing name overwrites the previous loader + * Loaders are stored by reference, not copied + """ + self._loaders[name] = loader + + def get_loader(self, name: str) -> ModelLoader: + """Retrieve a registered model loader by name. + + Returns the model loader instance associated with the given name. + The loader can then be used directly for custom loading scenarios. + + Parameters + ---------- + name : str + The name of the registered loader to retrieve. + + Returns + ------- + ModelLoader + The model loader instance associated with the name. + + Raises + ------ + ValueError + If no loader is registered with the given name. The error + message includes available loader names for debugging. + + Examples + -------- + + .. code-block:: python + + registry = ModelLoaderRegistry() + + # Get built-in loaders + standard_loader = registry.get_loader("standard") + transfer_loader = registry.get_loader("transfer_learning") + + # Use loader directly + model = standard_loader.load_model_weights( + model=my_model, + checkpoint_source="./checkpoint.ckpt" + ) + + # Handle unknown loaders + try: + loader = registry.get_loader("nonexistent") + except ValueError as e: + print(f"Error: {e}") # Shows available loaders + """ + if name not in self._loaders: + msg = f"Unknown loader: {name}. Available: {list(self._loaders.keys())}" + raise ValueError(msg) + return self._loaders[name] + + def load_model_weights( + self, + model: torch.nn.Module, + checkpoint_source: str | Path, + loader_type: str = "standard", + **kwargs, + ) -> torch.nn.Module: + """Load model weights using a specified loader strategy. + + This is the main entry point for loading model weights through the + registry. It combines loader selection and weight loading in one + convenient method call. + + Parameters + ---------- + model : torch.nn.Module + The target model to load weights into. + checkpoint_source : str | Path + Path or URL to the checkpoint file. Supports all checkpoint + loading sources (local, S3, HTTP, GCS, Azure). + loader_type : str, optional + Name of the registered loader to use. Default is "standard". + Built-in options: "standard", "transfer_learning", "weights_only". + **kwargs + Additional keyword arguments passed to the selected loader. + Options depend on the specific loader being used. + + Returns + ------- + torch.nn.Module + The input model with loaded weights from the checkpoint. + + Raises + ------ + ValueError + If the loader_type is not registered in the registry. + FileNotFoundError + If the checkpoint source cannot be accessed. + RuntimeError + If the selected loader fails to load the weights. + + Examples + -------- + + .. code-block:: python + + registry = ModelLoaderRegistry() + + # Standard loading + model = registry.load_model_weights( + model=my_model, + checkpoint_source="./checkpoint.ckpt", + loader_type="standard" + ) + + # Transfer learning with options + model = registry.load_model_weights( + model=target_model, + checkpoint_source="s3://bucket/pretrained.ckpt", + loader_type="transfer_learning", + strict=False, + skip_mismatched=True + ) + + # Weights-only loading + model = registry.load_model_weights( + model=inference_model, + checkpoint_source="https://models.com/weights.ckpt", + loader_type="weights_only" + ) + + Integration example: + + .. code-block:: python + + # Function for loading from config + def load_from_config(model, config): + return model_loader_registry.load_model_weights( + model=model, + checkpoint_source=config.source, + loader_type=config.get("loader_type", "standard"), + **config.get("loader_options", {}) + ) + + Notes + ----- + * This method is thread-safe for concurrent loading operations + * The selected loader handles all checkpoint source types automatically + * Error messages include available loader types for easy debugging + * All loader-specific options are passed through via **kwargs + """ + loader = self.get_loader(loader_type) + return loader.load_model_weights(model, checkpoint_source, **kwargs) + + +# Global registry instance pre-configured with built-in loaders +#: ModelLoaderRegistry: Global model loader registry. +#: +#: This registry comes pre-configured with loaders for the most common +#: model loading scenarios and is used by :func:`load_model_from_checkpoint`. +#: Custom loaders can be registered with this global instance: +#: +#: .. code-block:: python +#: +#: from anemoi.training.utils.model_loading import model_loader_registry +#: model_loader_registry.register("my_custom", MyCustomLoader()) +#: +#: Available built-in loaders: +#: +#: * ``"standard"``: :class:`StandardModelLoader` - Direct PyTorch Lightning loading +#: * ``"transfer_learning"``: :class:`TransferLearningModelLoader` - Mismatch handling +#: * ``"weights_only"``: :class:`WeightsOnlyModelLoader` - Clean weight loading +model_loader_registry = ModelLoaderRegistry() + + +def load_model_from_checkpoint( + model: torch.nn.Module, + checkpoint_source: str | Path, + loader_type: str = "standard", + **kwargs, +) -> torch.nn.Module: + """Load PyTorch model weights from checkpoint using the global registry. + + This is the main entry point for loading model weights in the Anemoi training + system. It provides a simple, unified interface for loading model weights using + different strategies without requiring direct interaction with the registry. + + The function automatically handles checkpoint source detection and uses the + global model loader registry to select and execute the appropriate loading + strategy based on the specified loader type. + + Supported Loading Strategies + ============================ + + * **"standard"**: Direct loading for exact architecture matches (PyTorch Lightning) + * **"transfer_learning"**: Intelligent loading with size mismatch handling + * **"weights_only"**: Clean loading of model parameters without training state + + Checkpoint Source Support + ========================= + + * **Local files**: ``/path/to/checkpoint.ckpt``, ``./models/checkpoint.ckpt`` + * **HTTP/HTTPS**: ``https://example.com/models/checkpoint.ckpt`` + * **Amazon S3**: ``s3://bucket/models/checkpoint.ckpt`` + * **Google Cloud Storage**: ``gs://bucket/models/checkpoint.ckpt`` + * **Azure Blob Storage**: ``azure://account.blob.core.windows.net/container/checkpoint.ckpt`` + + Parameters + ---------- + model : torch.nn.Module + The target PyTorch model to load weights into. The model should already + be instantiated with the desired architecture. + checkpoint_source : str | Path + Path or URL to the checkpoint file. Supports local files and remote + sources through the checkpoint loading system. Both string paths and + pathlib.Path objects are supported for local files. + loader_type : str, optional + The loading strategy to use. Default is "standard". Available options: + + - ``"standard"``: For exact architecture matches (PyTorch Lightning) + - ``"transfer_learning"``: For handling architecture differences + - ``"weights_only"``: For loading only model parameters + + **kwargs + Additional keyword arguments passed to the selected loader. + Common options include: + + - ``strict`` (bool): Whether to enforce exact parameter key matching + - ``skip_mismatched`` (bool): Skip parameters with shape mismatches (transfer learning) + - Custom loader-specific parameters + + Returns + ------- + torch.nn.Module + The input model with loaded weights from the checkpoint. The same model + object is returned for method chaining, but parameters are updated in-place. + + Raises + ------ + ValueError + If the loader_type is not recognized or checkpoint format is invalid. + FileNotFoundError + If the checkpoint source cannot be accessed or doesn't exist. + RuntimeError + If PyTorch fails to load weights due to architecture mismatch or corruption. + ImportError + If required cloud storage libraries are missing for remote sources. + + Examples + -------- + Basic usage with different loading strategies: + + .. code-block:: python + + from anemoi.training.utils.model_loading import load_model_from_checkpoint + + # Standard loading (exact architecture match) + model = load_model_from_checkpoint( + model=my_lightning_model, + checkpoint_source="./checkpoints/epoch_10.ckpt", + loader_type="standard" + ) + + # Transfer learning (handles mismatches) + model = load_model_from_checkpoint( + model=target_model, + checkpoint_source="s3://bucket/pretrained.ckpt", + loader_type="transfer_learning", + strict=False, + skip_mismatched=True + ) + + # Weights-only loading (clean inference) + model = load_model_from_checkpoint( + model=inference_model, + checkpoint_source="https://models.com/production.ckpt", + loader_type="weights_only" + ) + + Advanced usage with custom options: + + .. code-block:: python + + # Transfer learning with detailed control + model = load_model_from_checkpoint( + model=custom_model, + checkpoint_source="gs://bucket/research_model.ckpt", + loader_type="transfer_learning", + strict=False, # Allow missing keys + skip_mismatched=True # Skip size mismatches + ) + + # Check if transfer learning was successful + if hasattr(model, 'weights_initialized'): + print(f"Weights initialized: {model.weights_initialized}") + + Integration with training pipeline: + + .. code-block:: python + + # Function for config-driven loading + def load_from_config(model, checkpoint_config): + return load_model_from_checkpoint( + model=model, + checkpoint_source=checkpoint_config.source, + loader_type=checkpoint_config.get("type", "standard"), + **checkpoint_config.get("options", {}) + ) + + # Usage in training setup + if config.checkpoint_loading.enabled: + model = load_from_config(model, config.checkpoint_loading) + + Error handling example: + + .. code-block:: python + + try: + model = load_model_from_checkpoint( + model=model, + checkpoint_source="s3://bucket/model.ckpt", + loader_type="transfer_learning" + ) + except ValueError as e: + print(f"Invalid loader type or checkpoint format: {e}") + except FileNotFoundError as e: + print(f"Checkpoint not found: {e}") + except ImportError as e: + print(f"Missing dependencies for cloud storage: {e}") + + Notes + ----- + * This function uses the global ``model_loader_registry`` which comes + pre-configured with standard loaders + * All checkpoint sources are handled automatically by the checkpoint loading system + * The function is thread-safe for concurrent loading operations + * Custom loaders can be registered with the global registry for specialized needs + * Loading performance depends on checkpoint size and source type (local vs remote) + + See Also + -------- + * :class:`ModelLoaderRegistry`: For registering custom loaders + * :class:`ModelLoader`: For implementing custom loading strategies + * :func:`anemoi.training.utils.checkpoint_loaders.load_checkpoint_from_source`: For checkpoint access + * :mod:`anemoi.training.train.modify`: For model modification after loading + """ + return model_loader_registry.load_model_weights( + model=model, + checkpoint_source=checkpoint_source, + loader_type=loader_type, + **kwargs, + ) diff --git a/training/tests/config/test_checkpoint_loading_configs.py b/training/tests/config/test_checkpoint_loading_configs.py new file mode 100644 index 000000000..dee68827f --- /dev/null +++ b/training/tests/config/test_checkpoint_loading_configs.py @@ -0,0 +1,197 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from hydra import compose +from hydra import initialize_config_dir +from omegaconf import DictConfig + +from anemoi.training.config import CONFIG_PATH + + +class TestCheckpointLoadingConfigs: + """Test checkpoint_loading configuration templates.""" + + def test_config_templates_exist(self): + """Test that all expected config templates exist.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + expected_templates = ["weights_only.yml", "transfer_learning.yml", "standard.yml"] + + for template in expected_templates: + template_path = config_dir / template + assert template_path.exists(), f"Missing config template: {template}" + + def test_weights_only_config_structure(self): + """Test weights_only.yml config structure.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + config_path = config_dir / "weights_only.yml" + + with open(config_path) as f: + config = yaml.safe_load(f) + + assert "source" in config + assert "loader_type" in config + assert config["loader_type"] == "weights_only" + assert "strict" in config + + def test_transfer_learning_config_structure(self): + """Test transfer_learning.yml config structure.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + config_path = config_dir / "transfer_learning.yml" + + with open(config_path) as f: + config = yaml.safe_load(f) + + assert "source" in config + assert "loader_type" in config + assert config["loader_type"] == "transfer_learning" + assert "strict" in config + assert "skip_mismatched" in config + + # Transfer learning should have appropriate defaults + assert config["strict"] is False + assert config["skip_mismatched"] is True + + def test_standard_config_structure(self): + """Test standard.yml config structure.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + config_path = config_dir / "standard.yml" + + with open(config_path) as f: + config = yaml.safe_load(f) + + assert "source" in config + assert "loader_type" in config + assert config["loader_type"] == "standard" + assert "strict" in config + + def test_config_yaml_validity(self): + """Test that all config templates are valid YAML.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + + for config_file in config_dir.glob("*.yml"): + with open(config_file) as f: + try: + yaml.safe_load(f) + except yaml.YAMLError: + pytest.fail(f"Invalid YAML in {config_file}") + + def test_hydra_config_loading(self): + """Test that configs can be loaded through Hydra.""" + config_dir = Path(CONFIG_PATH) + + # Test loading each checkpoint loading config through Hydra + checkpoint_configs = [ + "checkpoint_loading=weights_only", + "checkpoint_loading=transfer_learning", + "checkpoint_loading=standard", + ] + + with initialize_config_dir(config_dir=str(config_dir), version_base=None): + for config_override in checkpoint_configs: + try: + # Try to compose with minimal overrides + cfg = compose(overrides=[config_override]) + assert "checkpoint_loading" in cfg + assert cfg.checkpoint_loading.loader_type is not None + except Exception as e: + pytest.fail(f"Failed to load config {config_override}: {e}") + + def test_config_parameter_types(self): + """Test that config parameters have correct types.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + + for config_file in config_dir.glob("*.yml"): + with open(config_file) as f: + config = yaml.safe_load(f) + + # source should be string or null + assert config["source"] is None or isinstance(config["source"], str) + + # loader_type should be string + assert isinstance(config["loader_type"], str) + assert config["loader_type"] in ["weights_only", "transfer_learning", "standard"] + + # strict should be boolean if present + if "strict" in config: + assert isinstance(config["strict"], bool) + + # skip_mismatched should be boolean if present + if "skip_mismatched" in config: + assert isinstance(config["skip_mismatched"], bool) + + def test_config_templates_with_actual_paths(self): + """Test config templates with actual checkpoint paths.""" + test_configs = { + "weights_only": {"source": "/path/to/checkpoint.ckpt", "loader_type": "weights_only", "strict": True}, + "transfer_learning": { + "source": "s3://bucket/checkpoint.ckpt", + "loader_type": "transfer_learning", + "strict": False, + "skip_mismatched": True, + }, + "standard": {"source": "https://example.com/checkpoint.ckpt", "loader_type": "standard", "strict": True}, + } + + for config_name, expected_structure in test_configs.items(): + # Verify our templates match expected structure + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + config_path = config_dir / f"{config_name}.yml" + + with open(config_path) as f: + template_config = yaml.safe_load(f) + + # Check that all expected keys exist + for key in expected_structure: + assert key in template_config + + def test_config_documentation_completeness(self): + """Test that configs have appropriate documentation/comments.""" + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + + for config_file in config_dir.glob("*.yml"): + with open(config_file) as f: + content = f.read() + + # Should have comments explaining the parameters + assert "#" in content, f"Config {config_file.name} should have documentation comments" + + # Key parameters should be documented + if "source" in content: + assert ( + "Path" in content or "URL" in content + ), f"Config {config_file.name} should document source parameter" + + def test_config_integration_with_training_schema(self): + """Test that config templates work with training schema validation.""" + from anemoi.training.schemas.base_schema import BaseSchema + + config_dir = Path(CONFIG_PATH) / "training" / "checkpoint_loading" + + for config_file in config_dir.glob("*.yml"): + with open(config_file) as f: + checkpoint_config = yaml.safe_load(f) + + # Create minimal training config with checkpoint loading + training_config = DictConfig({"training": {"checkpoint_loading": checkpoint_config}}) + + # Should not raise validation error for checkpoint_loading field + try: + config = BaseSchema(**training_config) + assert config.training.checkpoint_loading is not None + except Exception as e: + # If it fails, it shouldn't be due to checkpoint_loading structure + error_str = str(e) + if "checkpoint_loading" in error_str: + pytest.fail(f"Config template {config_file.name} failed schema validation: {e}") diff --git a/training/tests/schemas/__init__.py b/training/tests/schemas/__init__.py new file mode 100644 index 000000000..9fc775e54 --- /dev/null +++ b/training/tests/schemas/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/training/tests/schemas/test_checkpoint_loading_schema.py b/training/tests/schemas/test_checkpoint_loading_schema.py new file mode 100644 index 000000000..eaf5ddc8c --- /dev/null +++ b/training/tests/schemas/test_checkpoint_loading_schema.py @@ -0,0 +1,154 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from omegaconf import DictConfig +from pydantic import ValidationError + +from anemoi.training.schemas.base_schema import BaseSchema + + +class TestCheckpointLoadingSchema: + """Test checkpoint_loading field validation in training schema.""" + + def test_checkpoint_loading_field_exists(self): + """Test that checkpoint_loading field exists in schema.""" + config_dict = { + "training": {"checkpoint_loading": {"source": "/path/to/checkpoint.ckpt", "loader_type": "weights_only"}}, + } + + # Should not raise validation error + config = BaseSchema(**DictConfig(config_dict)) + assert hasattr(config.training, "checkpoint_loading") + assert config.training.checkpoint_loading is not None + + def test_checkpoint_loading_none_by_default(self): + """Test that checkpoint_loading defaults to None.""" + config_dict = {"training": {}} + + config = BaseSchema(**DictConfig(config_dict)) + assert config.training.checkpoint_loading is None + + def test_checkpoint_loading_valid_config(self): + """Test valid checkpoint_loading configurations.""" + valid_configs = [ + # Basic weights_only + {"source": "/local/path/checkpoint.ckpt", "loader_type": "weights_only", "strict": True}, + # Transfer learning with options + { + "source": "s3://bucket/checkpoint.ckpt", + "loader_type": "transfer_learning", + "strict": False, + "skip_mismatched": True, + }, + # Standard loading + {"source": "https://example.com/checkpoint.ckpt", "loader_type": "standard"}, + # Cloud sources + {"source": "gs://bucket/checkpoint.ckpt", "loader_type": "weights_only"}, + { + "source": "azure://account.blob.core.windows.net/container/checkpoint.ckpt", + "loader_type": "transfer_learning", + }, + ] + + for checkpoint_config in valid_configs: + config_dict = {"training": {"checkpoint_loading": checkpoint_config}} + + # Should not raise validation error + config = BaseSchema(**DictConfig(config_dict)) + assert config.training.checkpoint_loading is not None + + def test_checkpoint_loading_dict_type(self): + """Test that checkpoint_loading accepts dict type.""" + config_dict = { + "training": { + "checkpoint_loading": { + "source": "/path/to/checkpoint.ckpt", + "loader_type": "weights_only", + "custom_param": "custom_value", # Additional parameters should be allowed + }, + }, + } + + config = BaseSchema(**DictConfig(config_dict)) + assert isinstance(config.training.checkpoint_loading, (dict, DictConfig)) + + def test_checkpoint_loading_empty_dict(self): + """Test that empty dict is valid for checkpoint_loading.""" + config_dict = {"training": {"checkpoint_loading": {}}} + + # Should not raise validation error - empty dict is valid + config = BaseSchema(**DictConfig(config_dict)) + assert config.training.checkpoint_loading == {} + + def test_checkpoint_loading_none_explicit(self): + """Test explicitly setting checkpoint_loading to None.""" + config_dict = {"training": {"checkpoint_loading": None}} + + config = BaseSchema(**DictConfig(config_dict)) + assert config.training.checkpoint_loading is None + + def test_checkpoint_loading_with_minimal_training_config(self): + """Test checkpoint_loading with minimal training configuration.""" + # Create a minimal training config that would normally be valid + config_dict = { + "training": { + "model_task": "anemoi.training.train.forecaster.GraphForecaster", + "checkpoint_loading": {"source": "/path/to/checkpoint.ckpt", "loader_type": "weights_only"}, + }, + } + + # This might not fully validate due to missing required fields, + # but checkpoint_loading field itself should be parsed correctly + try: + config = BaseSchema(**DictConfig(config_dict)) + assert config.training.checkpoint_loading is not None + except ValidationError as e: + # If validation fails, it shouldn't be due to checkpoint_loading field + error_str = str(e) + assert "checkpoint_loading" not in error_str + + def test_checkpoint_loading_with_complex_config(self): + """Test checkpoint_loading works with complex nested configuration.""" + config_dict = { + "training": { + "checkpoint_loading": { + "source": "s3://my-bucket/experiments/run-123/checkpoint-best.ckpt", + "loader_type": "transfer_learning", + "strict": False, + "skip_mismatched": True, + "additional_options": {"nested": "value", "list": [1, 2, 3]}, + }, + }, + } + + config = BaseSchema(**DictConfig(config_dict)) + checkpoint_config = config.training.checkpoint_loading + + assert checkpoint_config["source"] == "s3://my-bucket/experiments/run-123/checkpoint-best.ckpt" + assert checkpoint_config["loader_type"] == "transfer_learning" + assert checkpoint_config["strict"] is False + assert checkpoint_config["skip_mismatched"] is True + assert "additional_options" in checkpoint_config + + def test_checkpoint_loading_field_description(self): + """Test that the field has proper description/documentation.""" + from anemoi.training.schemas.training import BaseTrainingSchema + + # Check that the field exists in the schema + fields = BaseTrainingSchema.model_fields + assert "checkpoint_loading" in fields + + # Check field configuration + field_info = fields["checkpoint_loading"] + assert field_info.default is None # Should default to None + + # Field should accept Union[dict, None] + # The actual validation depends on pydantic implementation details diff --git a/training/tests/train/test_checkpoint_loading_integration.py b/training/tests/train/test_checkpoint_loading_integration.py new file mode 100644 index 000000000..40da310d2 --- /dev/null +++ b/training/tests/train/test_checkpoint_loading_integration.py @@ -0,0 +1,262 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock +from unittest.mock import patch + +import torch +import torch.nn as nn +from omegaconf import DictConfig + +from anemoi.training.train.train import AnemoiTrainer + + +class SimpleMockModel(nn.Module): + """Simple model for testing.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 1) + + def forward(self, x): + return self.linear(x) + + +class TestCheckpointLoadingIntegration: + """Integration tests for checkpoint loading in the training pipeline.""" + + def create_test_checkpoint(self, tmp_path: Path) -> Path: + """Create a test checkpoint file.""" + model = SimpleMockModel() + checkpoint_path = tmp_path / "test_checkpoint.ckpt" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer": {}, + "lr_scheduler": {}, + "epoch": 1, + }, + checkpoint_path, + ) + return checkpoint_path + + @patch("anemoi.training.utils.model_loading.load_model_from_checkpoint") + def test_load_checkpoint_if_configured_with_valid_config(self, mock_load_model, tmp_path): + """Test _load_checkpoint_if_configured with valid configuration.""" + # Create test checkpoint + checkpoint_path = self.create_test_checkpoint(tmp_path) + + # Mock trainer with minimal config + config = DictConfig( + { + "training": { + "checkpoint_loading": { + "source": str(checkpoint_path), + "loader_type": "weights_only", + "strict": True, + }, + }, + }, + ) + + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + + # Create mock model + mock_model = SimpleMockModel() + mock_load_model.return_value = mock_model + + # Test the method + result = trainer._load_checkpoint_if_configured(mock_model) + + # Verify load_model_from_checkpoint was called + mock_load_model.assert_called_once_with( + model=mock_model, + checkpoint_source=str(checkpoint_path), + loader_type="weights_only", + strict=True, + ) + assert result == mock_model + + def test_load_checkpoint_if_configured_no_config(self): + """Test _load_checkpoint_if_configured with no checkpoint_loading config.""" + config = DictConfig({"training": {}}) + + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + + mock_model = SimpleMockModel() + result = trainer._load_checkpoint_if_configured(mock_model) + + # Should return original model unchanged + assert result is mock_model + + def test_load_checkpoint_if_configured_no_source(self): + """Test _load_checkpoint_if_configured with config but no source.""" + config = DictConfig({"training": {"checkpoint_loading": {"loader_type": "weights_only"}}}) + + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + + mock_model = SimpleMockModel() + + with patch("anemoi.training.train.train.LOGGER") as mock_logger: + result = trainer._load_checkpoint_if_configured(mock_model) + + # Should log warning and return original model + mock_logger.warning.assert_called_once() + assert result is mock_model + + @patch("anemoi.training.utils.model_loading.load_model_from_checkpoint") + def test_load_checkpoint_if_configured_with_optional_params(self, mock_load_model, tmp_path): + """Test _load_checkpoint_if_configured with optional parameters.""" + checkpoint_path = self.create_test_checkpoint(tmp_path) + + config = DictConfig( + { + "training": { + "checkpoint_loading": { + "source": str(checkpoint_path), + "loader_type": "transfer_learning", + "strict": False, + "skip_mismatched": True, + }, + }, + }, + ) + + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + + mock_model = SimpleMockModel() + mock_load_model.return_value = mock_model + + trainer._load_checkpoint_if_configured(mock_model) + + # Verify all parameters were passed + mock_load_model.assert_called_once_with( + model=mock_model, + checkpoint_source=str(checkpoint_path), + loader_type="transfer_learning", + strict=False, + skip_mismatched=True, + ) + + def test_load_checkpoint_if_configured_missing_loader_type(self, tmp_path): + """Test behavior with missing loader_type.""" + checkpoint_path = self.create_test_checkpoint(tmp_path) + + config = DictConfig( + { + "training": { + "checkpoint_loading": { + "source": str(checkpoint_path), + # missing loader_type + }, + }, + }, + ) + + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + + mock_model = SimpleMockModel() + + # Should handle missing loader_type gracefully + with patch("anemoi.training.utils.model_loading.load_model_from_checkpoint") as mock_load_model: + trainer._load_checkpoint_if_configured(mock_model) + + # Should still be called, with None loader_type + mock_load_model.assert_called_once() + call_kwargs = mock_load_model.call_args.kwargs + assert call_kwargs["checkpoint_source"] == str(checkpoint_path) + + @patch("anemoi.training.utils.model_loading.load_model_from_checkpoint") + def test_load_checkpoint_integration_in_model_property(self, mock_load_model, tmp_path): + """Test checkpoint loading integration in model property.""" + checkpoint_path = self.create_test_checkpoint(tmp_path) + + # Create minimal config with checkpoint loading + config = DictConfig( + {"training": {"checkpoint_loading": {"source": str(checkpoint_path), "loader_type": "weights_only"}}}, + ) + + # Mock the trainer's dependencies + with patch.multiple( + AnemoiTrainer, + config=config, + data_indices=MagicMock(), + graph_data=MagicMock(), + metadata=MagicMock(), + datamodule=MagicMock(), + supporting_arrays=MagicMock(), + load_weights_only=False, # Ensure legacy loading doesn't interfere + ): + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + trainer.data_indices = MagicMock() + trainer.graph_data = MagicMock() + trainer.metadata = MagicMock() + trainer.datamodule = MagicMock() + trainer.supporting_arrays = MagicMock() + trainer.load_weights_only = False + + mock_model = SimpleMockModel() + mock_load_model.return_value = mock_model + + # Mock GraphForecaster construction + with patch("anemoi.training.train.train.GraphForecaster") as mock_forecaster: + mock_forecaster.return_value = mock_model + + # Call model property + result = trainer.model + + # Verify checkpoint loading was called + mock_load_model.assert_called_once() + assert result == mock_model + + def test_legacy_and_new_checkpoint_loading_priority(self, tmp_path): + """Test that new checkpoint loading takes precedence over legacy.""" + checkpoint_path = self.create_test_checkpoint(tmp_path) + + config = DictConfig( + {"training": {"checkpoint_loading": {"source": str(checkpoint_path), "loader_type": "weights_only"}}}, + ) + + with patch.multiple( + AnemoiTrainer, + config=config, + data_indices=MagicMock(), + graph_data=MagicMock(), + metadata=MagicMock(), + datamodule=MagicMock(), + supporting_arrays=MagicMock(), + load_weights_only=True, # Legacy loading would normally activate + last_checkpoint=str(checkpoint_path), + ): + trainer = AnemoiTrainer.__new__(AnemoiTrainer) + trainer.config = config + trainer.load_weights_only = True + trainer.last_checkpoint = str(checkpoint_path) + + with patch("anemoi.training.utils.model_loading.load_model_from_checkpoint") as mock_new_load: + with patch("anemoi.training.train.train.GraphForecaster") as mock_forecaster: + mock_model = SimpleMockModel() + mock_forecaster.return_value = mock_model + mock_new_load.return_value = mock_model + + trainer.model + + # New checkpoint loading should be called + mock_new_load.assert_called_once() + # Legacy loading should not happen due to the condition + mock_forecaster.load_from_checkpoint.assert_not_called() diff --git a/training/tests/utils/test_checkpoint_loaders.py b/training/tests/utils/test_checkpoint_loaders.py new file mode 100644 index 000000000..74c996b3e --- /dev/null +++ b/training/tests/utils/test_checkpoint_loaders.py @@ -0,0 +1,319 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Comprehensive test suite for the Checkpoint Loading System. + +This module tests the extensible checkpoint loading infrastructure that enables +loading PyTorch model checkpoints from various sources including local filesystem, +cloud storage (S3, GCS, Azure), and HTTP/HTTPS endpoints. + +Test Coverage +============= + +1. **LocalCheckpointLoader**: Tests for local filesystem checkpoint loading +2. **RemoteCheckpointLoader**: Tests for remote URL and cloud storage loading +3. **CheckpointLoaderRegistry**: Tests for loader registration and selection +4. **Integration Tests**: End-to-end tests with multiple loaders +5. **Error Handling**: Tests for various failure scenarios + +Key Testing Scenarios +===================== + +- Source type detection (local vs remote) +- File existence validation +- Network error handling +- Cloud authentication simulation +- Registry pattern functionality +- Fallback mechanisms + +Test Organization +================= + +- TestLocalCheckpointLoader: Local filesystem operations +- TestRemoteCheckpointLoader: Remote source operations +- TestCheckpointLoaderRegistry: Registry management +- TestIntegration: Combined loader workflows + +Testing Principles +================== + +- Mock external dependencies (network, cloud APIs) +- Use temporary files for local testing +- Validate error messages for debugging +- Test both success and failure paths +- Ensure proper resource cleanup +""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock +from unittest.mock import patch + +import pytest +import torch + +from anemoi.training.utils.checkpoint_loaders import CheckpointLoaderRegistry +from anemoi.training.utils.checkpoint_loaders import LocalCheckpointLoader +from anemoi.training.utils.checkpoint_loaders import RemoteCheckpointLoader +from anemoi.training.utils.checkpoint_loaders import load_checkpoint_from_source + + +class TestLocalCheckpointLoader: + """Tests for the LocalCheckpointLoader implementation. + + The LocalCheckpointLoader is responsible for loading checkpoints from + the local filesystem. These tests verify: + + - Correct source type detection (local paths vs URLs) + - File loading functionality + - Error handling for missing files + - Path object and string path compatibility + - Cross-platform path handling + """ + + def test_supports_local_path(self) -> None: + """Test that various local path formats are correctly identified. + + Validates that the loader recognizes: + - pathlib.Path objects + - Absolute string paths + - Relative string paths + - Simple filenames + + This ensures the loader can handle all common ways users might + specify local checkpoint paths. + """ + loader = LocalCheckpointLoader() + + # Test Path object + assert loader.supports_source(Path("/some/path")) + + # Test string path (should be supported even if doesn't exist) + assert loader.supports_source("/some/local/path") + + # Test non-URL string + assert loader.supports_source("model.ckpt") + + def test_does_not_support_urls(self) -> None: + """Test that remote URLs are correctly rejected by the local loader. + + Ensures proper separation of concerns - the local loader should + not attempt to handle remote sources, leaving them for the + RemoteCheckpointLoader to handle. + """ + loader = LocalCheckpointLoader() + + assert not loader.supports_source("http://example.com/model.ckpt") + assert not loader.supports_source("https://example.com/model.ckpt") + assert not loader.supports_source("s3://bucket/model.ckpt") + + def test_load_existing_checkpoint(self) -> None: + """Test successful loading of a valid checkpoint file. + + This test: + 1. Creates a temporary checkpoint with known data + 2. Loads it using the LocalCheckpointLoader + 3. Verifies the loaded data matches the original + 4. Ensures proper cleanup of temporary files + + This validates the core functionality of checkpoint loading + and ensures data integrity is maintained. + """ + loader = LocalCheckpointLoader() + + # Create a temporary checkpoint file + test_data = {"model_state": {"layer.weight": torch.randn(3, 3)}} + + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + torch.save(test_data, tmp_file.name) + tmp_path = Path(tmp_file.name) + + try: + loaded_data = loader.load_checkpoint(tmp_path) + assert "model_state" in loaded_data + assert "layer.weight" in loaded_data["model_state"] + finally: + tmp_path.unlink() + + def test_load_nonexistent_checkpoint(self) -> None: + """Test proper error handling when checkpoint file doesn't exist. + + Validates that: + - FileNotFoundError is raised with appropriate message + - No system crashes or undefined behavior occurs + - Error message helps users identify the problem + + This is important for user experience and debugging. + """ + loader = LocalCheckpointLoader() + + with pytest.raises(FileNotFoundError): + loader.load_checkpoint("/nonexistent/path/model.ckpt") + + +class TestRemoteCheckpointLoader: + """Test RemoteCheckpointLoader functionality.""" + + def test_supports_remote_urls(self) -> None: + """Test that remote URLs are supported.""" + loader = RemoteCheckpointLoader() + + assert loader.supports_source("http://example.com/model.ckpt") + assert loader.supports_source("https://example.com/model.ckpt") + assert loader.supports_source("s3://bucket/model.ckpt") + assert loader.supports_source("gs://bucket/model.ckpt") + assert loader.supports_source("azure://account.blob.core.windows.net/container/model.ckpt") + + def test_does_not_support_local_paths(self) -> None: + """Test that local paths are not supported by remote loader.""" + loader = RemoteCheckpointLoader() + + assert not loader.supports_source(Path("/local/path")) + assert not loader.supports_source("/local/path") + assert not loader.supports_source("model.ckpt") + + def test_download_http(self) -> None: + """Test HTTP download functionality.""" + loader = RemoteCheckpointLoader() + + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + with patch("urllib.request.urlretrieve") as mock_retrieve: + loader._download_http("http://example.com/model.ckpt", tmp_path) + mock_retrieve.assert_called_once_with("http://example.com/model.ckpt", tmp_path) + finally: + if tmp_path.exists(): + tmp_path.unlink() + + def test_download_s3(self) -> None: + """Test S3 download functionality.""" + loader = RemoteCheckpointLoader() + + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + with patch("boto3.client") as mock_boto3: + mock_client = Mock() + mock_boto3.return_value = mock_client + + loader._download_s3("s3://my-bucket/path/to/model.ckpt", tmp_path) + + mock_boto3.assert_called_once_with("s3") + mock_client.download_file.assert_called_once_with("my-bucket", "path/to/model.ckpt", str(tmp_path)) + finally: + if tmp_path.exists(): + tmp_path.unlink() + + def test_download_s3_missing_dependency(self) -> None: + """Test S3 download with missing boto3 dependency.""" + loader = RemoteCheckpointLoader() + + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + with ( + patch("boto3.client", side_effect=ImportError("No module named 'boto3'")), + pytest.raises(ImportError, match="boto3 required for S3 downloads"), + ): + loader._download_s3("s3://bucket/model.ckpt", tmp_path) + finally: + if tmp_path.exists(): + tmp_path.unlink() + + def test_unsupported_scheme(self) -> None: + """Test error handling for unsupported URL schemes.""" + loader = RemoteCheckpointLoader() + + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + with pytest.raises(ValueError, match="Unsupported remote scheme"): + loader._download_checkpoint("ftp://example.com/model.ckpt", tmp_path) + finally: + if tmp_path.exists(): + tmp_path.unlink() + + +class TestCheckpointLoaderRegistry: + """Test CheckpointLoaderRegistry functionality.""" + + def test_registry_has_default_loaders(self) -> None: + """Test that registry comes with default loaders.""" + registry = CheckpointLoaderRegistry() + + # Should have at least the default loaders + assert len(registry._loaders) >= 2 + + # Test that it can find appropriate loaders + local_loader = registry.get_loader("/local/path") + assert isinstance(local_loader, LocalCheckpointLoader) + + remote_loader = registry.get_loader("http://example.com/model.ckpt") + assert isinstance(remote_loader, RemoteCheckpointLoader) + + def test_register_custom_loader(self) -> None: + """Test registering a custom loader.""" + registry = CheckpointLoaderRegistry() + custom_loader = Mock() + custom_loader.supports_source.return_value = False + + registry.register(custom_loader) + assert custom_loader in registry._loaders + + def test_no_loader_found(self) -> None: + """Test error handling when no loader supports the source.""" + registry = CheckpointLoaderRegistry() + + # Mock all loaders to return False for supports_source + for loader in registry._loaders: + loader.supports_source = Mock(return_value=False) + + with pytest.raises(ValueError, match="No loader found for source"): + registry.get_loader("unsupported://source") + + def test_load_checkpoint_delegates_to_loader(self) -> None: + """Test that load_checkpoint delegates to the appropriate loader.""" + registry = CheckpointLoaderRegistry() + + # Create a mock loader that supports our test source + mock_loader = Mock() + mock_loader.supports_source.return_value = True + mock_loader.load_checkpoint.return_value = {"test": "data"} + + # Replace the loaders with our mock + registry._loaders = [mock_loader] + + result = registry.load_checkpoint("test://source") + + mock_loader.supports_source.assert_called_once_with("test://source") + mock_loader.load_checkpoint.assert_called_once_with("test://source") + assert result == {"test": "data"} + + +class TestLoadCheckpointFromSource: + """Test the convenience function.""" + + def test_load_checkpoint_from_source_function(self) -> None: + """Test the main convenience function.""" + test_data = {"model": "data"} + + with tempfile.NamedTemporaryFile(suffix=".ckpt", delete=False) as tmp_file: + torch.save(test_data, tmp_file.name) + tmp_path = Path(tmp_file.name) + + try: + loaded_data = load_checkpoint_from_source(tmp_path) + assert loaded_data == test_data + finally: + tmp_path.unlink() diff --git a/training/tests/utils/test_model_loading.py b/training/tests/utils/test_model_loading.py new file mode 100644 index 000000000..fe806fced --- /dev/null +++ b/training/tests/utils/test_model_loading.py @@ -0,0 +1,247 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Comprehensive test suite for the Model Loading System. + +This module tests the model loading strategies that handle different scenarios +for loading PyTorch model weights from checkpoints, including standard loading, +transfer learning with architecture mismatches, and weights-only loading. + +Test Coverage +============= + +1. **StandardModelLoader**: Tests for exact architecture match loading +2. **TransferLearningModelLoader**: Tests for flexible loading with mismatches +3. **WeightsOnlyModelLoader**: Tests for parameter-only loading +4. **ModelLoaderRegistry**: Tests for strategy registration and selection +5. **Integration**: End-to-end tests combining loaders with checkpoint sources + +Key Testing Scenarios +===================== + +- Exact architecture matches +- Parameter size mismatches +- Missing layers in checkpoint +- Extra layers in checkpoint +- Metadata preservation +- Hyperparameter handling +- Error conditions and recovery + +Test Organization +================= + +- DummyModel: Simple test model with known architecture +- TestStandardModelLoader: Standard loading strategy tests +- TestTransferLearningModelLoader: Transfer learning tests +- TestModelLoaderRegistry: Registry pattern tests +- TestIntegration: Combined workflow tests + +Testing Principles +================== + +- Use dummy models with known architectures +- Mock checkpoint loading to isolate logic +- Test both successful and error cases +- Validate metadata preservation +- Ensure proper state dict handling +""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from anemoi.training.utils.model_loading import ModelLoaderRegistry +from anemoi.training.utils.model_loading import StandardModelLoader +from anemoi.training.utils.model_loading import TransferLearningModelLoader +from anemoi.training.utils.model_loading import load_model_from_checkpoint + + +class DummyModel(nn.Module): + """Simple model for testing.""" + + def __init__(self) -> None: + super().__init__() + self.layer1 = nn.Linear(10, 5) + self.layer2 = nn.Linear(5, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer2(self.layer1(x)) + + +class TestStandardModelLoader: + """Test StandardModelLoader functionality.""" + + def test_load_standard_checkpoint(self) -> None: + """Test loading a standard Lightning checkpoint.""" + loader = StandardModelLoader() + model = DummyModel() + + # Create checkpoint with state_dict + state_dict = model.state_dict() + checkpoint = { + "state_dict": state_dict, + "hyper_parameters": {"data_indices": Mock(name_to_index={"var1": 0, "var2": 1})}, + } + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp_file: + torch.save(checkpoint, tmp_file.name) + tmp_path = Path(tmp_file.name) + + try: + with patch("anemoi.training.utils.model_loading.load_checkpoint_from_source") as mock_load: + mock_load.return_value = checkpoint + loaded_model = loader.load_model_weights(model, tmp_path) + + assert hasattr(loaded_model, "_ckpt_model_name_to_index") + assert loaded_model._ckpt_model_name_to_index == {"var1": 0, "var2": 1} + finally: + tmp_path.unlink() + + def test_load_checkpoint_missing_state_dict(self) -> None: + """Test error handling when checkpoint has no state_dict.""" + loader = StandardModelLoader() + model = DummyModel() + + with patch("anemoi.training.utils.model_loading.load_checkpoint_from_source") as mock_load: + mock_load.return_value = {"some_other_key": "value"} + + with pytest.raises(ValueError, match="No 'state_dict' found"): + loader.load_model_weights(model, "/fake/path.pt") + + +class TestTransferLearningModelLoader: + """Test TransferLearningModelLoader functionality.""" + + def test_load_with_size_mismatch(self) -> None: + """Test loading with size mismatches (should skip mismatched layers).""" + loader = TransferLearningModelLoader() + model = DummyModel() + + # Create checkpoint with mismatched layer size + mismatched_state_dict = { + "layer1.weight": torch.randn(10, 5), # Correct size + "layer1.bias": torch.randn(5), # Correct size + "layer2.weight": torch.randn(1, 20), # Wrong size (should be 1, 5) + "layer2.bias": torch.randn(1), # Correct size + } + + checkpoint = { + "state_dict": mismatched_state_dict, + "hyper_parameters": {"data_indices": Mock(name_to_index={"var1": 0, "var2": 1})}, + } + + with patch("anemoi.training.utils.model_loading.load_checkpoint_from_source") as mock_load: + mock_load.return_value = checkpoint + loaded_model = loader.load_model_weights(model, "/fake/path.pt", skip_mismatched=True) + + # Check that model has weights_initialized flag + assert hasattr(loaded_model, "weights_initialized") + assert loaded_model.weights_initialized is True + + def test_load_without_skip_mismatched(self) -> None: + """Test loading without skipping mismatched layers (should fail).""" + loader = TransferLearningModelLoader() + model = DummyModel() + + # Create checkpoint with mismatched layer size + mismatched_state_dict = { + "layer1.weight": torch.randn(10, 5), # Correct size + "layer2.weight": torch.randn(1, 20), # Wrong size + } + + checkpoint = {"state_dict": mismatched_state_dict} + + with patch("anemoi.training.utils.model_loading.load_checkpoint_from_source") as mock_load: + mock_load.return_value = checkpoint + + with pytest.raises(RuntimeError): # PyTorch raises RuntimeError for size mismatch + loader.load_model_weights(model, "/fake/path.pt", skip_mismatched=False, strict=True) + + +class TestModelLoaderRegistry: + """Test ModelLoaderRegistry functionality.""" + + def test_registry_has_default_loaders(self) -> None: + """Test that registry comes with default loaders.""" + registry = ModelLoaderRegistry() + + assert "standard" in registry._loaders + assert "transfer_learning" in registry._loaders + assert "weights_only" in registry._loaders + + def test_get_loader_by_name(self) -> None: + """Test retrieving loaders by name.""" + registry = ModelLoaderRegistry() + + standard_loader = registry.get_loader("standard") + assert isinstance(standard_loader, StandardModelLoader) + + transfer_loader = registry.get_loader("transfer_learning") + assert isinstance(transfer_loader, TransferLearningModelLoader) + + def test_unknown_loader_error(self) -> None: + """Test error handling for unknown loader names.""" + registry = ModelLoaderRegistry() + + with pytest.raises(ValueError, match="Unknown loader"): + registry.get_loader("nonexistent_loader") + + def test_register_custom_loader(self) -> None: + """Test registering a custom loader.""" + registry = ModelLoaderRegistry() + custom_loader = Mock() + + registry.register("custom", custom_loader) + assert registry.get_loader("custom") is custom_loader + + +class TestLoadModelFromCheckpoint: + """Test the convenience function.""" + + def test_load_model_from_checkpoint_function(self) -> None: + """Test the main convenience function.""" + model = DummyModel() + + # Create a simple checkpoint + state_dict = model.state_dict() + checkpoint = {"state_dict": state_dict} + + with patch("anemoi.training.utils.model_loading.load_checkpoint_from_source") as mock_load: + mock_load.return_value = checkpoint + + loaded_model = load_model_from_checkpoint( + model=model, + checkpoint_source="/fake/path.pt", + loader_type="standard", + ) + + assert loaded_model is model # Should return the same model instance + + def test_load_with_transfer_learning(self) -> None: + """Test loading with transfer learning type.""" + model = DummyModel() + + checkpoint = {"state_dict": model.state_dict(), "hyper_parameters": {"data_indices": Mock(name_to_index={})}} + + with patch("anemoi.training.utils.model_loading.load_checkpoint_from_source") as mock_load: + mock_load.return_value = checkpoint + + loaded_model = load_model_from_checkpoint( + model=model, + checkpoint_source="/fake/path.pt", + loader_type="transfer_learning", + ) + + assert hasattr(loaded_model, "weights_initialized") + assert loaded_model.weights_initialized is True