Skip to content
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
fdf767e
Changes to pyproject
OpheliaMiralles Jul 25, 2025
922a80a
WIP
OpheliaMiralles Aug 12, 2025
7fbee60
Set reference date to the first date
OpheliaMiralles Aug 12, 2025
a385c38
Precommit
OpheliaMiralles Aug 12, 2025
53605a7
Replace str with torch.device
OpheliaMiralles Aug 12, 2025
6391edd
Cleanup
OpheliaMiralles Aug 12, 2025
f40e077
Precommit
OpheliaMiralles Aug 12, 2025
cb44806
Reset permissions
OpheliaMiralles Aug 13, 2025
09135af
Add extract pre-processors
OpheliaMiralles Aug 13, 2025
add4ed6
Precommit
OpheliaMiralles Aug 13, 2025
4b0e2cc
WIP
OpheliaMiralles Aug 12, 2025
176a5a7
Reset permissions
OpheliaMiralles Aug 13, 2025
02df7df
Remove debugging changes
OpheliaMiralles Aug 15, 2025
06728e6
Fix cutout for interpolator
OpheliaMiralles Aug 19, 2025
fabf63f
Default date change
OpheliaMiralles Aug 29, 2025
79651de
Allow interpolator to run from files
OpheliaMiralles Sep 12, 2025
7749194
Sharding
OpheliaMiralles Sep 19, 2025
737e536
Re-add mps support
OpheliaMiralles Sep 26, 2025
bc963bd
Rebase on main/remove include_forcings
OpheliaMiralles Sep 26, 2025
f8e9a92
Refactor
OpheliaMiralles Oct 7, 2025
6c11c1b
Refactor
OpheliaMiralles Oct 8, 2025
cafe185
Run precommit
OpheliaMiralles Oct 8, 2025
2a43dad
Merge remote-tracking branch 'origin/main' into fix/interp_files_tent…
OpheliaMiralles Oct 8, 2025
9d22efe
Update src/anemoi/inference/runners/interpolator.py
OpheliaMiralles Oct 8, 2025
d52d44c
Update src/anemoi/inference/runners/interpolator.py
OpheliaMiralles Oct 8, 2025
85fa086
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
2cdcc8c
State_fields outside loop
OpheliaMiralles Oct 8, 2025
faf115c
Update src/anemoi/inference/pre_processors/extract.py
OpheliaMiralles Oct 8, 2025
8147130
Update src/anemoi/inference/runner.py
OpheliaMiralles Oct 8, 2025
bc83d13
Update src/anemoi/inference/runner.py
OpheliaMiralles Oct 8, 2025
8bb7795
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
488f17a
Fix reference date
OpheliaMiralles Oct 8, 2025
7292df4
In case no date in config
OpheliaMiralles Oct 10, 2025
b519661
Merge branch 'main' into fix/interp_files
OpheliaMiralles Oct 13, 2025
3799cf8
Update pyproject.toml
OpheliaMiralles Oct 13, 2025
aff2e9f
Update src/anemoi/inference/config/__init__.py
OpheliaMiralles Oct 13, 2025
3178d7d
Update src/anemoi/inference/pre_processors/forward_transform_filter.py
OpheliaMiralles Oct 13, 2025
6fc3c78
Revert slurm compat changes and propagate kwargs
OpheliaMiralles Oct 13, 2025
20fa81a
Update src/anemoi/inference/pre_processors/extract.py
OpheliaMiralles Oct 13, 2025
b35ba77
Clarify fields type in ekd
OpheliaMiralles Oct 13, 2025
64e584f
Update src/anemoi/inference/runner.py
OpheliaMiralles Oct 13, 2025
8a614e9
Remove unnecessary changes
OpheliaMiralles Oct 13, 2025
71f3dd7
Precommit
OpheliaMiralles Oct 13, 2025
367c540
Merge branch 'main' into fix/interp_files
OpheliaMiralles Oct 13, 2025
abbd3f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
f2a38e6
Allowing extra forcing is actually needed
OpheliaMiralles Oct 13, 2025
ea4259b
Fix cutout preprocesser thing
OpheliaMiralles Oct 13, 2025
fcfdc2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
6216e55
Apparently deleted cutout
OpheliaMiralles Oct 14, 2025
cc0bae6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
10a37b2
Fix cutout
OpheliaMiralles Oct 14, 2025
bd0cfe2
Review
OpheliaMiralles Oct 14, 2025
87b5855
Update src/anemoi/inference/config/__init__.py
OpheliaMiralles Oct 14, 2025
2e650f3
Add ref_date_index to gribfile
OpheliaMiralles Oct 15, 2025
9d83ca8
Update src/anemoi/inference/runners/interpolator.py
OpheliaMiralles Oct 15, 2025
6a1bae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2025
d5582ab
Remove ekd_field from state
OpheliaMiralles Oct 15, 2025
a5281e5
Unify pre and post processor extract
OpheliaMiralles Oct 15, 2025
ee50495
Revert config changw
OpheliaMiralles Oct 15, 2025
1ae7819
Update src/anemoi/inference/runner.py
OpheliaMiralles Oct 15, 2025
48bc543
Update src/anemoi/inference/runners/interpolator.py
OpheliaMiralles Oct 15, 2025
e334ec5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2025
687cc69
Update src/anemoi/inference/runners/interpolator.py
OpheliaMiralles Oct 15, 2025
318eb37
Update src/anemoi/inference/inputs/gribfile.py
OpheliaMiralles Oct 15, 2025
6bf6c7d
Update src/anemoi/inference/inputs/ekd.py
OpheliaMiralles Oct 15, 2025
84b0e7f
Update src/anemoi/inference/inputs/ekd.py
OpheliaMiralles Oct 15, 2025
84e514f
De-complexify, update doc str
OpheliaMiralles Oct 15, 2025
63b5408
Update docstr
OpheliaMiralles Oct 15, 2025
8b5a18d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2025
46c51e3
Revert change on forcing
OpheliaMiralles Oct 15, 2025
340344e
Precommit
OpheliaMiralles Oct 15, 2025
04111bc
Merge branch 'main' into fix/interp_files
OpheliaMiralles Oct 15, 2025
c8362c3
Delete uv.lock
HCookie Oct 15, 2025
7706a17
Merge branch 'main' into fix/interp_files
OpheliaMiralles Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/anemoi/inference/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.purpose})"

