Skip to content

Commit ff222e2

Browse files
authored
fix(metadata): Patching when receiving entry is not a dict (#338)
## Description Fixing an edge case in the `external_graph` runner when used with a LAM model. When using this runner with the `graph_dataset` option, it will patch the `dataloader.dataset` entry: https://github.com/ecmwf/anemoi-inference/blob/196aa55a57204c4bcb823e7b6546d79505f5eaee/src/anemoi/inference/runners/external_graph.py#L189-L199 But if the existing `dataloader.dataset` entry in the metadata happens to be a list of datasets, the patching crashes. When patching an existing entry, the patch function did not check if the receiving entry is also a dictionary. In the case of a list on the receiving side, it would try to index the list with the incoming dictionary keys. This PR fixes that. Also adds a test for this function. ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent c119c59 commit ff222e2

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

src/anemoi/inference/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def patch(self, patch: dict) -> None:
11291129
def merge(main: dict[str, Any], patch: dict[str, Any]) -> None:
11301130

11311131
for k, v in patch.items():
1132-
if isinstance(v, dict):
1132+
if isinstance(v, dict) and isinstance(main.get(k, {}), dict):
11331133
if k not in main:
11341134
main[k] = {}
11351135
merge(main[k], v)

tests/unit/test_metadata.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import pytest
11+
12+
from anemoi.inference.metadata import Metadata
13+
from anemoi.inference.testing.mock_checkpoint import mock_load_metadata
14+
15+
16+
@pytest.mark.parametrize(
17+
"initial, patch, expected",
18+
[
19+
(
20+
{"config": {"dataloader": {"dataset": "abc"}}},
21+
{"config": {"dataloader": {"something": {"else": "123"}}}},
22+
{"dataset": "abc", "something": {"else": "123"}},
23+
),
24+
(
25+
{"config": {"dataloader": [{"dataset": "abc"}, {"dataset": "xyz"}]}},
26+
{"config": {"dataloader": {"cutout": [{"dataset": "123"}, {"dataset": "456"}]}}},
27+
{"cutout": [{"dataset": "123"}, {"dataset": "456"}]},
28+
),
29+
(
30+
{"config": {"dataloader": "abc"}},
31+
{"config": {"dataloader": "xyz"}},
32+
"xyz",
33+
),
34+
],
35+
)
36+
def test_patch(initial, patch, expected):
37+
metadata = Metadata(initial)
38+
assert metadata._metadata == initial
39+
40+
metadata.patch(patch)
41+
assert metadata._metadata["config"]["dataloader"] == expected
42+
43+
44+
def test_constant_fields_patch():
45+
model_metadata = mock_load_metadata("unit/checkpoints/atmos.json", supporting_arrays=False)
46+
metadata = Metadata(model_metadata)
47+
48+
fields = ["z", "sdor", "slor", "lsm"]
49+
metadata.patch({"dataset": {"constant_fields": fields}})
50+
assert metadata._metadata["dataset"]["constant_fields"] == fields
51+
52+
# check that the rest of the metadata is still the same after patching
53+
metadata._metadata["dataset"].pop("constant_fields")
54+
assert model_metadata == metadata._metadata

0 commit comments

Comments
 (0)