Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .gitkeep
Empty file.
64 changes: 64 additions & 0 deletions training/docs/user-guide/configuring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,67 @@ error reported is not very intuitive and indeed hides the real issue. We
will work on improving this on future releases, but mean time we
recommend to double check the schemas and the config files to make sure
they are correctly defined.

********************
Checkpoint Loading
********************

Anemoi Training supports flexible checkpoint loading to initialize model
weights from saved checkpoints. This system supports multiple sources
and loading strategies.

Configuration
=============

To configure checkpoint loading, add a ``checkpoint_loading`` section to
your training config:

.. code:: yaml

training:
checkpoint_loading:
source: "/path/to/checkpoint.ckpt"
loader_type: "weights_only"
strict: true

Available Options
=================

- **source**: Path or URL to checkpoint file (local, S3, HTTP, GCS,
Azure)
- **loader_type**: Strategy for loading ("weights_only",
"transfer_learning", "standard")
- **strict**: Whether to require exact parameter matching (optional)

Loader Types
============

**weights_only**
Load only model weights, ignoring optimizer and scheduler states

**transfer_learning**
Load weights with size mismatch handling for cross-model transfer

**standard**
Full Lightning checkpoint loading (weights + optimizer + scheduler)

Remote Sources
==============

The system supports loading from multiple remote sources:

.. code:: yaml

training:
checkpoint_loading:
# S3 checkpoint
source: "s3://my-bucket/models/checkpoint.ckpt"

# HTTP checkpoint
source: "https://example.com/models/checkpoint.ckpt"

# Google Cloud Storage
source: "gs://my-bucket/models/checkpoint.ckpt"

# Azure Blob Storage
source: "azure://account.blob.core.windows.net/container/checkpoint.ckpt"
103 changes: 52 additions & 51 deletions training/docs/user-guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,64 +281,65 @@ specific point they can do this by setting
``config.hardware.files.warm_start`` to be the checkpoint they want to
restart from..

*******************
Transfer Learning
*******************
********************
Checkpoint Loading
********************

Transfer learning allows the model to reuse knowledge from a previously
trained checkpoint. This is particularly useful when the new task is
related to the old one, enabling faster convergence and often improving
model performance.
Multi-source checkpoint loading with automatic format detection.

To enable transfer learning, set the config.training.transfer_learning
flag to True in the configuration file.
Quick Start
===========

.. code:: yaml

training:
# start the training from a checkpoint of a previous run
fork_run_id: ...
load_weights_only: True
transfer_learning: True

When this flag is active and a checkpoint path is specified in
config.hardware.files.warm_start or self.last_checkpoint, the system
loads the pre-trained weights using the `transfer_learning_loading`
function. This approach ensures only compatible weights are loaded and
mismatched layers are handled appropriately.

For example, transfer learning might be used to adapt a weather
forecasting model trained on one geographic region to another region
with similar characteristics.

****************
Model Freezing
****************

Model freezing is a technique where specific parts (submodules) of a
model are excluded from training. This is useful when certain parts of
the model have been sufficiently trained or should remain unchanged for
the current task.

To specify which submodules to freeze, use the
config.training.submodules_to_freeze field in the configuration. List
the names of submodules to be frozen. During model initialization, these
submodules will have their parameters frozen, ensuring they are not
updated during training.

For example with the following configuration, the processor will be
frozen and only the encoder and decoder will be trained:
checkpoint_loading:
source: "s3://bucket/model.ckpt" # Any source
loader_type: "transfer_learning" # Strategy
strict: false # Allow mismatches

Sources
=======

+----------+--------------------------------------+-----------------+
| Source | URL Format | Dependency |
+==========+======================================+=================+
| Local | ``/path/to/file.ckpt`` | None |
+----------+--------------------------------------+-----------------+
| S3 | ``s3://bucket/file.ckpt`` | ``boto3`` |
+----------+--------------------------------------+-----------------+
| HTTP | ``https://url/file.ckpt`` | None |
+----------+--------------------------------------+-----------------+
| GCS | ``gs://bucket/file.ckpt`` | ``google-cloud``|
+----------+--------------------------------------+-----------------+
| Azure | ``azure://account.../file.ckpt`` | ``azure-storage``|
+----------+--------------------------------------+-----------------+

