|
11 | 11 | import logging |
12 | 12 | from collections import defaultdict |
13 | 13 | from collections.abc import Iterable |
14 | | -from collections.abc import Mapping |
15 | 14 |
|
16 | 15 | import numpy as np |
17 | 16 |
|
18 | 17 | from anemoi.inference.input import Input |
19 | 18 | from anemoi.inference.inputs import create_input |
20 | 19 | from anemoi.inference.inputs import input_registry |
21 | 20 | from anemoi.inference.types import Date |
22 | | -from anemoi.inference.types import ProcessorConfig |
23 | 21 | from anemoi.inference.types import State |
24 | 22 |
|
25 | 23 | LOG = logging.getLogger(__name__) |
26 | 24 |
|
27 | 25 |
|
28 | | -def contains_key(obj, key: str) -> bool: |
29 | | - """Recursively check if `key` exists anywhere in a nested config (dict/DotDict/lists).""" |
30 | | - if isinstance(obj, Mapping): |
31 | | - if key in obj: |
32 | | - return True |
33 | | - return any(contains_key(v, key) for v in obj.values()) |
34 | | - if isinstance(obj, (list, tuple, set)): |
35 | | - return any(contains_key(v, key) for v in obj) |
36 | | - return False |
37 | | - |
38 | | - |
39 | 26 | def _mask_and_combine_states( |
40 | 27 | existing_state: State, |
41 | 28 | new_state: State, |
@@ -112,58 +99,48 @@ def _extract_and_add_private_attributes( |
112 | 99 | class Cutout(Input): |
113 | 100 | """Combines one or more LAMs into a global source using cutouts.""" |
114 | 101 |
|
115 | | - # TODO: Does this need an ordering? |
116 | | - |
117 | 102 | def __init__( |
118 | 103 | self, |
119 | 104 | context, |
120 | | - *, |
121 | | - variables: list[str] | None = None, |
122 | | - pre_processors: list[ProcessorConfig] | None = None, |
123 | | - purpose: str | None = None, |
124 | | - **sources: dict[str, dict], |
| 105 | + *args: dict[str, dict], |
| 106 | + sources: list[dict[str, dict]] | None = None, |
| 107 | + **kwargs, |
125 | 108 | ): |
126 | 109 | """Create a cutout input from a list of sources. |
127 | 110 |
|
128 | 111 | Parameters |
129 | 112 | ---------- |
130 | 113 | context : dict |
131 | 114 | The context runner. |
132 | | - sources : dict of sources |
133 | | - A dictionary of sources to combine. |
134 | | - variables : list[str] | None |
135 | | - List of variables to be handled by the input, or None for a sensible default variables. |
136 | | - pre_processors : Optional[List[ProcessorConfig]], default None |
137 | | - Pre-processors to apply to the input. Note that pre_processors are applied to each sub-input. |
138 | | - purpose : Optional[str] |
139 | | - The purpose of the input. |
| 115 | + sources : list[dict[str, dict]] |
| 116 | + List of sources / inputs to combine, the order defines the order in which they are combined. |
140 | 117 | """ |
| 118 | + if any(x in kwargs for x in ["lam_0", "global"]): # Capture common update issues |
| 119 | + raise KeyError( |
| 120 | + "Cutout input has changed to set the sub-inputs as a list, if using the config, prefix each input with `-` to update." |
| 121 | + ) |
| 122 | + |
| 123 | + super().__init__(context, pre_processors=None, **kwargs) |
141 | 124 |
|
142 | | - super().__init__(context, variables=variables, pre_processors=pre_processors, purpose=purpose) |
| 125 | + if not sources: |
| 126 | + sources = [] |
| 127 | + sources = [*args, *sources] |
143 | 128 |
|
144 | 129 | self.sources: dict[str, Input] = {} |
145 | 130 | self.masks: dict[str, np.ndarray | slice] = {} |
146 | 131 |
|
147 | | - for src, cfg in sources.items(): |
| 132 | + for inp in sources: |
| 133 | + if not isinstance(inp, dict) or len(inp) != 1: |
| 134 | + raise ValueError("Each source in cutout inputs must be a dict with a single key-value pair.") |
| 135 | + src, cfg = next(iter(inp.items())) |
| 136 | + |
148 | 137 | if isinstance(cfg, str): |
149 | 138 | mask = f"{src}/cutout_mask" |
150 | 139 | else: |
151 | 140 | cfg = cfg.copy() |
152 | 141 | mask = cfg.pop("mask", f"{src}/cutout_mask") |
153 | 142 |
|
154 | | - if contains_key(cfg, "pre_processors"): |
155 | | - combined_pre_processors = (pre_processors or []).extend(cfg.get("pre_processors", [])) |
156 | | - self.sources[src] = create_input( |
157 | | - context, cfg, variables=variables, pre_processors=combined_pre_processors, purpose=purpose |
158 | | - ) |
159 | | - else: |
160 | | - self.sources[src] = create_input( |
161 | | - context, |
162 | | - cfg, |
163 | | - variables=variables, |
164 | | - purpose=purpose, |
165 | | - pre_processors=pre_processors, |
166 | | - ) |
| 143 | + self.sources[src] = create_input(context, cfg, variables=self.variables, purpose=self.purpose) |
167 | 144 |
|
168 | 145 | if isinstance(mask, str): |
169 | 146 | self.masks[src] = self.sources[src].checkpoint.load_supporting_array(mask) |
@@ -236,6 +213,8 @@ def create_input_state(self, *, date: Date | None, **kwargs) -> State: |
236 | 213 | _private_attributes["_mask"] = _mask_private_attributes |
237 | 214 |
|
238 | 215 | combined_state.update(_private_attributes) |
| 216 | + combined_state["_input"] = self |
| 217 | + |
239 | 218 | return combined_state |
240 | 219 |
|
241 | 220 | def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: |
@@ -264,4 +243,6 @@ def load_forcings_state(self, *, dates: list[Date], current_state: State) -> Sta |
264 | 243 | combined_fields = _mask_and_combine_states(combined_fields, source_state, source_mask, source_state.keys()) |
265 | 244 |
|
266 | 245 | current_state["fields"] |= combined_fields |
| 246 | + current_state["_input"] = self |
| 247 | + |
267 | 248 | return current_state |
0 commit comments