diff --git a/src/anemoi/inference/commands/inspect.py b/src/anemoi/inference/commands/inspect.py index 5ef6ea36..3925aa1d 100644 --- a/src/anemoi/inference/commands/inspect.py +++ b/src/anemoi/inference/commands/inspect.py @@ -15,6 +15,8 @@ from collections.abc import Callable from typing import Any +import rich + from ..checkpoint import Checkpoint from . import Command @@ -31,6 +33,9 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: The argument parser to which the arguments will be added. """ command_parser.add_argument("path", help="Path to the checkpoint.") + command_parser.add_argument( + "--origins", action="store_true", help="Print the origins of the variables in the checkpoint." + ) command_parser.add_argument( "--validate", action="store_true", help="Validate the current virtual environment against the checkpoint" ) @@ -56,6 +61,10 @@ def run(self, args: Namespace) -> None: c.validate_environment() return + if args.origins: + self.origins(c, args) + return + def _(f: Callable[[], Any]) -> Any: """Wrapper function to handle exceptions. @@ -90,5 +99,19 @@ def _(f: Callable[[], Any]) -> Any: print(" ", json.dumps(value, indent=4, default=str)) print() + def origins(self, c: Checkpoint, args: Namespace) -> None: + from anemoi.datasets import open_dataset + from anemoi.transform.origins import make_origins + + open_dataset_args, open_dataset_kwargs = c.open_dataset_args_kwargs(use_original_paths=False) + ds = open_dataset(*open_dataset_args, **open_dataset_kwargs) + result = {} + for p in ds.components(): + o = make_origins(p.origins(compressed=True), p.dataset_name) + result.update(o) + + print("Origins:") + rich.print(result) + command = InspectCmd