@abstractmethod
def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
"""Create the input state dictionary.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
**kwargs : Any
Additional keyword arguments.

Returns
-------
Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/inference/inputs/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ def __init__(
self.dataset = dataset
self.kwargs = kwargs

def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
**kwargs : Any
Additional keyword arguments.

Returns
-------
Expand All @@ -162,6 +164,7 @@ def create_input_state(self, *, date: Date | None) -> State:
),
variables=self.variables,
date=date,
**kwargs,
)

def retrieve(self, variables: list[str], dates: list[Date]) -> Any:
Expand Down
30 changes: 23 additions & 7 deletions src/anemoi/inference/inputs/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,31 @@
import logging
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Mapping

import numpy as np

from anemoi.inference.input import Input
from anemoi.inference.inputs import create_input
from anemoi.inference.inputs import input_registry
from anemoi.inference.types import Date
from anemoi.inference.types import ProcessorConfig
from anemoi.inference.types import State

from . import input_registry

LOG = logging.getLogger(__name__)


def contains_key(obj, key: str) -> bool:
"""Recursively check if `key` exists anywhere in a nested config (dict/DotDict/lists)."""
if isinstance(obj, Mapping):
if key in obj:
return True
return any(contains_key(v, key) for v in obj.values())
if isinstance(obj, (list, tuple, set)):
return any(contains_key(v, key) for v in obj)
return False


def _mask_and_combine_states(
existing_state: State,
new_state: State,
Expand Down Expand Up @@ -138,9 +149,12 @@ def __init__(
cfg = cfg.copy()
mask = cfg.pop("mask", f"{src}/cutout_mask")

self.sources[src] = create_input(
context, cfg, variables=variables, pre_processors=pre_processors, purpose=purpose
)
if contains_key(cfg, "pre_processors"):
self.sources[src] = create_input(context, cfg, variables=variables, purpose=purpose)
else:
self.sources[src] = create_input(
context, cfg, variables=variables, purpose=purpose, pre_processors=pre_processors
)

if isinstance(mask, str):
self.masks[src] = self.sources[src].checkpoint.load_supporting_array(mask)
Expand All @@ -151,13 +165,15 @@ def __repr__(self):
"""Return a string representation of the Cutout object."""
return f"Cutout({self.sources})"

def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
**kwargs : dict
Additional keyword arguments for the source input state creation.

Returns
-------
Expand All @@ -173,7 +189,7 @@ def create_input_state(self, *, date: Date | None) -> State:
combined_state = {}

for source in self.sources.keys():
source_state = self.sources[source].create_input_state(date=date)
source_state = self.sources[source].create_input_state(date=date, **kwargs)
source_mask = self.masks[source]

# Create the mask front padded with zeros
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ def __repr__(self) -> str:
"""Return a string representation of the DatasetInput."""
return f"DatasetInput({self.open_dataset_args}, {self.open_dataset_kwargs})"

def create_input_state(self, *, date: Date | None = None) -> State:
def create_input_state(self, *, date: Date | None = None, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Any]
The date for which to create the input state.
**kwargs : Any
Additional keyword arguments.

Returns
-------
Expand Down
68 changes: 39 additions & 29 deletions src/anemoi/inference/inputs/ekd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def __init__(
----------
context : Any
The context in which the input is used.
pre_processors : Optional[List[ProcessorConfig]], default None
Pre-processors to apply to the input
namer : Optional[Union[Callable[[Any, Dict[str, Any]], str], Dict[str, Any]]]
Optional namer for the input.
"""
Expand All @@ -128,12 +126,12 @@ def __init__(
assert callable(self._namer), type(self._namer)

def _filter_and_sort(self, data: Any, *, dates: list[Any], title: str) -> Any:
"""Filter and sort the data.
"""Filter and sort the data (earthkit FieldList/FieldArray).

Parameters
----------
data : Any
The data to filter and sort.
The data to filter and sort (FieldList or FieldArray).
dates : List[Any]
The list of dates to select.
title : str
Expand All @@ -142,7 +140,7 @@ def _filter_and_sort(self, data: Any, *, dates: list[Any], title: str) -> Any:
Returns
-------
Any
The filtered and sorted data.
The filtered and sorted data (FieldArray).
"""

def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str:
Expand All @@ -166,12 +164,12 @@ def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str:
return data

def _find_variable(self, data: Any, name: str, **kwargs: Any) -> Any:
"""Find a variable in the data.
"""Find a variable in the data (earthkit FieldList/FieldArray selection).

Parameters
----------
data : Any
The data to search.
The data to search (FieldList or FieldArray).
name : str
The name of the variable to find.
**kwargs : Any
Expand All @@ -180,7 +178,7 @@ def _find_variable(self, data: Any, name: str, **kwargs: Any) -> Any:
Returns
-------
Any
The selected variable.
The selected variable (FieldArray subset).
"""

def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str:
Expand All @@ -198,9 +196,17 @@ def _create_state(
longitudes: FloatArray | None = None,
dtype: DTypeLike = np.float32,
flatten: bool = True,
ref_date_index: int = -1,
) -> State:
"""Create a state from an ekd.FieldList.

Notes
-----
- The `fields` argument must be an earthkit FieldList (or FieldArray-compatible).
- This method intentionally converts state["fields"] from a FieldList to
a Dict[str, np.ndarray] with shape (len(dates), n_points).
- Pre-processors are run while state["fields"] is still a FieldList.

Parameters
----------
fields : ekd.FieldList
Expand All @@ -215,30 +221,17 @@ def _create_state(
The data type.
flatten : bool
Whether to flatten the data.
ref_date_index : int
The index of the reference date in the dates list.

Returns
-------
State
The created input state.
The created input state with state["fields"] as Dict[str, np.ndarray].
"""
fields = self.pre_process(fields)

dates = sorted([to_datetime(d) for d in dates])
date_to_index = {d.isoformat(): i for i, d in enumerate(dates)}

state = dict(date=dates[-1], latitudes=latitudes, longitudes=longitudes, fields=dict())

if len(fields) == 0:
LOG.warning("No input fields found for dates %s (%s)", [d.isoformat() for d in dates], self)
return state

state_fields = state["fields"]

fields = self._filter_and_sort(fields, dates=dates, title="Create input state")

if latitudes is None and longitudes is None:
try:
state["latitudes"], state["longitudes"] = fields[0].grid_points()
latitudes, longitudes = fields[0].grid_points()
LOG.info(
"%s: using `latitudes` and `longitudes` from the first input field",
self.__class__.__name__,
Expand All @@ -251,8 +244,6 @@ def _create_state(
latitudes = self.checkpoint.latitudes
longitudes = self.checkpoint.longitudes
if latitudes is not None and longitudes is not None:
state["latitudes"] = latitudes
state["longitudes"] = longitudes
LOG.info(
"%s: using `latitudes` and `longitudes` found in the checkpoint.",
self.__class__.__name__,
Expand All @@ -264,6 +255,21 @@ def _create_state(
)
raise e

state = dict(date=dates[ref_date_index], latitudes=latitudes, longitudes=longitudes, fields=fields)

# allow hooks to operate on the FieldList before conversion to numpy
state = self.pre_process(state)

fields = state["fields"]
state_fields = {}

if len(fields) == 0:
raise ValueError("No input fields provided")

dates = sorted([to_datetime(d) for d in dates])
date_to_index = {d.isoformat(): i for i, d in enumerate(dates)}
fields = self._filter_and_sort(fields, dates=dates, title="Create input state")

check = defaultdict(set)

n_points = fields[0].to_numpy(dtype=dtype, flatten=flatten).size
Expand Down Expand Up @@ -295,7 +301,7 @@ def _create_state(
raise ValueError(f"Duplicate dates for {name}")

check[name].add(date_idx)

state["fields"] = state_fields
for name, idx in check.items():
if len(idx) != len(dates):
LOG.error("Missing dates for %s: %s", name, idx)
Expand Down Expand Up @@ -326,6 +332,7 @@ def _create_input_state(
longitudes: FloatArray | None = None,
dtype: DTypeLike = np.float32,
flatten: bool = True,
ref_date_index: int = -1,
) -> State:
"""Create the input state.

Expand All @@ -345,6 +352,8 @@ def _create_input_state(
The data type.
flatten : bool
Whether to flatten the data.
ref_date_index : int
The index of the reference date in the dates list.

Returns
-------
Expand All @@ -366,6 +375,7 @@ def _create_input_state(
longitudes=longitudes,
dtype=dtype,
flatten=flatten,
ref_date_index=ref_date_index,
)

def _load_forcings_state(self, fields: ekd.FieldList, *, dates: list[Date], current_state: State) -> State:
Expand Down Expand Up @@ -397,7 +407,7 @@ def _load_forcings_state(self, fields: ekd.FieldList, *, dates: list[Date], curr
def set_private_attributes(self, state: State, fields: ekd.FieldList) -> None: # type: ignore
"""Set private attributes to the state.

Provides geography information if available retrieved from the fields.
Provides geography information if available retrieved from the fields (FieldList/FieldArray).
"""
geography_information = {}

Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/inference/inputs/empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ def __init__(self, context: Context, **kwargs: Any) -> None:
super().__init__(context, **kwargs)
assert self.variables in (None, []), "EmptyInput should not have variables"

def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
"""Create an empty input state.

Parameters
----------
date : Date or None
The date for the input state.
**kwargs : Any
Additional keyword arguments.

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/inputs/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def __init__(
# NOTE: this is a temporary workaround for #191 thus not documented
self.param_id_map = kwargs.pop("param_id_map", {})

def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
date = np.datetime64(date).astype(datetime.datetime)
dates = [date + h for h in self.checkpoint.lagged]
ds = self.retrieve(variables=self.variables, dates=dates)
res = self._create_input_state(ds, variables=None, date=date)
res = self._create_input_state(ds, variables=None, date=date, **kwargs)
return res

def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State:
Expand Down
Loading
Loading