Skip to content

Commit a412434

Browse files
fix: boundary forcings (#388)
## Description Fixes #387 ## What problem does this change solve? Boundary forcings have been broken for the umptieth time. We fix it and also introduce an new integration test to ensure they wont be broken so easily anymore. ## Additional notes ## We also introduce support for tests with checks on multiple outputs. By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 05a0739 commit a412434

File tree

6 files changed

+119
-15
lines changed

6 files changed

+119
-15
lines changed

src/anemoi/inference/forcings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def load_forcings_array(self, dates: list[Date], current_state: State) -> FloatA
281281
The loaded forcings as a numpy array.
282282
"""
283283
data = self._state_to_numpy(
284-
self.input.load_forcings_state(variables=self.variables, dates=dates, current_state=current_state),
284+
self.input.load_forcings_state(dates=dates, current_state=current_state),
285285
self.variables,
286286
dates,
287287
)

src/anemoi/inference/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def create_dynamic_forcings_inputs(self, input_state: State) -> list[Forcings]:
332332

333333
def create_boundary_forcings_inputs(self, input_state: State) -> list[Forcings]:
334334

335-
if not self.checkpoint.has_supporting_array("boundary"):
335+
if not self.checkpoint.has_supporting_array("output_mask"):
336336
return []
337337

338338
result = []

src/anemoi/inference/runners/default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def create_boundary_forcings_input(self) -> Input:
287287
Input
288288
The created boundary forcings input.
289289
"""
290-
variables = self.variables.retrieved_boundary_forcings_variables()
290+
variables = self.variables.retrieved_prognostic_variables()
291291
config = self._input_forcings("boundary_forcings", "-boundary", "forcings", "input") if variables else "empty"
292292
input = create_input(self, config, variables=variables, purpose="boundary_forcings")
293293
LOG.info("Boundary forcings input: %s", input)

src/anemoi/inference/testing/checks.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def check_cutout_with_xarray(
197197
for var in checkpoint.prognostic_variables:
198198
assert var in ds.data_vars, f"Variable {var} not found in output file."
199199
ref_idx = ref_ds.name_to_index[var]
200+
# loop through time dimension
200201
for data in ds[var]:
201202
assert np.allclose(
202203
data.values, ref_ds[0, ref_idx, 0, :]
@@ -206,6 +207,88 @@ def check_cutout_with_xarray(
206207
raise NotImplementedError("Reference file check is not implemented yet.")
207208

208209

210+
@testing_registry.register("check_boundary_forcings_with_xarray")
211+
def check_boundary_forcings_with_xarray(
212+
*,
213+
file: Path,
214+
checkpoint: "Checkpoint",
215+
reference_dataset={},
216+
reference_file=None,
217+
**kwargs,
218+
) -> None:
219+
LOG.info(f"Checking boundary forcings: {file}")
220+
221+
# get boundary mask from checkpoint
222+
supporting_arrays = checkpoint.supporting_arrays
223+
LOG.info(f"Supporting arrays in checkpoint: {supporting_arrays.keys()}")
224+
if "output_mask" not in supporting_arrays:
225+
LOG.warning("Boundary forcings check is trivial. Consider removing from test config.")
226+
return
227+
else:
228+
boundary_mask = ~supporting_arrays["output_mask"]
229+
230+
import numpy as np
231+
import xarray as xr
232+
233+
ds = xr.open_dataset(file)
234+
235+
# check if boundary mask compatible with output
236+
n_grid = len(ds["latitude"].values)
237+
n_mask = len(boundary_mask)
238+
assert (
239+
n_grid == n_mask
240+
), f"Number of grid points ({n_grid}) does not match size of output mask in checkpoint ({n_mask})."
241+
dates = ds["time"].astype("datetime64[s]").values
242+
freq = dates[1] - dates[0]
243+
if reference_dataset:
244+
from anemoi.datasets import open_dataset
245+
246+
ref_ds = open_dataset(**reference_dataset, start=dates[0])
247+
ref_freq = np.timedelta64(ref_ds.frequency)
248+
ref_dates = ref_ds.dates.astype("datetime64[s]")
249+
step = freq // ref_freq
250+
251+
# make sure all dates needed are present and we will step through them consistently
252+
assert set(dates[:-1]).issubset(ref_dates), f"Reference dataset is missing dates {set(dates) - set(ref_dates)}"
253+
assert step == freq / ref_freq, f"Frequency mismatch between output ({freq}) and reference ({ref_freq})"
254+
LOG.info(f"Inference output has a timestep that is {step} times that of the reference dataset.")
255+
256+
if ref_ds.shape[2] != 1:
257+
raise NotImplementedError("Support for ensembles is not implemented yet.")
258+
ref_values = ref_ds[:, :, 0, :]
259+
260+
# make sure we have the reference dataset on the output grid
261+
lats = ref_ds.latitudes
262+
lons = ref_ds.longitudes
263+
if "grid_indices" in supporting_arrays:
264+
LOG.info("Using grid indices for boundary forcings check.")
265+
grid_indices = supporting_arrays["grid_indices"]
266+
ref_values = ref_values[:, :, grid_indices]
267+
lats = lats[grid_indices]
268+
lons = lons[grid_indices]
269+
assert np.allclose(lats, ds.latitude.values), "Latitudes don't match between output and reference."
270+
assert np.allclose(lons, ds.longitude.values), "Longitudes don't match between output and reference."
271+
272+
# check boundary forcings
273+
# each inference step takes us from input i to output i
274+
# boundary forcings are applied to output i in the creation of input i+1
275+
# the current mock inference model simply passes the input, so output i+1 == input i+1
276+
# the boundary forcing applied on output i (ref dataset at i) appear thus directly in output i+1
277+
for var in checkpoint.prognostic_variables:
278+
assert var in ds.data_vars, f"Variable {var} not found in output file."
279+
ref_idx = ref_ds.name_to_index[var]
280+
for i in range(len(dates) - 1):
281+
out = ds[var].isel(time=i + 1).values
282+
forcing = ref_values[i * step, ref_idx]
283+
assert np.allclose(
284+
out[boundary_mask], forcing[boundary_mask]
285+
), f"Boundary forcing for variable {var} does not match reference data at {ref_dates[i*step]}."
286+
287+
elif reference_file:
288+
# check against a reference file, implement when needed
289+
raise NotImplementedError("Reference file check is not implemented yet.")
290+
291+
209292
@testing_registry.register("check_file_exist")
210293
def check_file_exist(*, file: Path, **kwargs) -> None:
211294
LOG.info(f"Checking file exists: {file}")
Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
- name: dataset-in-netcdf-out
22
input: null # anemoi-datasets will download the zarr.zip at runtime
3-
output: output.nc
3+
output:
4+
- lam_output.nc
5+
- full_output.nc
46
checks:
57
- check_cutout_with_xarray:
8+
file: ${output:0}
69
mask: 'lam_0'
710
reference_date: 2020-01-02T00:00:00
811
reference_dataset:
912
dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o48-2020-6h.zarr.zip
1013
area: [70, -55, 10, 70]
14+
- check_boundary_forcings_with_xarray:
15+
file: ${output:1}
16+
reference_dataset:
17+
dataset:
18+
cutout:
19+
- dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o48-2020-6h.zarr.zip
20+
area: [70, -55, 10, 70]
21+
- dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o32-2020-6h.zarr.zip
1122
- check_with_xarray:
23+
file: ${output:0}
1224
check_accum: tp
1325
check_nans: true
1426
inference_config:
@@ -22,11 +34,13 @@
2234
area: [70, -55, 10, 70]
2335
- dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o32-2020-6h.zarr.zip
2436
output:
25-
netcdf:
26-
post_processors:
27-
- extract_mask:
28-
mask: lam_0/cutout_mask
29-
as_slice: true
30-
path: ${output:}
37+
tee:
38+
- netcdf:
39+
post_processors:
40+
- extract_mask:
41+
mask: lam_0/cutout_mask
42+
as_slice: true
43+
path: ${output:0}
44+
- netcdf: ${output:1}
3145
post_processors:
3246
- accumulate_from_start_of_forecast

tests/integration/test_integration.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,22 @@ def _markers(config: DictConfig):
6969

7070
class Setup(NamedTuple):
7171
config: OmegaConf
72-
output: Path
72+
output: list[Path]
7373

7474

7575
@pytest.fixture(params=MODEL_CONFIGS)
7676
def test_setup(request, get_test_data: GetTestData, tmp_path: Path) -> Setup:
7777
model, config = request.param
7878
input = config.input
79-
output = tmp_path / config.output
79+
output = config.output
8080
inference_config = config.inference_config
8181
s3_path = f"anemoi-integration-tests/inference/{model}"
8282

83+
# set output path(s)
84+
if not isinstance(output, (list, ListConfig)):
85+
output = [output]
86+
output = [tmp_path / file_name for file_name in output]
87+
8388
# download input file(s)
8489
if not input:
8590
input = []
@@ -116,7 +121,7 @@ def load_array(name):
116121

117122
# substitute inference config with real paths
118123
OmegaConf.register_new_resolver("input", lambda i=0: str(input_data[i]), replace=True)
119-
OmegaConf.register_new_resolver("output", lambda: str(output), replace=True)
124+
OmegaConf.register_new_resolver("output", lambda i=0: str(output[i]), replace=True)
120125
OmegaConf.register_new_resolver("checkpoint", lambda: str(checkpoint_path), replace=True)
121126
OmegaConf.register_new_resolver("s3", lambda: str(f"{TEST_DATA_URL}{s3_path}"), replace=True)
122127
OmegaConf.register_new_resolver("sys.prefix", lambda: sys.prefix, replace=True)
@@ -144,7 +149,8 @@ def test_integration(test_setup: Setup, tmp_path: Path) -> None:
144149
runner = create_runner(config)
145150
runner.execute()
146151

147-
assert (test_setup.output).exists(), "Output file was not created."
152+
for file in test_setup.output:
153+
assert file.exists(), f"Output file was not created: {file}."
148154

149155
checkpoint_output_variables = _typed_variables_output(runner._checkpoint)
150156
LOG.info(f"Checkpoint output variables: {checkpoint_output_variables}")
@@ -159,9 +165,10 @@ def test_integration(test_setup: Setup, tmp_path: Path) -> None:
159165
VariableFromMarsVocabulary(var, {"param": var}) for var in expected_variables_config
160166
] or checkpoint_output_variables
161167

168+
file = kwargs.pop("file", test_setup.output[0])
162169
testing_registry.create(
163170
check,
164-
file=test_setup.output,
171+
file=file,
165172
expected_variables=expected_variables,
166173
checkpoint=runner._checkpoint,
167174
**kwargs,

0 commit comments

Comments
 (0)