Loader Types
============

- **weights_only**: Model weights only
- **transfer_learning**: Handle size mismatches
- **standard**: Full checkpoint (weights + optimizer)

.. code:: yaml
*******************
Transfer Learning
*******************

training:
# start the training from a checkpoint of a previous run
fork_run_id: ...
load_weights_only: True
Transfer learning allows automatic handling of size mismatches between models:

submodules_to_freeze:
- processor
.. code:: yaml

Freezing can be particularly beneficial in scenarios such as fine-tuning
when only specific components (e.g., the encoder, the decoder) need to
adapt to a new task while keeping others (e.g., the processor) fixed.
training:
checkpoint_loading:
source: "/path/to/pretrained.ckpt"
loader_type: "transfer_learning"
strict: false # allow parameter mismatches
skip_mismatched: true # skip layers with shape mismatches

The transfer learning system automatically:

- Identifies parameter shape mismatches
- Logs which parameters are skipped due to mismatches
- Loads only compatible weights
- Preserves data indices for validation
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
source: null # Path or URL to checkpoint file
loader_type: "standard" # Full Lightning checkpoint loading
strict: true # Require exact parameter matching
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
source: null # Path or URL to checkpoint file
loader_type: "transfer_learning" # Handle size mismatches
strict: false # Allow parameter mismatches
skip_mismatched: true # Skip parameters with shape mismatches
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
source: null # Path or URL to checkpoint file
loader_type: "weights_only" # Load only model weights
strict: true # Require exact parameter matching
2 changes: 2 additions & 0 deletions training/src/anemoi/training/schemas/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,7 @@ class TrainingSchema(BaseModel):
"Configuration of the pressure level scaler apllied in the loss computation."
metrics: list[str]
"List of metrics"
checkpoint_loading: Union[dict, None] = Field(default=None)
"Checkpoint loading configuration for initializing model weights from saved checkpoints."
node_loss_weights: NodeLossWeightsSchema
"Node loss weights configuration."
42 changes: 40 additions & 2 deletions training/src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,13 @@ def model(self) -> GraphForecaster:

model = GraphForecaster(**kwargs)

# Load the model weights
if self.load_weights_only:
# Load checkpoint weights if configured
model = self._load_checkpoint_if_configured(model)

# Legacy checkpoint loading (for compatibility)
if self.load_weights_only and not (
hasattr(self.config.training, "checkpoint_loading") and self.config.training.checkpoint_loading
):
if hasattr(self.config.training, "transfer_learning"):
# Sanify the checkpoint for transfer learning
if self.config.training.transfer_learning:
Expand All @@ -191,6 +196,39 @@ def model(self) -> GraphForecaster:

return model

def _load_checkpoint_if_configured(self, model: torch.nn.Module) -> torch.nn.Module:
"""Load checkpoint weights if checkpoint_loading is configured."""
if not hasattr(self.config.training, "checkpoint_loading") or not self.config.training.checkpoint_loading:
return model

checkpoint_config = self.config.training.checkpoint_loading

if not checkpoint_config.source:
LOGGER.warning("checkpoint_loading configured but no source specified")
return model

from anemoi.training.utils.model_loading import load_model_from_checkpoint

LOGGER.info(
"Loading checkpoint from %s using %s loader",
checkpoint_config.source,
checkpoint_config.loader_type,
)

# Extract parameters from checkpoint config
loader_kwargs = {}
if hasattr(checkpoint_config, "strict"):
loader_kwargs["strict"] = checkpoint_config.strict
if hasattr(checkpoint_config, "skip_mismatched"):
loader_kwargs["skip_mismatched"] = checkpoint_config.skip_mismatched

return load_model_from_checkpoint(
model=model,
checkpoint_source=checkpoint_config.source,
loader_type=checkpoint_config.loader_type,
**loader_kwargs,
)

@rank_zero_only
def _get_mlflow_run_id(self) -> str:
run_id = self.mlflow_logger.run_id
Expand Down
Loading
Loading