Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/anemoi/inference/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from collections.abc import Callable
from typing import Any

import rich

from ..checkpoint import Checkpoint
from . import Command

Expand All @@ -31,6 +33,9 @@ def add_arguments(self, command_parser: ArgumentParser) -> None:
The argument parser to which the arguments will be added.
"""
command_parser.add_argument("path", help="Path to the checkpoint.")
command_parser.add_argument(
"--origins", action="store_true", help="Print the origins of the variables in the checkpoint."
)
command_parser.add_argument(
"--validate", action="store_true", help="Validate the current virtual environment against the checkpoint"
)
Expand All @@ -56,6 +61,10 @@ def run(self, args: Namespace) -> None:
c.validate_environment()
return

if args.origins:
self.origins(c, args)
return

def _(f: Callable[[], Any]) -> Any:
"""Wrapper function to handle exceptions.

Expand Down Expand Up @@ -90,5 +99,19 @@ def _(f: Callable[[], Any]) -> Any:
print(" ", json.dumps(value, indent=4, default=str))
print()

def origins(self, c: Checkpoint, args: Namespace) -> None:
from anemoi.datasets import open_dataset
from anemoi.transform.origins import make_origins

open_dataset_args, open_dataset_kwargs = c.open_dataset_args_kwargs(use_original_paths=False)
ds = open_dataset(*open_dataset_args, **open_dataset_kwargs)
result = {}
for p in ds.components():
o = make_origins(p.origins(compressed=True), p.dataset_name)
result.update(o)

print("Origins:")
rich.print(result)


command = InspectCmd
Loading