-
Notifications
You must be signed in to change notification settings - Fork 20
feat: Add checkpoint update command #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4c4b3cb
db7627b
d95dad4
71a8f25
a589127
751fe89
1cdb28a
f7b3e6b
1845c6a
79039a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
.. _redefine_graph-command: | ||
|
||
Redefine Graph Command | ||
====================== | ||
|
||
With this command, you can redefine the graph of a checkpoint file. | ||
This is useful when you want to change / reconfigure the local-domain of a model, or rebuild with a new graph. | ||
|
||
We should caution that such transfer of the model from one graph to | ||
another is not guaranteed to lead to good results. Still, it is a | ||
powerful tool to explore generalisability of the model or to test | ||
performance before starting fine tuning through transfer learning. | ||
|
||
This will create a new checkpoint file with the updated graph, and optionally save the graph to a file. | ||
|
||
Subcommands allow for a graph to be made from a lat/lon coordinate file, bounding box, or from a defined graph config. | ||
|
||
********* | ||
Usage | ||
********* | ||
|
||
.. code-block:: bash | ||
|
||
% anemoi-inference redefine_graph --help | ||
|
||
Redefine the graph of a checkpoint file. | ||
|
||
positional arguments: | ||
path Path to the checkpoint. | ||
|
||
options: | ||
-h, --help show this help message and exit | ||
-g GRAPH, --graph GRAPH | ||
Path to graph file to use | ||
-y GRAPH_CONFIG, --graph-config GRAPH_CONFIG | ||
Path to graph config to use | ||
-ll LATLON, --latlon LATLON | ||
Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes. | ||
-c COORDS COORDS COORDS COORDS COORDS, --coords COORDS COORDS COORDS COORDS COORDS | ||
Coordinates, (North West South East Resolution). | ||
-gr GLOBAL_RESOLUTION, --global_-esolution GLOBAL_RESOLUTION | ||
Global grid resolution required with --coords, (e.g. n320, o96). | ||
--save-graph SAVE_GRAPH | ||
Path to save the updated graph. | ||
--output OUTPUT Path to save the updated checkpoint. | ||
|
||
|
||
********* | ||
Examples | ||
********* | ||
|
||
Here are some examples of how to use the `redefine_graph` command: | ||
|
||
#. Using a graph file: | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference redefine_graph path/to/checkpoint --graph path/to/graph | ||
|
||
#. Using a graph configuration: | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference redefine_graph path/to/checkpoint --graph_config path/to/graph_config | ||
|
||
.. note:: | ||
The configuration of the existing graph can be found using: | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference metadata path/to/checkpoint -get config.graph ----yaml | ||
|
||
#. Using latitude/longitude coordinates: | ||
This lat lon file should be a numpy file of shape (N, 2) with latitudes and longitudes. | ||
|
||
It can be easily made from a list of coordinates as follows: | ||
|
||
.. code-block:: python | ||
|
||
import numpy as np | ||
coords = np.array(np.meshgrid(latitudes, longitudes)).T.reshape(-1, 2) | ||
np.save('path/to/latlon.npy', coords) | ||
|
||
Once created, | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference redefine_graph path/to/checkpoint --latlon path/to/latlon.npy | ||
|
||
#. Using bounding box coordinates: | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference redefine_graph path/to/checkpoint --coords North West South East Resolution | ||
|
||
i.e. | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 | ||
|
||
|
||
All examples can optionally save the updated graph and checkpoint using the `--save-graph` and `--output` options. | ||
|
||
*************************** | ||
Complete Inference Example | ||
*************************** | ||
|
||
For this example we will redefine a checkpoint using a bounding box and then run inference | ||
|
||
|
||
Redefine the checkpoint | ||
----------------------- | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 --save-graph path/to/updated_graph --output path/to/updated_checkpoint | ||
|
||
Create the inference config | ||
--------------------------- | ||
|
||
If you have an input file of the expected shape handy use it in place of the input block, here we will show | ||
how to use MARS to handle the regridding. | ||
|
||
.. note:: | ||
Using the `anemoi-plugins-ecmwf-inference <https://github.com/ecmwf/anemoi-plugins-ecmwf>`_ package, preprocessors are available which can handle the regridding for you from other sources. | ||
|
||
.. code-block:: yaml | ||
|
||
checkpoint: path/to/updated_checkpoint | ||
date: -2 | ||
|
||
input: | ||
cutout: | ||
lam_0: | ||
mars: | ||
grid: 0.1/0.1 # RESOLUTION WE SET | ||
area: 30.0/-10.0/20.0/0.0 # BOUNDING BOX WE SET, N W S E | ||
global: | ||
mars: | ||
grid: n320 # GLOBAL RESOLUTION WE SET | ||
|
||
|
||
Run inference | ||
----------------- | ||
|
||
.. code-block:: bash | ||
|
||
anemoi-inference run path/to/updated_checkpoint | ||
|
||
|
||
********** | ||
Reference | ||
********** | ||
|
||
.. argparse:: | ||
:module: anemoi.inference.__main__ | ||
:func: create_parser | ||
:prog: anemoi-inference | ||
:path: redefine_graph |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# (C) Copyright 2025- Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
|
||
import logging | ||
from argparse import ArgumentParser | ||
from argparse import Namespace | ||
from pathlib import Path | ||
|
||
from . import Command | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
def check_redefine_imports(): | ||
"""Check if required packages are installed.""" | ||
required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be useful to add these packages to the optional dependencies in the |
||
from importlib.util import find_spec | ||
|
||
for package in required_packages: | ||
if find_spec(package) is None: | ||
raise ImportError(f"{package!r} is required for this command.") | ||
|
||
|
||
def format_namespace_as_str(namespace: Namespace) -> str: | ||
"""Format an argparse Namespace object as command-line arguments.""" | ||
args = [] | ||
|
||
for key, value in vars(namespace).items(): | ||
if key == "command": | ||
continue | ||
if value is None: | ||
continue | ||
|
||
# Convert underscores to hyphens for command line format | ||
arg_name = f"--{key.replace('_', '-')}" | ||
|
||
if isinstance(value, bool): | ||
if value: | ||
args.append(arg_name) | ||
elif isinstance(value, list): | ||
args.append(f"{arg_name} {' '.join(map(str, value))}") | ||
else: | ||
args.extend([arg_name, str(value)]) | ||
|
||
return " ".join(args) | ||
|
||
|
||
class RedefineGraphCmd(Command): | ||
"""Redefine the graph of a checkpoint file.""" | ||
|
||
def add_arguments(self, command_parser: ArgumentParser) -> None: | ||
"""Add arguments to the command parser. | ||
|
||
Parameters | ||
---------- | ||
command_parser : ArgumentParser | ||
The argument parser to which the arguments will be added. | ||
""" | ||
command_parser.description = "Redefine the graph of a checkpoint file. If using coordinate specifications, assumes the input to the local domain is already regridded." | ||
command_parser.add_argument("path", help="Path to the checkpoint.") | ||
|
||
group = command_parser.add_mutually_exclusive_group(required=True) | ||
|
||
group.add_argument("-g", "--graph", type=Path, help="Path to graph file to use") | ||
group.add_argument("-y", "--graph-config", type=Path, help="Path to graph config to use") | ||
group.add_argument( | ||
"-ll", | ||
"--latlon", | ||
type=Path, | ||
help="Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes.", | ||
) | ||
group.add_argument("-c", "--coords", type=str, help="Coordinates, (North West South East Resolution).", nargs=5) | ||
|
||
command_parser.add_argument( | ||
"-gr", | ||
"--global-resolution", | ||
type=str, | ||
help="Global grid resolution required with --coords, (e.g. n320, o96).", | ||
) | ||
|
||
command_parser.add_argument("--save-graph", type=str, help="Path to save the updated graph.", default=None) | ||
command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None) | ||
|
||
def run(self, args: Namespace) -> None: | ||
"""Run the redefine_graph command. | ||
|
||
Parameters | ||
---------- | ||
args : Namespace | ||
The arguments passed to the command. | ||
""" | ||
from anemoi.inference.utils.redefine_graph import create_graph_from_config | ||
from anemoi.inference.utils.redefine_graph import get_coordinates_from_file | ||
from anemoi.inference.utils.redefine_graph import get_coordinates_from_mars_request | ||
from anemoi.inference.utils.redefine_graph import load_graph_from_file | ||
from anemoi.inference.utils.redefine_graph import make_graph_from_coordinates | ||
from anemoi.inference.utils.redefine_graph import update_checkpoint | ||
|
||
check_redefine_imports() | ||
|
||
import torch | ||
from anemoi.utils.checkpoints import load_metadata | ||
from anemoi.utils.checkpoints import save_metadata | ||
|
||
path = Path(args.path) | ||
|
||
# Load checkpoint metadata and supporting arrays | ||
metadata, supporting_arrays = load_metadata(str(path), supporting_arrays=True) | ||
|
||
# Add command to history | ||
metadata.setdefault("history", []) | ||
metadata["history"].append(f"anemoi-inference redefine_graph {format_namespace_as_str(args)}") | ||
|
||
# Create or load the graph | ||
if args.graph is not None: | ||
graph = load_graph_from_file(args.graph) | ||
elif args.graph_config is not None: | ||
graph = create_graph_from_config(args.graph_config) | ||
else: | ||
# Generate graph from coordinates | ||
LOG.info("Generating graph from coordinates...") | ||
|
||
# Get coordinates based on input type | ||
if args.latlon is not None: | ||
local_lats, local_lons = get_coordinates_from_file(args.latlon) | ||
elif args.coords is not None: | ||
local_lats, local_lons = get_coordinates_from_mars_request(args.coords) | ||
else: | ||
raise ValueError("No valid coordinates found.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't be reached if argparse is working correctly but one more check doesn't hurt. |
||
|
||
metadata, supporting_arrays, graph = make_graph_from_coordinates( | ||
local_lats, local_lons, args.global_resolution, metadata, supporting_arrays | ||
) | ||
|
||
# Save graph if requested | ||
if args.save_graph is not None: | ||
torch.save(graph, args.save_graph) | ||
LOG.info("Saved updated graph to %s", args.save_graph) | ||
|
||
# Update checkpoint | ||
LOG.info("Updating checkpoint...") | ||
model = torch.load(path, weights_only=False, map_location=torch.device("cpu")) | ||
model = update_checkpoint(model, metadata, graph) | ||
|
||
# Save updated checkpoint | ||
model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}" | ||
torch.save(model, model_path) | ||
|
||
save_metadata( | ||
model_path, | ||
metadata=metadata, | ||
supporting_arrays=supporting_arrays, | ||
) | ||
|
||
LOG.info("Updated checkpoint saved to %s", model_path) | ||
|
||
|
||
command = RedefineGraphCmd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you chose to make the output optional? Shouldn't it be a mandatory positional argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For arbitary convenience as it will add a suffix by default and still save it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable. I would clarify this in the docs.