From 4c4b3cb0cedb5cbcb9802006a38b199aeb5b5684 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 25 Sep 2025 15:54:31 +0000 Subject: [PATCH 1/9] Initial command and docs --- docs/cli/introduction.rst | 1 + docs/cli/redefine.rst | 160 +++++++++++ docs/index.rst | 1 + src/anemoi/inference/commands/redefine.py | 329 ++++++++++++++++++++++ 4 files changed, 491 insertions(+) create mode 100644 docs/cli/redefine.rst create mode 100644 src/anemoi/inference/commands/redefine.py diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst index 1fa81f73..3c33ca42 100644 --- a/docs/cli/introduction.rst +++ b/docs/cli/introduction.rst @@ -19,3 +19,4 @@ The commands are: - :ref:`Validate Command ` - :ref:`Patch Command ` - :ref:`Requests Command ` +- :ref:`Redefine Command ` diff --git a/docs/cli/redefine.rst b/docs/cli/redefine.rst new file mode 100644 index 00000000..cfca1305 --- /dev/null +++ b/docs/cli/redefine.rst @@ -0,0 +1,160 @@ +.. _redefine-command: + +Redefine Command +=============== + +With this command, you can redefine the graph of a checkpoint file. +This is useful when you want to change / reconfigure the local-domain of a model, or rebuild with a new graph. + +We should caution that such transfer of the model from one graph to +another is not guaranteed to lead to good results. Still, it is a +powerful tool to explore generalisability of the model or to test +performance before starting fine tuning through transfer learning. + +This will create a new checkpoint file with the updated graph, and optionally save the graph to a file. + +Subcommands allow for a graph to be made from a lat/lon coordinate file, bounding box, or from a defined graph config. + +********* + Usage +********* + +.. code-block:: bash + + % anemoi-inference redefine --help + + Redefine the graph of a checkpoint file. + + positional arguments: + path Path to the checkpoint. + + options: + -h, --help show this help message and exit + -g GRAPH, --graph GRAPH + Path to graph file to use + -y GRAPH_CONFIG, --graph_config GRAPH_CONFIG + Path to graph config to use + -ll LATLON, --latlon LATLON + Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes. + -c COORDS COORDS COORDS COORDS COORDS, --coords COORDS COORDS COORDS COORDS COORDS + Coordinates, (North West South East Resolution). + -gr GLOBAL_RESOLUTION, --global_resolution GLOBAL_RESOLUTION + Global grid resolution required with --coords, (e.g. n320, o96). + --save-graph SAVE_GRAPH + Path to save the updated graph. + --output OUTPUT Path to save the updated checkpoint. + + +********* +Examples +********* + +Here are some examples of how to use the `redefine` command: + +#. Using a graph file: + + .. code-block:: bash + + anemoi-inference redefine path/to/checkpoint --graph path/to/graph + +#. Using a graph configuration: + + .. code-block:: bash + + anemoi-inference redefine path/to/checkpoint --graph_config path/to/graph_config + + .. note:: + The configuration of the existing graph can be found using: + + .. code-block:: bash + + anemoi-inference metadata path/to/checkpoint -get config.graph ----yaml + +#. Using latitude/longitude coordinates: + This lat lon file should be a numpy file of shape (N, 2) with latitudes and longitudes. + + It can be easily made from a list of coordinates as follows: + + .. code-block:: python + + import numpy as np + coords = np.array(np.meshgrid(latitudes, longitudes)).T.reshape(-1, 2) + np.save('path/to/latlon.npy', coords) + + Once created, + + .. code-block:: bash + + anemoi-inference redefine path/to/checkpoint --latlon path/to/latlon.npy + +#. Using bounding box coordinates: + + .. code-block:: bash + + anemoi-inference redefine path/to/checkpoint --coords North West South East Resolution + + i.e. + + .. code-block:: bash + + anemoi-inference redefine path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 + + +All examples can optionally save the updated graph and checkpoint using the `--save-graph` and `--output` options. + +*************************** +Complete Inference Example +*************************** + +For this example we will redefine a checkpoint using a bounding box and then run inference + + +Redefine the checkpoint +----------------------- + +.. code-block:: bash + + anemoi-inference redefine path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 --save-graph path/to/updated_graph --output path/to/updated_checkpoint + +Create the inference config +--------------------------- + +If you have an input file of the expected shape handy use it in place of the input block, here we will show +how to use MARS to handle the regridding. + +.. note:: + Using the `anemoi-plugins-ecmwf-inference `_ package, preprocessors are available which can handle the regridding for you from other sources. + +.. code-block:: yaml + + checkpoint: path/to/updated_checkpoint + date: -2 + + input: + cutout: + lam_0: + mars: + grid: 0.1/0.1 # RESOLUTION WE SET + area: 30.0/-10.0/20.0/0.0 # BOUNDING BOX WE SET, N W S E + global: + mars: + grid: n320 # GLOBAL RESOLUTION WE SET + + +Run inference +----------------- + +.. code-block:: bash + + anemoi-inference run path/to/updated_checkpoint + + +********** +Reference +********** + +.. argparse:: + :module: anemoi.inference.__main__ + :func: create_parser + :prog: anemoi-inference + :path: redefine diff --git a/docs/index.rst b/docs/index.rst index ff24cab3..c64a3b8b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -153,6 +153,7 @@ You may also have to install pandoc on MacOS: cli/inspect cli/patch cli/requests + cli/redefine .. toctree:: :maxdepth: 1 diff --git a/src/anemoi/inference/commands/redefine.py b/src/anemoi/inference/commands/redefine.py new file mode 100644 index 00000000..d0cf2906 --- /dev/null +++ b/src/anemoi/inference/commands/redefine.py @@ -0,0 +1,329 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from argparse import ArgumentParser +from argparse import Namespace +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING + +from . import Command + +LOG = logging.getLogger(__name__) + +if TYPE_CHECKING: + import numpy as np + from torch_geometric.data import HeteroData + + +def format_namespace_as_str(namespace: Namespace) -> str: + """Format an argparse Namespace object as command-line arguments.""" + args = [] + + for key, value in vars(namespace).items(): + if key == "command": + continue + if value is None: + continue + + # Convert underscores to hyphens for command line format + arg_name = f"--{key.replace('_', '-')}" + + if isinstance(value, bool): + if value: + args.append(arg_name) + elif isinstance(value, list): + args.append(f"{arg_name} {' '.join(map(str, value))}") + else: + args.extend([arg_name, str(value)]) + + return " ".join(args) + + +def update_state_dict( + model, + external_state_dict, + keywords: list[str] | None = None, + ignore_mismatched_layers=False, + ignore_additional_layers=False, +): + """Update the model's state_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered.""" + + LOG.info("Updating model state dictionary.") + + keywords = keywords or [] + + # select relevant part of external_state_dict + reduced_state_dict = {k: v for k, v in external_state_dict.items() if any(kw in k for kw in keywords)} + model_state_dict = model.state_dict() + + # check layers and their shapes + for key in list(reduced_state_dict): + if key not in model_state_dict: + if ignore_additional_layers: + LOG.info("Skipping injection of %s, which is not in the model.", key) + del reduced_state_dict[key] + else: + raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.") + elif reduced_state_dict[key].shape != model_state_dict[key].shape: + if ignore_mismatched_layers: + LOG.info("Skipping injection of %s due to shape mismatch.", key) + LOG.info("Model shape: %s", model_state_dict[key].shape) + LOG.info("Provided shape: %s", reduced_state_dict[key].shape) + del reduced_state_dict[key] + else: + raise AssertionError(f"Mismatch in shape of {key}. Consider setting 'ignore_mismatched_layers = True'.") + + model.load_state_dict(reduced_state_dict, strict=False) + return model + + +class RedefineCmd(Command): + """Redefine the graph of a checkpoint file.""" + + def add_arguments(self, command_parser: ArgumentParser) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : ArgumentParser + The argument parser to which the arguments will be added. + """ + command_parser.description = "Redefine the graph of a checkpoint file." + command_parser.add_argument("path", help="Path to the checkpoint.") + + group = command_parser.add_mutually_exclusive_group(required=True) + + group.add_argument("-g", "--graph", type=Path, help="Path to graph file to use") + group.add_argument("-y", "--graph_config", type=Path, help="Path to graph config to use") + group.add_argument( + "-ll", + "--latlon", + type=Path, + help="Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes.", + ) + group.add_argument("-c", "--coords", type=str, help="Coordinates, (North West South East Resolution).", nargs=5) + + command_parser.add_argument( + "-gr", + "--global_resolution", + type=str, + help="Global grid resolution required with --coords, (e.g. n320, o96).", + ) + + command_parser.add_argument("--save-graph", type=str, help="Path to save the updated graph.", default=None) + command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None) + + def _get_coordinates(self, args: Namespace) -> tuple["np.ndarray", "np.ndarray"]: + """Get coordinates from command line arguments. + + Either from files or from coords which are extracted from a MARS request. + """ + import numpy as np + + if args.latlon is not None: + latlon = np.load(args.latlon) + return latlon[:, 0], latlon[:, 1] + + elif args.coords is not None: + import earthkit.data as ekd + + area = [args.coords[0], args.coords[1], args.coords[2], args.coords[3]] + + resolution = str(args.coords[4]) + if resolution.isdigit(): + resolution = f"{resolution}/{resolution}" + + ds = ekd.from_source( + "mars", + { + "AREA": area, + "GRID": f"{resolution}", + "param": "2t", + "date": -2, + "stream": "oper", + "type": "an", + "levtype": "sfc", + }, + ) + return ds[0].grid_points() # type: ignore + raise ValueError("No valid coordinates found.") + + def _combine_nodes( + self, latitudes: "np.ndarray", longitudes: "np.ndarray", global_grid: str + ) -> tuple["np.ndarray", "np.ndarray", "np.ndarray", "np.ndarray"]: + """Combine lat/lon nodes with global grid if specified. + + Returns lats, lons, local_mask, global_mask + """ + import numpy as np + from anemoi.datasets.grids import cutout_mask + from anemoi.utils.grids import grids + + global_points = grids(global_grid) + + global_removal_mask = cutout_mask( + latitudes, longitudes, global_points["latitudes"], global_points["longitudes"] + ) + lats = np.concatenate([latitudes, global_points["latitudes"][global_removal_mask]]) + lons = np.concatenate([longitudes, global_points["longitudes"][global_removal_mask]]) + local_mask = np.array([True] * len(latitudes) + [False] * sum(global_removal_mask), dtype=bool) + + return lats, lons, local_mask, global_removal_mask + + def _make_data_graph( + self, + lats: "np.ndarray", + lons: "np.ndarray", + local_mask: "np.ndarray", + global_mask: "np.ndarray", + *, + mask_attr_name: str = "cutout", + attrs, + ) -> "HeteroData": + """Make a data graph with the given lat/lon nodes and attributes.""" + import torch + from anemoi.graphs.nodes import LatLonNodes + from torch_geometric.data import HeteroData + + graph = LatLonNodes(lats, lons, name="data").update_graph(HeteroData(), attrs_config=attrs) + graph["data"][mask_attr_name] = torch.from_numpy(local_mask) + return graph + + def _make_graph_from_coordinates( + self, args: Namespace, metadata: dict, supporting_arrays: dict + ) -> tuple[dict, dict, "HeteroData"]: + """Make a graph from coordinates given in args.""" + import numpy as np + + if args.global_resolution is None: + raise ValueError("Global resolution must be specified when generating graph from coordinates.") + + local_lats, local_lons = self._get_coordinates(args) + LOG.info("Coordinates loaded. Number of local nodes: %d", len(local_lats)) + lats, lons, local_mask, global_mask = self._combine_nodes(local_lats, local_lons, args.global_resolution) + + graph_config = deepcopy(metadata["config"]["graph"]) + data_graph = graph_config["nodes"].pop("data") + + from anemoi.graphs.create import GraphCreator + from anemoi.utils.config import DotDict + + creator = GraphCreator(DotDict(graph_config)) + + LOG.info("Updating graph...") + LOG.debug("Using %r", graph_config) + + def nested_get(d, keys, default=None): + for key in keys: + d = d.get(key, {}) + return d or default + + mask_attr_name = nested_get(graph_config, ["nodes", "hidden", "node_builder", "mask_attr_name"], "cutout") + + data_graph = self._make_data_graph( + lats, lons, local_mask, global_mask, mask_attr_name=mask_attr_name, attrs=data_graph.get("attrs", None) + ) + LOG.info("Created data graph with %d nodes.", data_graph.num_nodes) + graph = creator.update_graph(data_graph) + + supporting_arrays[f"global/{mask_attr_name}"] = global_mask + supporting_arrays[f"lam_0/{mask_attr_name}"] = np.array([True] * len(local_lats)) + + supporting_arrays["latitudes"] = lats + supporting_arrays["longitudes"] = lons + supporting_arrays["grid_indices"] = np.ones(global_mask.shape, dtype=np.int64) + + return metadata, supporting_arrays, graph + + def _update_checkpoint(self, model, metadata, graph: "HeteroData"): + from anemoi.utils.config import DotDict + + state_dict_ckpt = deepcopy(model.state_dict()) + + # rebuild the model with the new graph + model.graph_data = graph + model.config = DotDict(metadata).config + model._build_model() + + # reinstate the weights, biases and normalizer from the checkpoint + # reinstating the normalizer is necessary for checkpoints that were created + # using transfer learning, where the statistics as stored in the checkpoint + # do not match the statistics used to build the normalizer in the checkpoint. + model_instance = update_state_dict(model, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"]) + + return model_instance + + def _check_imports(self): + """Check if required packages are installed.""" + required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"] + from importlib.util import find_spec + + for package in required_packages: + if find_spec(package) is None: + raise ImportError(f"{package!r} is required for this command.") + + def run(self, args: Namespace) -> None: + """Run the redefine command. + + Parameters + ---------- + args : Namespace + The arguments passed to the command. + """ + self._check_imports() + + import torch + from anemoi.utils.checkpoints import load_metadata + from anemoi.utils.checkpoints import save_metadata + + path = Path(args.path) + + metadata, supporting_arrays = load_metadata(str(path), supporting_arrays=True) + + metadata.setdefault("history", []) + metadata["history"].append(f"anemoi-inference redefine {format_namespace_as_str(args)}") + + if args.graph is not None: + LOG.info("Loading graph from %s", args.graph) + graph = torch.load(args.graph) + else: + if args.graph_config is not None: + from anemoi.graphs.create import GraphCreator + from torch_geometric.data import HeteroData + + graph = GraphCreator(args.graph_config).update_graph(HeteroData()) + else: + LOG.info("Generating graph from coordinates...") + metadata, supporting_arrays, graph = self._make_graph_from_coordinates( + args, metadata, supporting_arrays + ) + + if args.save_graph is not None: + torch.save(graph, args.save_graph) + LOG.info("Saved updated graph to %s", args.save_graph) + + LOG.info("Updating checkpoint...") + + model = torch.load(str(path), weights_only=False, map_location=torch.device("cpu")) + model = self._update_checkpoint(model, metadata, graph) + model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}" + + torch.save(model, model_path) + + save_metadata( + model_path, + metadata=metadata, + supporting_arrays=supporting_arrays, + ) + + +command = RedefineCmd From db7627b4cd609b0a14301f7f2edecef7ae275ccb Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 30 Sep 2025 08:01:15 +0000 Subject: [PATCH 2/9] Move to seperate utils module --- src/anemoi/inference/commands/redefine.py | 258 +++------------- .../inference/runners/external_graph.py | 95 ++---- src/anemoi/inference/utils/__init__.py | 8 + src/anemoi/inference/utils/redefine.py | 280 ++++++++++++++++++ 4 files changed, 359 insertions(+), 282 deletions(-) create mode 100644 src/anemoi/inference/utils/__init__.py create mode 100644 src/anemoi/inference/utils/redefine.py diff --git a/src/anemoi/inference/commands/redefine.py b/src/anemoi/inference/commands/redefine.py index d0cf2906..0f159986 100644 --- a/src/anemoi/inference/commands/redefine.py +++ b/src/anemoi/inference/commands/redefine.py @@ -11,17 +11,21 @@ import logging from argparse import ArgumentParser from argparse import Namespace -from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING from . import Command LOG = logging.getLogger(__name__) -if TYPE_CHECKING: - import numpy as np - from torch_geometric.data import HeteroData + +def check_redefine_imports(): + """Check if required packages are installed.""" + required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"] + from importlib.util import find_spec + + for package in required_packages: + if find_spec(package) is None: + raise ImportError(f"{package!r} is required for this command.") def format_namespace_as_str(namespace: Namespace) -> str: @@ -48,44 +52,6 @@ def format_namespace_as_str(namespace: Namespace) -> str: return " ".join(args) -def update_state_dict( - model, - external_state_dict, - keywords: list[str] | None = None, - ignore_mismatched_layers=False, - ignore_additional_layers=False, -): - """Update the model's state_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered.""" - - LOG.info("Updating model state dictionary.") - - keywords = keywords or [] - - # select relevant part of external_state_dict - reduced_state_dict = {k: v for k, v in external_state_dict.items() if any(kw in k for kw in keywords)} - model_state_dict = model.state_dict() - - # check layers and their shapes - for key in list(reduced_state_dict): - if key not in model_state_dict: - if ignore_additional_layers: - LOG.info("Skipping injection of %s, which is not in the model.", key) - del reduced_state_dict[key] - else: - raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.") - elif reduced_state_dict[key].shape != model_state_dict[key].shape: - if ignore_mismatched_layers: - LOG.info("Skipping injection of %s due to shape mismatch.", key) - LOG.info("Model shape: %s", model_state_dict[key].shape) - LOG.info("Provided shape: %s", reduced_state_dict[key].shape) - del reduced_state_dict[key] - else: - raise AssertionError(f"Mismatch in shape of {key}. Consider setting 'ignore_mismatched_layers = True'.") - - model.load_state_dict(reduced_state_dict, strict=False) - return model - - class RedefineCmd(Command): """Redefine the graph of a checkpoint file.""" @@ -97,7 +63,7 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: command_parser : ArgumentParser The argument parser to which the arguments will be added. """ - command_parser.description = "Redefine the graph of a checkpoint file." + command_parser.description = "Redefine the graph of a checkpoint file. If using coordinate specifications, assumes the input to the local domain is already regridded." command_parser.add_argument("path", help="Path to the checkpoint.") group = command_parser.add_mutually_exclusive_group(required=True) @@ -122,155 +88,6 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: command_parser.add_argument("--save-graph", type=str, help="Path to save the updated graph.", default=None) command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None) - def _get_coordinates(self, args: Namespace) -> tuple["np.ndarray", "np.ndarray"]: - """Get coordinates from command line arguments. - - Either from files or from coords which are extracted from a MARS request. - """ - import numpy as np - - if args.latlon is not None: - latlon = np.load(args.latlon) - return latlon[:, 0], latlon[:, 1] - - elif args.coords is not None: - import earthkit.data as ekd - - area = [args.coords[0], args.coords[1], args.coords[2], args.coords[3]] - - resolution = str(args.coords[4]) - if resolution.isdigit(): - resolution = f"{resolution}/{resolution}" - - ds = ekd.from_source( - "mars", - { - "AREA": area, - "GRID": f"{resolution}", - "param": "2t", - "date": -2, - "stream": "oper", - "type": "an", - "levtype": "sfc", - }, - ) - return ds[0].grid_points() # type: ignore - raise ValueError("No valid coordinates found.") - - def _combine_nodes( - self, latitudes: "np.ndarray", longitudes: "np.ndarray", global_grid: str - ) -> tuple["np.ndarray", "np.ndarray", "np.ndarray", "np.ndarray"]: - """Combine lat/lon nodes with global grid if specified. - - Returns lats, lons, local_mask, global_mask - """ - import numpy as np - from anemoi.datasets.grids import cutout_mask - from anemoi.utils.grids import grids - - global_points = grids(global_grid) - - global_removal_mask = cutout_mask( - latitudes, longitudes, global_points["latitudes"], global_points["longitudes"] - ) - lats = np.concatenate([latitudes, global_points["latitudes"][global_removal_mask]]) - lons = np.concatenate([longitudes, global_points["longitudes"][global_removal_mask]]) - local_mask = np.array([True] * len(latitudes) + [False] * sum(global_removal_mask), dtype=bool) - - return lats, lons, local_mask, global_removal_mask - - def _make_data_graph( - self, - lats: "np.ndarray", - lons: "np.ndarray", - local_mask: "np.ndarray", - global_mask: "np.ndarray", - *, - mask_attr_name: str = "cutout", - attrs, - ) -> "HeteroData": - """Make a data graph with the given lat/lon nodes and attributes.""" - import torch - from anemoi.graphs.nodes import LatLonNodes - from torch_geometric.data import HeteroData - - graph = LatLonNodes(lats, lons, name="data").update_graph(HeteroData(), attrs_config=attrs) - graph["data"][mask_attr_name] = torch.from_numpy(local_mask) - return graph - - def _make_graph_from_coordinates( - self, args: Namespace, metadata: dict, supporting_arrays: dict - ) -> tuple[dict, dict, "HeteroData"]: - """Make a graph from coordinates given in args.""" - import numpy as np - - if args.global_resolution is None: - raise ValueError("Global resolution must be specified when generating graph from coordinates.") - - local_lats, local_lons = self._get_coordinates(args) - LOG.info("Coordinates loaded. Number of local nodes: %d", len(local_lats)) - lats, lons, local_mask, global_mask = self._combine_nodes(local_lats, local_lons, args.global_resolution) - - graph_config = deepcopy(metadata["config"]["graph"]) - data_graph = graph_config["nodes"].pop("data") - - from anemoi.graphs.create import GraphCreator - from anemoi.utils.config import DotDict - - creator = GraphCreator(DotDict(graph_config)) - - LOG.info("Updating graph...") - LOG.debug("Using %r", graph_config) - - def nested_get(d, keys, default=None): - for key in keys: - d = d.get(key, {}) - return d or default - - mask_attr_name = nested_get(graph_config, ["nodes", "hidden", "node_builder", "mask_attr_name"], "cutout") - - data_graph = self._make_data_graph( - lats, lons, local_mask, global_mask, mask_attr_name=mask_attr_name, attrs=data_graph.get("attrs", None) - ) - LOG.info("Created data graph with %d nodes.", data_graph.num_nodes) - graph = creator.update_graph(data_graph) - - supporting_arrays[f"global/{mask_attr_name}"] = global_mask - supporting_arrays[f"lam_0/{mask_attr_name}"] = np.array([True] * len(local_lats)) - - supporting_arrays["latitudes"] = lats - supporting_arrays["longitudes"] = lons - supporting_arrays["grid_indices"] = np.ones(global_mask.shape, dtype=np.int64) - - return metadata, supporting_arrays, graph - - def _update_checkpoint(self, model, metadata, graph: "HeteroData"): - from anemoi.utils.config import DotDict - - state_dict_ckpt = deepcopy(model.state_dict()) - - # rebuild the model with the new graph - model.graph_data = graph - model.config = DotDict(metadata).config - model._build_model() - - # reinstate the weights, biases and normalizer from the checkpoint - # reinstating the normalizer is necessary for checkpoints that were created - # using transfer learning, where the statistics as stored in the checkpoint - # do not match the statistics used to build the normalizer in the checkpoint. - model_instance = update_state_dict(model, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"]) - - return model_instance - - def _check_imports(self): - """Check if required packages are installed.""" - required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"] - from importlib.util import find_spec - - for package in required_packages: - if find_spec(package) is None: - raise ImportError(f"{package!r} is required for this command.") - def run(self, args: Namespace) -> None: """Run the redefine command. @@ -279,7 +96,14 @@ def run(self, args: Namespace) -> None: args : Namespace The arguments passed to the command. """ - self._check_imports() + from anemoi.inference.utils.redefine import create_graph_from_config + from anemoi.inference.utils.redefine import get_coordinates_from_file + from anemoi.inference.utils.redefine import get_coordinates_from_mars_request + from anemoi.inference.utils.redefine import load_graph_from_file + from anemoi.inference.utils.redefine import make_graph_from_coordinates + from anemoi.inference.utils.redefine import update_checkpoint + + check_redefine_imports() import torch from anemoi.utils.checkpoints import load_metadata @@ -287,36 +111,46 @@ def run(self, args: Namespace) -> None: path = Path(args.path) + # Load checkpoint metadata and supporting arrays metadata, supporting_arrays = load_metadata(str(path), supporting_arrays=True) + # Add command to history metadata.setdefault("history", []) metadata["history"].append(f"anemoi-inference redefine {format_namespace_as_str(args)}") + # Create or load the graph if args.graph is not None: - LOG.info("Loading graph from %s", args.graph) - graph = torch.load(args.graph) + graph = load_graph_from_file(args.graph) + elif args.graph_config is not None: + graph = create_graph_from_config(args.graph_config) else: - if args.graph_config is not None: - from anemoi.graphs.create import GraphCreator - from torch_geometric.data import HeteroData - - graph = GraphCreator(args.graph_config).update_graph(HeteroData()) + # Generate graph from coordinates + LOG.info("Generating graph from coordinates...") + + # Get coordinates based on input type + if args.latlon is not None: + local_lats, local_lons = get_coordinates_from_file(args.latlon) + elif args.coords is not None: + local_lats, local_lons = get_coordinates_from_mars_request(args.coords) else: - LOG.info("Generating graph from coordinates...") - metadata, supporting_arrays, graph = self._make_graph_from_coordinates( - args, metadata, supporting_arrays - ) + raise ValueError("No valid coordinates found.") - if args.save_graph is not None: - torch.save(graph, args.save_graph) - LOG.info("Saved updated graph to %s", args.save_graph) + metadata, supporting_arrays, graph = make_graph_from_coordinates( + local_lats, local_lons, args.global_resolution, metadata, supporting_arrays + ) - LOG.info("Updating checkpoint...") + # Save graph if requested + if args.save_graph is not None: + torch.save(graph, args.save_graph) + LOG.info("Saved updated graph to %s", args.save_graph) + # Update checkpoint + LOG.info("Updating checkpoint...") model = torch.load(str(path), weights_only=False, map_location=torch.device("cpu")) - model = self._update_checkpoint(model, metadata, graph) - model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}" + model = update_checkpoint(model, metadata, graph) + # Save updated checkpoint + model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}" torch.save(model, model_path) save_metadata( @@ -325,5 +159,7 @@ def run(self, args: Namespace) -> None: supporting_arrays=supporting_arrays, ) + LOG.info("Updated checkpoint saved to %s", model_path) + command = RedefineCmd diff --git a/src/anemoi/inference/runners/external_graph.py b/src/anemoi/inference/runners/external_graph.py index 4ee175ad..7a457037 100644 --- a/src/anemoi/inference/runners/external_graph.py +++ b/src/anemoi/inference/runners/external_graph.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024- Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -12,7 +12,7 @@ import logging import os -from copy import deepcopy +from contextlib import contextmanager from functools import cached_property from typing import Any from typing import Literal @@ -23,59 +23,11 @@ from ..decorators import main_argument from ..runners.default import DefaultRunner +from ..utils.redefine import update_checkpoint from . import runner_registry LOG = logging.getLogger(__name__) -# Possibly move the function(s) below to anemoi-models or anemoi-utils since it could be used in transfer learning. - - -def contains_any(key, specifications): - contained = False - for specification in specifications: - if specification in key: - contained = True - break - return contained - - -def update_state_dict( - model, external_state_dict, keywords="", ignore_mismatched_layers=False, ignore_additional_layers=False -): - """Update the model's stated_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered.""" - - LOG.info("Updating model state dictionary.") - - if isinstance(keywords, str): - keywords = [keywords] - - # select relevant part of external_state_dict - reduced_state_dict = {k: v for k, v in external_state_dict.items() if contains_any(k, keywords)} - model_state_dict = model.state_dict() - - # check layers and their shapes - for key in list(reduced_state_dict): - if key not in model_state_dict: - if ignore_additional_layers: - LOG.info("Skipping injection of %s, which is not in the model.", key) - del reduced_state_dict[key] - else: - raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.") - elif reduced_state_dict[key].shape != model_state_dict[key].shape: - if ignore_mismatched_layers: - LOG.info("Skipping injection of %s due to shape mismatch.", key) - LOG.info("Model shape: %s", model_state_dict[key].shape) - LOG.info("Provided shape: %s", reduced_state_dict[key].shape) - del reduced_state_dict[key] - else: - raise AssertionError( - "Mismatch in shape of %s. Consider setting 'ignore_mismatched_layers = True'.", key - ) - - # update - model.load_state_dict(reduced_state_dict, strict=False) - return model - def _get_supporting_arrays_from_graph(update_supporting_arrays: dict[str, str], graph: Any) -> dict: """Update the supporting arrays from the graph data.""" @@ -258,35 +210,36 @@ def __init__( @cached_property def graph(self): - + """Get the external graph from file.""" graph_path = self.graph_path + assert os.path.isfile( graph_path ), f"No graph found at {graph_path}. An external graph needs to be specified in the config file for this runner." + LOG.info("Loading external graph from path %s.", graph_path) return torch.load(graph_path, map_location="cpu", weights_only=False) + def on_device(self, device: str = "cpu"): + """Temporally reassign the device of the runner""" + + @contextmanager + def _device_manager(runner: ExternalGraphRunner, device: str): # type: ignore + original_device = runner.device + try: + runner.device = device + yield + finally: + runner.device = original_device + + return _device_manager(self, device) + @cached_property def model(self): # load the model from the checkpoint - device = self.device - self.device = "cpu" - model_instance = super().model - state_dict_ckpt = deepcopy(model_instance.state_dict()) - - # rebuild the model with the new graph - model_instance.graph_data = self.graph - model_instance.config = self.checkpoint._metadata._config - model_instance._build_model() - - # reinstate the weights, biases and normalizer from the checkpoint - # reinstating the normalizer is necessary for checkpoints that were created - # using transfer learning, where the statistics as stored in the checkpoint - # do not match the statistics used to build the normalizer in the checkpoint. - model_instance = update_state_dict( - model_instance, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"] - ) + with self.on_device("cpu"): + model = update_checkpoint(super().model, self.checkpoint._metadata, self.graph) LOG.info("Successfully built model with external graph and reassigned model weights!") - self.device = device - return model_instance.to(self.device) + + return model.to(self.device) diff --git a/src/anemoi/inference/utils/__init__.py b/src/anemoi/inference/utils/__init__.py new file mode 100644 index 00000000..c6149c4e --- /dev/null +++ b/src/anemoi/inference/utils/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/inference/utils/redefine.py b/src/anemoi/inference/utils/redefine.py new file mode 100644 index 00000000..a2556e91 --- /dev/null +++ b/src/anemoi/inference/utils/redefine.py @@ -0,0 +1,280 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING + +LOG = logging.getLogger(__name__) + +if TYPE_CHECKING: + import numpy as np + from torch_geometric.data import HeteroData + + +def update_state_dict( + model, + external_state_dict, + keywords: list[str] | None = None, + ignore_mismatched_layers=False, + ignore_additional_layers=False, +): + """Update the model's state_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered.""" + + LOG.info("Updating model state dictionary.") + + keywords = keywords or [] + + # select relevant part of external_state_dict + reduced_state_dict = {k: v for k, v in external_state_dict.items() if any(kw in k for kw in keywords)} + model_state_dict = model.state_dict() + + # check layers and their shapes + for key in list(reduced_state_dict): + if key not in model_state_dict: + if ignore_additional_layers: + LOG.info("Skipping injection of %s, which is not in the model.", key) + del reduced_state_dict[key] + else: + raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.") + elif reduced_state_dict[key].shape != model_state_dict[key].shape: + if ignore_mismatched_layers: + LOG.info("Skipping injection of %s due to shape mismatch.", key) + LOG.info("Model shape: %s", model_state_dict[key].shape) + LOG.info("Provided shape: %s", reduced_state_dict[key].shape) + del reduced_state_dict[key] + else: + raise AssertionError(f"Mismatch in shape of {key}. Consider setting 'ignore_mismatched_layers = True'.") + + model.load_state_dict(reduced_state_dict, strict=False) + return model + + +def get_coordinates_from_file(latlon_path: Path) -> tuple["np.ndarray", "np.ndarray"]: + """Get coordinates from a numpy file. + + Parameters + ---------- + latlon_path : Path + Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Latitudes and longitudes arrays + """ + import numpy as np + + latlon = np.load(latlon_path) + return latlon[:, 0], latlon[:, 1] + + +COORDINATE = tuple[float, float, float, float, float] + + +def get_coordinates_from_mars_request(coords: COORDINATE) -> tuple["np.ndarray", "np.ndarray"]: + """Get coordinates from MARS request parameters. + + Parameters + ---------- + coords : COORDINATE + Coordinates (North West South East Resolution) + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Latitudes and longitudes arrays + """ + import earthkit.data as ekd + + area = [coords[0], coords[1], coords[2], coords[3]] + + resolution = str(coords[4]) + if resolution.replace(".", "", 1).isdigit(): + resolution = f"{resolution}/{resolution}" + + ds = ekd.from_source( + "mars", + { + "AREA": area, + "GRID": f"{resolution}", + "param": "2t", + "date": -2, + "stream": "oper", + "type": "an", + "levtype": "sfc", + }, + ) + return ds[0].grid_points() # type: ignore + + +def combine_nodes_with_global_grid( + latitudes: "np.ndarray", longitudes: "np.ndarray", global_grid: str +) -> tuple["np.ndarray", "np.ndarray", "np.ndarray", "np.ndarray"]: + """Combine lat/lon nodes with global grid if specified. + + Returns lats, lons, local_mask, global_mask + """ + import numpy as np + from anemoi.datasets.grids import cutout_mask + from anemoi.utils.grids import grids + + global_points = grids(global_grid) + + global_removal_mask = cutout_mask(latitudes, longitudes, global_points["latitudes"], global_points["longitudes"]) + lats = np.concatenate([latitudes, global_points["latitudes"][global_removal_mask]]) + lons = np.concatenate([longitudes, global_points["longitudes"][global_removal_mask]]) + local_mask = np.array([True] * len(latitudes) + [False] * sum(global_removal_mask), dtype=bool) + + return lats, lons, local_mask, global_removal_mask + + +def make_data_graph( + lats: "np.ndarray", + lons: "np.ndarray", + local_mask: "np.ndarray", + global_mask: "np.ndarray", + *, + mask_attr_name: str = "cutout", + attrs, +) -> "HeteroData": + """Make a data graph with the given lat/lon nodes and attributes.""" + import torch + from anemoi.graphs.nodes import LatLonNodes + from torch_geometric.data import HeteroData + + graph = LatLonNodes(lats, lons, name="data").update_graph(HeteroData(), attrs_config=attrs) + graph["data"][mask_attr_name] = torch.from_numpy(local_mask) + return graph + + +def make_graph_from_coordinates( + local_lats: "np.ndarray", local_lons: "np.ndarray", global_resolution: str, metadata: dict, supporting_arrays: dict +) -> tuple[dict, dict, "HeteroData"]: + """Make a graph from coordinates. + + Parameters + ---------- + local_lats : np.ndarray + Local latitude coordinates + local_lons : np.ndarray + Local longitude coordinates + global_resolution : str + Global grid resolution (e.g. n320, o96) + metadata : dict + Checkpoint metadata + supporting_arrays : dict + Supporting arrays from checkpoint + + Returns + ------- + tuple[dict, dict, HeteroData] + Updated metadata, supporting arrays, and graph + """ + import numpy as np + + if global_resolution is None: + raise ValueError("Global resolution must be specified when generating graph from coordinates.") + + LOG.info("Coordinates loaded. Number of local nodes: %d", len(local_lats)) + lats, lons, local_mask, global_mask = combine_nodes_with_global_grid(local_lats, local_lons, global_resolution) + + graph_config = deepcopy(metadata["config"]["graph"]) + data_graph = graph_config["nodes"].pop("data") + + from anemoi.graphs.create import GraphCreator + from anemoi.utils.config import DotDict + + creator = GraphCreator(DotDict(graph_config)) + + LOG.info("Updating graph...") + LOG.debug("Using %r", graph_config) + + def nested_get(d, keys, default=None): + for key in keys: + d = d.get(key, {}) + return d or default + + mask_attr_name = nested_get(graph_config, ["nodes", "hidden", "node_builder", "mask_attr_name"], "cutout") + + data_graph = make_data_graph( + lats, lons, local_mask, global_mask, mask_attr_name=mask_attr_name, attrs=data_graph.get("attrs", None) + ) + + LOG.info("Created data graph with %d nodes.", data_graph.num_nodes) + graph = creator.update_graph(data_graph) + + supporting_arrays[f"global/{mask_attr_name}"] = global_mask + supporting_arrays[f"lam_0/{mask_attr_name}"] = np.array([True] * len(local_lats)) + + supporting_arrays["latitudes"] = lats + supporting_arrays["longitudes"] = lons + supporting_arrays["grid_indices"] = np.ones(local_mask.shape, dtype=np.int64) + + return metadata, supporting_arrays, graph + + +def update_checkpoint(model, metadata: dict, graph: "HeteroData"): + """Update checkpoint with new graph and update state dict.""" + from anemoi.utils.config import DotDict + + state_dict_ckpt = deepcopy(model.state_dict()) + + # rebuild the model with the new graph + model.graph_data = graph + model.config = DotDict(metadata).config + model._build_model() + + # reinstate the weights, biases and normalizer from the checkpoint + # reinstating the normalizer is necessary for checkpoints that were created + # using transfer learning, where the statistics as stored in the checkpoint + # do not match the statistics used to build the normalizer in the checkpoint. + model_instance = update_state_dict(model, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"]) + + return model_instance + + +def load_graph_from_file(graph_path: Path) -> "HeteroData": + """Load graph from file. + + Parameters + ---------- + graph_path : Path + Path to graph file + + Returns + ------- + HeteroData + Loaded graph + """ + import torch + + LOG.info("Loading graph from %s", graph_path) + return torch.load(graph_path, weights_only=False, map_location=torch.device("cpu")) + + +def create_graph_from_config(graph_config_path: Path) -> "HeteroData": + """Create graph from configuration file. + + Parameters + ---------- + graph_config_path : Path + Path to graph configuration file + + Returns + ------- + HeteroData + Created graph + """ + from anemoi.graphs.create import GraphCreator + from torch_geometric.data import HeteroData + + return GraphCreator(graph_config_path).update_graph(HeteroData()) From d95dad4c33e4156edcc3ee4dde2fc83368d341aa Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 30 Sep 2025 13:11:29 +0100 Subject: [PATCH 3/9] Apply suggestions from code review Co-authored-by: Gert Mertes <13658335+gmertes@users.noreply.github.com> --- src/anemoi/inference/commands/redefine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/inference/commands/redefine.py b/src/anemoi/inference/commands/redefine.py index 0f159986..36066f63 100644 --- a/src/anemoi/inference/commands/redefine.py +++ b/src/anemoi/inference/commands/redefine.py @@ -69,7 +69,7 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: group = command_parser.add_mutually_exclusive_group(required=True) group.add_argument("-g", "--graph", type=Path, help="Path to graph file to use") - group.add_argument("-y", "--graph_config", type=Path, help="Path to graph config to use") + group.add_argument("-y", "--graph-config", type=Path, help="Path to graph config to use") group.add_argument( "-ll", "--latlon", @@ -80,7 +80,7 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: command_parser.add_argument( "-gr", - "--global_resolution", + "--global-resolution", type=str, help="Global grid resolution required with --coords, (e.g. n320, o96).", ) From 71a8f25356cabc580498371adcf4685d4d98a34b Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 30 Sep 2025 12:18:43 +0000 Subject: [PATCH 4/9] Rename to redefine_graph --- docs/cli/introduction.rst | 2 +- docs/cli/{redefine.rst => redefine_graph.rst} | 24 +++++++++---------- docs/index.rst | 2 +- .../{redefine.py => redefine_graph.py} | 20 ++++++++-------- .../inference/runners/external_graph.py | 2 +- .../utils/{redefine.py => redefine_graph.py} | 0 6 files changed, 25 insertions(+), 25 deletions(-) rename docs/cli/{redefine.rst => redefine_graph.rst} (81%) rename src/anemoi/inference/commands/{redefine.py => redefine_graph.py} (88%) rename src/anemoi/inference/utils/{redefine.py => redefine_graph.py} (100%) diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst index 3c33ca42..be902f98 100644 --- a/docs/cli/introduction.rst +++ b/docs/cli/introduction.rst @@ -19,4 +19,4 @@ The commands are: - :ref:`Validate Command ` - :ref:`Patch Command ` - :ref:`Requests Command ` -- :ref:`Redefine Command ` +- :ref:`redefine_graph Command ` diff --git a/docs/cli/redefine.rst b/docs/cli/redefine_graph.rst similarity index 81% rename from docs/cli/redefine.rst rename to docs/cli/redefine_graph.rst index cfca1305..2688b9fc 100644 --- a/docs/cli/redefine.rst +++ b/docs/cli/redefine_graph.rst @@ -1,7 +1,7 @@ -.. _redefine-command: +.. _redefine_graph-command: -Redefine Command -=============== +Redefine Graph Command +====================== With this command, you can redefine the graph of a checkpoint file. This is useful when you want to change / reconfigure the local-domain of a model, or rebuild with a new graph. @@ -21,7 +21,7 @@ Subcommands allow for a graph to be made from a lat/lon coordinate file, boundin .. code-block:: bash - % anemoi-inference redefine --help + % anemoi-inference redefine_graph --help Redefine the graph of a checkpoint file. @@ -49,19 +49,19 @@ Subcommands allow for a graph to be made from a lat/lon coordinate file, boundin Examples ********* -Here are some examples of how to use the `redefine` command: +Here are some examples of how to use the `redefine_graph` command: #. Using a graph file: .. code-block:: bash - anemoi-inference redefine path/to/checkpoint --graph path/to/graph + anemoi-inference redefine_graph path/to/checkpoint --graph path/to/graph #. Using a graph configuration: .. code-block:: bash - anemoi-inference redefine path/to/checkpoint --graph_config path/to/graph_config + anemoi-inference redefine_graph path/to/checkpoint --graph_config path/to/graph_config .. note:: The configuration of the existing graph can be found using: @@ -85,19 +85,19 @@ Here are some examples of how to use the `redefine` command: .. code-block:: bash - anemoi-inference redefine path/to/checkpoint --latlon path/to/latlon.npy + anemoi-inference redefine_graph path/to/checkpoint --latlon path/to/latlon.npy #. Using bounding box coordinates: .. code-block:: bash - anemoi-inference redefine path/to/checkpoint --coords North West South East Resolution + anemoi-inference redefine_graph path/to/checkpoint --coords North West South East Resolution i.e. .. code-block:: bash - anemoi-inference redefine path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 + anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 All examples can optionally save the updated graph and checkpoint using the `--save-graph` and `--output` options. @@ -114,7 +114,7 @@ Redefine the checkpoint .. code-block:: bash - anemoi-inference redefine path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 --save-graph path/to/updated_graph --output path/to/updated_checkpoint + anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 --save-graph path/to/updated_graph --output path/to/updated_checkpoint Create the inference config --------------------------- @@ -157,4 +157,4 @@ Reference :module: anemoi.inference.__main__ :func: create_parser :prog: anemoi-inference - :path: redefine + :path: redefine_graph diff --git a/docs/index.rst b/docs/index.rst index c64a3b8b..353428a4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -153,7 +153,7 @@ You may also have to install pandoc on MacOS: cli/inspect cli/patch cli/requests - cli/redefine + cli/redefine_graph .. toctree:: :maxdepth: 1 diff --git a/src/anemoi/inference/commands/redefine.py b/src/anemoi/inference/commands/redefine_graph.py similarity index 88% rename from src/anemoi/inference/commands/redefine.py rename to src/anemoi/inference/commands/redefine_graph.py index 36066f63..899ec694 100644 --- a/src/anemoi/inference/commands/redefine.py +++ b/src/anemoi/inference/commands/redefine_graph.py @@ -52,7 +52,7 @@ def format_namespace_as_str(namespace: Namespace) -> str: return " ".join(args) -class RedefineCmd(Command): +class RedefineGraphCmd(Command): """Redefine the graph of a checkpoint file.""" def add_arguments(self, command_parser: ArgumentParser) -> None: @@ -89,19 +89,19 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None) def run(self, args: Namespace) -> None: - """Run the redefine command. + """Run the redefine_graph command. Parameters ---------- args : Namespace The arguments passed to the command. """ - from anemoi.inference.utils.redefine import create_graph_from_config - from anemoi.inference.utils.redefine import get_coordinates_from_file - from anemoi.inference.utils.redefine import get_coordinates_from_mars_request - from anemoi.inference.utils.redefine import load_graph_from_file - from anemoi.inference.utils.redefine import make_graph_from_coordinates - from anemoi.inference.utils.redefine import update_checkpoint + from anemoi.inference.utils.redefine_graph import create_graph_from_config + from anemoi.inference.utils.redefine_graph import get_coordinates_from_file + from anemoi.inference.utils.redefine_graph import get_coordinates_from_mars_request + from anemoi.inference.utils.redefine_graph import load_graph_from_file + from anemoi.inference.utils.redefine_graph import make_graph_from_coordinates + from anemoi.inference.utils.redefine_graph import update_checkpoint check_redefine_imports() @@ -116,7 +116,7 @@ def run(self, args: Namespace) -> None: # Add command to history metadata.setdefault("history", []) - metadata["history"].append(f"anemoi-inference redefine {format_namespace_as_str(args)}") + metadata["history"].append(f"anemoi-inference redefine_graph {format_namespace_as_str(args)}") # Create or load the graph if args.graph is not None: @@ -162,4 +162,4 @@ def run(self, args: Namespace) -> None: LOG.info("Updated checkpoint saved to %s", model_path) -command = RedefineCmd +command = RedefineGraphCmd diff --git a/src/anemoi/inference/runners/external_graph.py b/src/anemoi/inference/runners/external_graph.py index 7a457037..3a926fd1 100644 --- a/src/anemoi/inference/runners/external_graph.py +++ b/src/anemoi/inference/runners/external_graph.py @@ -23,7 +23,7 @@ from ..decorators import main_argument from ..runners.default import DefaultRunner -from ..utils.redefine import update_checkpoint +from ..utils.redefine_graph import update_checkpoint from . import runner_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/inference/utils/redefine.py b/src/anemoi/inference/utils/redefine_graph.py similarity index 100% rename from src/anemoi/inference/utils/redefine.py rename to src/anemoi/inference/utils/redefine_graph.py From a589127ab627f3bf5c0f3176d57c6dadd18c52b6 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 1 Oct 2025 11:45:57 +0000 Subject: [PATCH 5/9] _ jail --- docs/cli/redefine_graph.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/cli/redefine_graph.rst b/docs/cli/redefine_graph.rst index 2688b9fc..97b5d9a3 100644 --- a/docs/cli/redefine_graph.rst +++ b/docs/cli/redefine_graph.rst @@ -32,13 +32,13 @@ Subcommands allow for a graph to be made from a lat/lon coordinate file, boundin -h, --help show this help message and exit -g GRAPH, --graph GRAPH Path to graph file to use - -y GRAPH_CONFIG, --graph_config GRAPH_CONFIG + -y GRAPH_CONFIG, --graph-config GRAPH_CONFIG Path to graph config to use -ll LATLON, --latlon LATLON Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes. -c COORDS COORDS COORDS COORDS COORDS, --coords COORDS COORDS COORDS COORDS COORDS Coordinates, (North West South East Resolution). - -gr GLOBAL_RESOLUTION, --global_resolution GLOBAL_RESOLUTION + -gr GLOBAL_RESOLUTION, --global_-esolution GLOBAL_RESOLUTION Global grid resolution required with --coords, (e.g. n320, o96). --save-graph SAVE_GRAPH Path to save the updated graph. From 751fe89d850ba3fc12b1c77c0d04966a822b66ef Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 1 Oct 2025 17:04:07 +0000 Subject: [PATCH 6/9] Update data_graph_attributes retrieval --- src/anemoi/inference/utils/redefine_graph.py | 24 ++++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/anemoi/inference/utils/redefine_graph.py b/src/anemoi/inference/utils/redefine_graph.py index a2556e91..692c57ea 100644 --- a/src/anemoi/inference/utils/redefine_graph.py +++ b/src/anemoi/inference/utils/redefine_graph.py @@ -142,17 +142,18 @@ def make_data_graph( lons: "np.ndarray", local_mask: "np.ndarray", global_mask: "np.ndarray", + reference_node_name: str = "data", *, - mask_attr_name: str = "cutout", - attrs, + mask_attr_name: str = "cutout_mask", + attrs: dict | None = None, ) -> "HeteroData": """Make a data graph with the given lat/lon nodes and attributes.""" import torch from anemoi.graphs.nodes import LatLonNodes from torch_geometric.data import HeteroData - graph = LatLonNodes(lats, lons, name="data").update_graph(HeteroData(), attrs_config=attrs) - graph["data"][mask_attr_name] = torch.from_numpy(local_mask) + graph = LatLonNodes(lats, lons, name=reference_node_name).update_graph(HeteroData(), attrs_config=attrs) # type: ignore + graph[reference_node_name][mask_attr_name] = torch.from_numpy(local_mask).unsqueeze(1) return graph @@ -205,8 +206,21 @@ def nested_get(d, keys, default=None): mask_attr_name = nested_get(graph_config, ["nodes", "hidden", "node_builder", "mask_attr_name"], "cutout") + data_graph_attributes = None + # if mask_attr_name in data_graph.get("attributes", {}): + # data_graph_attributes = {mask_attr_name: data_graph["attributes"][mask_attr_name]} + + LOG.info("Found mask attribute name: %r", mask_attr_name) + # LOG.info("Found data graph attributes: %s", data_graph_attributes) + data_graph = make_data_graph( - lats, lons, local_mask, global_mask, mask_attr_name=mask_attr_name, attrs=data_graph.get("attrs", None) + lats, + lons, + local_mask, + global_mask, + reference_node_name="data", + mask_attr_name=mask_attr_name, + attrs=data_graph_attributes, ) LOG.info("Created data graph with %d nodes.", data_graph.num_nodes) From 1cdb28ab4b7685158797c9031d00c2d0028ab9a3 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 1 Oct 2025 17:36:45 +0000 Subject: [PATCH 7/9] Add clean to graph construction --- src/anemoi/inference/commands/redefine_graph.py | 2 +- src/anemoi/inference/utils/redefine_graph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/inference/commands/redefine_graph.py b/src/anemoi/inference/commands/redefine_graph.py index 899ec694..2637d3c4 100644 --- a/src/anemoi/inference/commands/redefine_graph.py +++ b/src/anemoi/inference/commands/redefine_graph.py @@ -146,7 +146,7 @@ def run(self, args: Namespace) -> None: # Update checkpoint LOG.info("Updating checkpoint...") - model = torch.load(str(path), weights_only=False, map_location=torch.device("cpu")) + model = torch.load(path, weights_only=False, map_location=torch.device("cpu")) model = update_checkpoint(model, metadata, graph) # Save updated checkpoint diff --git a/src/anemoi/inference/utils/redefine_graph.py b/src/anemoi/inference/utils/redefine_graph.py index 692c57ea..b1de5593 100644 --- a/src/anemoi/inference/utils/redefine_graph.py +++ b/src/anemoi/inference/utils/redefine_graph.py @@ -224,7 +224,7 @@ def nested_get(d, keys, default=None): ) LOG.info("Created data graph with %d nodes.", data_graph.num_nodes) - graph = creator.update_graph(data_graph) + graph = creator.clean(creator.update_graph(data_graph)) supporting_arrays[f"global/{mask_attr_name}"] = global_mask supporting_arrays[f"lam_0/{mask_attr_name}"] = np.array([True] * len(local_lats)) From f7b3e6b3dc923ad59c1e13402606aef00d881eca Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 2 Oct 2025 09:31:16 +0000 Subject: [PATCH 8/9] Wacky types --- src/anemoi/inference/utils/redefine_graph.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/anemoi/inference/utils/redefine_graph.py b/src/anemoi/inference/utils/redefine_graph.py index b1de5593..bc7242de 100644 --- a/src/anemoi/inference/utils/redefine_graph.py +++ b/src/anemoi/inference/utils/redefine_graph.py @@ -12,6 +12,7 @@ from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING +from typing import NamedTuple LOG = logging.getLogger(__name__) @@ -77,15 +78,20 @@ def get_coordinates_from_file(latlon_path: Path) -> tuple["np.ndarray", "np.ndar return latlon[:, 0], latlon[:, 1] -COORDINATE = tuple[float, float, float, float, float] +class Coordinate(NamedTuple): + north: float + west: float + south: float + east: float + resolution: float -def get_coordinates_from_mars_request(coords: COORDINATE) -> tuple["np.ndarray", "np.ndarray"]: +def get_coordinates_from_mars_request(coords: Coordinate) -> tuple["np.ndarray", "np.ndarray"]: """Get coordinates from MARS request parameters. Parameters ---------- - coords : COORDINATE + coords : Coordinate Coordinates (North West South East Resolution) Returns @@ -95,9 +101,9 @@ def get_coordinates_from_mars_request(coords: COORDINATE) -> tuple["np.ndarray", """ import earthkit.data as ekd - area = [coords[0], coords[1], coords[2], coords[3]] + area = [coords.north, coords.west, coords.south, coords.east] - resolution = str(coords[4]) + resolution = str(coords.resolution) if resolution.replace(".", "", 1).isdigit(): resolution = f"{resolution}/{resolution}" From 1845c6a51a21c17ad73a991487870ae5d559a35e Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 6 Oct 2025 14:27:05 +0000 Subject: [PATCH 9/9] Fix issue with dict on ExternalGraph runner --- src/anemoi/inference/metadata.py | 10 ++++++++++ src/anemoi/inference/runners/external_graph.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index e0b752ca..7491daac 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -95,6 +95,16 @@ def __init__(self, metadata: dict[str, Any], supporting_arrays: dict[str, FloatA self._supporting_arrays = supporting_arrays self._variables_categories = None + def to_dict(self) -> dict[str, Any]: + """Convert the Metadata object to a dictionary. + + Returns + ------- + dict + A copy of the metadata dictionary. + """ + return dict(self._metadata).copy() + @property def _indices(self) -> DotDict: """Return the data indices.""" diff --git a/src/anemoi/inference/runners/external_graph.py b/src/anemoi/inference/runners/external_graph.py index 3a926fd1..fec6dffe 100644 --- a/src/anemoi/inference/runners/external_graph.py +++ b/src/anemoi/inference/runners/external_graph.py @@ -238,7 +238,8 @@ def _device_manager(runner: ExternalGraphRunner, device: str): # type: ignore def model(self): # load the model from the checkpoint with self.on_device("cpu"): - model = update_checkpoint(super().model, self.checkpoint._metadata, self.graph) + metadata = self.checkpoint._metadata.to_dict() + model = update_checkpoint(super().model, metadata, self.graph) LOG.info("Successfully built model with external graph and reassigned model weights!")