Skip to content

Conversation

@JesperDramsch
Copy link
Member

@JesperDramsch JesperDramsch commented Aug 6, 2025

Description

This PR introduces a comprehensive checkpoint loading system that supports loading model checkpoints from various sources including local files, S3, HTTP/HTTPS, Google Cloud Storage, and Azure Blob Storage. The implementation provides a modular, extensible architecture that separates checkpoint retrieval from model loading strategies and includes robust error handling and validation.

What problem does this change solve?

This change addresses several key problems in the current checkpoint loading workflow:

New Feature: Adds multi-source checkpoint loading capabilities with support for:

  • Local filesystem: Standard file path loading
  • HTTP/HTTPS: Download from web URLs
  • Amazon S3: S3 bucket support with boto3
  • Google Cloud Storage: GCS support with google-cloud-storage
  • Azure Blob Storage: Azure support with azure-storage-blob

Enhanced Model Loading Strategies:

  • StandardModelLoader: Standard PyTorch Lightning checkpoint loading
  • TransferLearningModelLoader: Handle size mismatches and partial weight loading
  • WeightsOnlyModelLoader: Load only model weights, skip optimiser states

Architecture Improvements:

  • Separation of concerns: Checkpoint retrieval vs model loading logic
  • Extensibility: Easy to add new sources or loading strategies via registry pattern
  • Testability: Mock-friendly design with clear interfaces
  • Error handling: Informative error messages and graceful failures
  • Performance: Efficient temporary file handling and cleanup

What issue or task does this change relate to?

Closes #458

Should be updated and merged after #410 / #422

Additional notes

Implementation Details

The PR introduces two main components:

CheckpointLoaders (checkpoint_loaders.py):

  • Abstract base class CheckpointLoader with pluggable implementations
  • LocalCheckpointLoader for filesystem access
  • RemoteCheckpointLoader with multi-cloud support
  • Automatic source detection and loader selection
  • Comprehensive error handling with clear messages

ModelLoading (model_loading.py):

  • Abstract base class ModelLoader for different loading strategies
  • Standard Lightning checkpoint compatibility
  • Transfer learning with size mismatch handling
  • Metadata preservation and data_indices validation
  • Registry pattern for extensibility

Usage Examples

# Load from any supported source
checkpoint = load_checkpoint_from_source("s3://bucket/model.ckpt")
checkpoint = load_checkpoint_from_source("https://example.com/model.ckpt")
checkpoint = load_checkpoint_from_source("/local/path/model.ckpt")

# Load model with transfer learning
model = load_model_from_checkpoint(
    model=my_model,
    checkpoint_source="s3://bucket/pretrained.ckpt",
    loader_type="transfer_learning",
    skip_mismatched=True
)

# Register custom loaders
registry.register("custom_source", CustomCheckpointLoader())
registry.register("custom_strategy", CustomModelLoader())

Testing and Quality Assurance

  • 100% line coverage across all loaders and strategies
  • Integration tests for multi-source scenarios
  • Error handling tests for edge cases and failures
  • Mock-based testing to avoid external dependencies
# Run all checkpoint loading tests
pytest training/tests/utils/test_checkpoint_loaders.py -v
pytest training/tests/utils/test_model_loading.py -v

# Test with different sources (requires cloud credentials)
pytest training/tests/utils/ -k "remote" --cloud-tests

Compatibility and Dependencies

  • Dependencies: Optional cloud dependencies (boto3, google-cloud-storage, azure-storage-blob)
  • Type safety: Full type annotations and strict typing
  • Code quality: Passes all linting rules (ruff, black, isort)

Future Extensions

This architecture makes it easy to add:

  • Additional cloud providers (OCI, DigitalOcean Spaces, etc.)
  • Custom authentication mechanisms
  • Caching and optimisation strategies
  • Checkpoint validation and integrity checks
  • Progress tracking for large downloads

As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/

By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.


📚 Documentation preview 📚: https://anemoi-training--464.org.readthedocs.build/en/464/


📚 Documentation preview 📚: https://anemoi-graphs--464.org.readthedocs.build/en/464/


📚 Documentation preview 📚: https://anemoi-models--464.org.readthedocs.build/en/464/

Refactor restart of training [warm start, forked runs, restarts]

Closes #458
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Aug 6, 2025
@JesperDramsch JesperDramsch self-assigned this Aug 6, 2025
@JesperDramsch JesperDramsch added this to the Fine-Tuning milestone Aug 6, 2025
@ssmmnn11
Copy link
Member

ssmmnn11 commented Aug 6, 2025

Great stuff! Do you plan to have config / hydra options, so we can start a model and then specify exactly where the checkpoint should come from? Or would this be out of scope?

@JesperDramsch
Copy link
Member Author

Yes exactly, so it's much easier to say "hey I have this checkpoint in an s3 bucket" or similar.

…e loading system

- Add test suite for multi-source checkpoint loading (local, S3, HTTP, GCS, Azure)
- Add test suite for model loading strategies (standard, transfer learning, weights-only)
- Test error handling for network failures and missing files
- Test registry pattern functionality for both loaders
- Add extensive documentation explaining test organization and principles

The tests ensure robustness of the extensible checkpoint loading system across
different sources and loading strategies, with proper error handling and validation.
@mchantry
Copy link
Member

@JesperDramsch thanks for the contribution. Is there an active use-case for the GCS loading? There is an active user-base for azure, so I think implementing the azure one would be a nice addition, but wonder if you could save yourself some work by leaving GCS for when a use-case arises?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: To be triaged

Development

Successfully merging this pull request may close these issues.

Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow)

4 participants