Skip to content

Commit 00c64c0

Browse files
authored
fix(input cutout): CutoutInput improvements and upgrades (#355)
A couple improvements to the cutout input: - Remove `pre_processors` from top level cutout - Add `_input` to cutout states - Alter `CutoutInput` to take a list to provide an explicit ordering Closes #356
1 parent 76b57b9 commit 00c64c0

File tree

7 files changed

+54
-72
lines changed

7 files changed

+54
-72
lines changed

docs/inference/configs/inputs.rst

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ consistently with what is done in ``anemoi-datasets``, see `here
150150

151151
The ``cutout`` input nests the ``private_attributes`` of the sources
152152
states so may prevent usage of some keys. To restore these, use the
153-
``extract_source`` postprocessor.
153+
``extract_from_state`` postprocessor.
154154

155155
To extract regions from different sources within the ``cutout`` input,
156156
your checkpoint must contain the cutout masks as supporting arrays. You
@@ -175,11 +175,13 @@ running ``anemoi-inference patch <your_checkpoint>``.
175175
176176
input:
177177
cutout:
178-
lam_0:
178+
- lam_0:
179179
mars: {}
180180
mask: null
181181
182-
An example configuration for the ``cutout`` input is shown below:
182+
An example configuration for the ``cutout`` input is shown below, the
183+
sources can be provided as a list of positional arguments, with each
184+
source specified as a mapping from source name to source configuration:
183185

184186
.. literalinclude:: yaml/inputs_11.yaml
185187
:language: yaml
@@ -188,11 +190,11 @@ The different sources are specified exactly as you would for a single
188190
source, as shown in the previous sections.
189191

190192
An easy way to then extract regions from the predicated state is to use
191-
the ``extract_source`` postprocessor, which will subset the state to the
192-
specified source. For example, to extract the ``lam_0`` source from the
193-
state, you can use the following configuration:
193+
the ``extract_from_state`` postprocessor, which will subset the state to
194+
the specified source. For example, to extract the ``lam_0`` source from
195+
the state, you can use the following configuration:
194196

195197
.. code:: yaml
196198
197199
post_processors:
198-
- extract_source: 'lam_0'
200+
- extract_from_state: 'lam_0'
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
input:
22
cutout:
3-
lam_0:
4-
grib: <path_to_your_grib_file>
5-
global:
6-
grib: <path_to_your_grib_file>
3+
- lam_0:
4+
grib: <path_to_your_grib_file>
5+
- global:
6+
grib: <path_to_your_grib_file>

src/anemoi/inference/inputs/cutout.py

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,18 @@
1111
import logging
1212
from collections import defaultdict
1313
from collections.abc import Iterable
14-
from collections.abc import Mapping
1514

1615
import numpy as np
1716

1817
from anemoi.inference.input import Input
1918
from anemoi.inference.inputs import create_input
2019
from anemoi.inference.inputs import input_registry
2120
from anemoi.inference.types import Date
22-
from anemoi.inference.types import ProcessorConfig
2321
from anemoi.inference.types import State
2422

2523
LOG = logging.getLogger(__name__)
2624

2725

28-
def contains_key(obj, key: str) -> bool:
29-
"""Recursively check if `key` exists anywhere in a nested config (dict/DotDict/lists)."""
30-
if isinstance(obj, Mapping):
31-
if key in obj:
32-
return True
33-
return any(contains_key(v, key) for v in obj.values())
34-
if isinstance(obj, (list, tuple, set)):
35-
return any(contains_key(v, key) for v in obj)
36-
return False
37-
38-
3926
def _mask_and_combine_states(
4027
existing_state: State,
4128
new_state: State,
@@ -112,58 +99,48 @@ def _extract_and_add_private_attributes(
11299
class Cutout(Input):
113100
"""Combines one or more LAMs into a global source using cutouts."""
114101

115-
# TODO: Does this need an ordering?
116-
117102
def __init__(
118103
self,
119104
context,
120-
*,
121-
variables: list[str] | None = None,
122-
pre_processors: list[ProcessorConfig] | None = None,
123-
purpose: str | None = None,
124-
**sources: dict[str, dict],
105+
*args: dict[str, dict],
106+
sources: list[dict[str, dict]] | None = None,
107+
**kwargs,
125108
):
126109
"""Create a cutout input from a list of sources.
127110
128111
Parameters
129112
----------
130113
context : dict
131114
The context runner.
132-
sources : dict of sources
133-
A dictionary of sources to combine.
134-
variables : list[str] | None
135-
List of variables to be handled by the input, or None for a sensible default variables.
136-
pre_processors : Optional[List[ProcessorConfig]], default None
137-
Pre-processors to apply to the input. Note that pre_processors are applied to each sub-input.
138-
purpose : Optional[str]
139-
The purpose of the input.
115+
sources : list[dict[str, dict]]
116+
List of sources / inputs to combine, the order defines the order in which they are combined.
140117
"""
118+
if any(x in kwargs for x in ["lam_0", "global"]): # Capture common update issues
119+
raise KeyError(
120+
"Cutout input has changed to set the sub-inputs as a list, if using the config, prefix each input with `-` to update."
121+
)
122+
123+
super().__init__(context, pre_processors=None, **kwargs)
141124

142-
super().__init__(context, variables=variables, pre_processors=pre_processors, purpose=purpose)
125+
if not sources:
126+
sources = []
127+
sources = [*args, *sources]
143128

144129
self.sources: dict[str, Input] = {}
145130
self.masks: dict[str, np.ndarray | slice] = {}
146131

147-
for src, cfg in sources.items():
132+
for inp in sources:
133+
if not isinstance(inp, dict) or len(inp) != 1:
134+
raise ValueError("Each source in cutout inputs must be a dict with a single key-value pair.")
135+
src, cfg = next(iter(inp.items()))
136+
148137
if isinstance(cfg, str):
149138
mask = f"{src}/cutout_mask"
150139
else:
151140
cfg = cfg.copy()
152141
mask = cfg.pop("mask", f"{src}/cutout_mask")
153142

154-
if contains_key(cfg, "pre_processors"):
155-
combined_pre_processors = (pre_processors or []).extend(cfg.get("pre_processors", []))
156-
self.sources[src] = create_input(
157-
context, cfg, variables=variables, pre_processors=combined_pre_processors, purpose=purpose
158-
)
159-
else:
160-
self.sources[src] = create_input(
161-
context,
162-
cfg,
163-
variables=variables,
164-
purpose=purpose,
165-
pre_processors=pre_processors,
166-
)
143+
self.sources[src] = create_input(context, cfg, variables=self.variables, purpose=self.purpose)
167144

168145
if isinstance(mask, str):
169146
self.masks[src] = self.sources[src].checkpoint.load_supporting_array(mask)
@@ -236,6 +213,8 @@ def create_input_state(self, *, date: Date | None, **kwargs) -> State:
236213
_private_attributes["_mask"] = _mask_private_attributes
237214

238215
combined_state.update(_private_attributes)
216+
combined_state["_input"] = self
217+
239218
return combined_state
240219

241220
def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State:
@@ -264,4 +243,6 @@ def load_forcings_state(self, *, dates: list[Date], current_state: State) -> Sta
264243
combined_fields = _mask_and_combine_states(combined_fields, source_state, source_mask, source_state.keys())
265244

266245
current_state["fields"] |= combined_fields
246+
current_state["_input"] = self
247+
267248
return current_state

src/anemoi/inference/inputs/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def create_input_state(self, *, date: Date | None = None, **kwargs) -> State:
153153
if self.context.trace:
154154
self.context.trace.from_input(variable, self)
155155

156+
input_state["_input"] = self
157+
156158
return input_state
157159

158160
def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State:

src/anemoi/inference/runners/default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def _combine_states(self, *states: dict[str, Any]) -> dict[str, Any]:
448448
if not np.array_equal(combined[key], value):
449449
raise ValueError(
450450
f"Key '{key}' has different array values in the states: "
451-
f"{combined[key]} and {value}."
451+
f"{combined[key]} ({combined[key].shape}) and {value} ({value.shape})."
452452
f" Input: {first_input} vs {this_input}."
453453
)
454454
continue

tests/integration/meteoswiss-sgm-cosmo/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
write_initial_state: false
2424
input:
2525
cutout:
26-
lam_0:
26+
- lam_0:
2727
grib:
2828
path: ${input:0}
2929
namer:
@@ -60,7 +60,7 @@
6060
- lsm
6161
- - shortName: TOT_PREC
6262
- tp
63-
global:
63+
- global:
6464
grib: ${input:1}
6565
output:
6666
grib:

tests/unit/inputs/test_cutout.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def runner() -> None:
6767
def test_cutout_no_mask(runner: Runner):
6868
from anemoi.inference.inputs.cutout import Cutout
6969

70-
cutout_config = {
71-
"lam": {"mask": None, "dummy": {}},
72-
"global": {"mask": None, "dummy": {}},
73-
}
74-
cutout_input = Cutout(runner, variables=["2t"], **cutout_config)
70+
cutout_config = [
71+
{"lam": {"mask": None, "dummy": {}}},
72+
{"global": {"mask": None, "dummy": {}}},
73+
]
74+
cutout_input = Cutout(runner, variables=["2t"], sources=cutout_config)
7575
input_state = cutout_input.create_input_state(date=datetime.datetime.fromisoformat("2020-01-01T00:00"))
7676
number_of_grid_points = runner.checkpoint.number_of_grid_points
7777

@@ -89,11 +89,11 @@ def test_cutout_no_mask(runner: Runner):
8989
def test_cutout_with_slice(runner: Runner):
9090
from anemoi.inference.inputs.cutout import Cutout
9191

92-
cutout_config = {
93-
"lam": {"mask": slice(0, 10), "dummy": {}},
94-
"global": {"mask": slice(10, 25), "dummy": {}},
95-
}
96-
cutout_input = Cutout(runner, variables=["2t"], **cutout_config)
92+
cutout_config = [
93+
{"lam": {"mask": slice(0, 10), "dummy": {}}},
94+
{"global": {"mask": slice(10, 25), "dummy": {}}},
95+
]
96+
cutout_input = Cutout(runner, variables=["2t"], sources=cutout_config)
9797
assert list(cutout_input.sources.keys()) == ["lam", "global"]
9898

9999
input_state = cutout_input.create_input_state(date=datetime.datetime.fromisoformat("2020-01-01T00:00"))
@@ -120,11 +120,8 @@ def test_cutout_with_array(runner: Runner):
120120
global_mask = np.zeros(number_of_grid_points, dtype=bool)
121121
global_mask[10:25] = True
122122

123-
cutout_config = {
124-
"lam": {"mask": lam_mask, "dummy": {}},
125-
"global": {"mask": global_mask, "dummy": {}},
126-
}
127-
cutout_input = Cutout(runner, variables=["2t"], **cutout_config)
123+
cutout_config = [{"lam": {"mask": lam_mask, "dummy": {}}}, {"global": {"mask": global_mask, "dummy": {}}}]
124+
cutout_input = Cutout(runner, variables=["2t"], sources=cutout_config)
128125
input_state = cutout_input.create_input_state(date=datetime.datetime.fromisoformat("2020-01-01T00:00"))
129126

130127
assert "_mask" in input_state

0 commit comments

Comments
 (0)