|
| 1 | +# (C) Copyright 2025 Anemoi contributors. |
| 2 | +# |
| 3 | +# This software is licensed under the terms of the Apache Licence Version 2.0 |
| 4 | +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. |
| 5 | +# |
| 6 | +# In applying this licence, ECMWF does not waive the privileges and immunities |
| 7 | +# granted to it by virtue of its status as an intergovernmental organisation |
| 8 | +# nor does it submit to any jurisdiction. |
| 9 | + |
| 10 | +import pytest |
| 11 | + |
| 12 | +from anemoi.inference.metadata import Metadata |
| 13 | +from anemoi.inference.testing.mock_checkpoint import mock_load_metadata |
| 14 | + |
| 15 | + |
| 16 | +@pytest.mark.parametrize( |
| 17 | + "initial, patch, expected", |
| 18 | + [ |
| 19 | + ( |
| 20 | + {"config": {"dataloader": {"dataset": "abc"}}}, |
| 21 | + {"config": {"dataloader": {"something": {"else": "123"}}}}, |
| 22 | + {"dataset": "abc", "something": {"else": "123"}}, |
| 23 | + ), |
| 24 | + ( |
| 25 | + {"config": {"dataloader": [{"dataset": "abc"}, {"dataset": "xyz"}]}}, |
| 26 | + {"config": {"dataloader": {"cutout": [{"dataset": "123"}, {"dataset": "456"}]}}}, |
| 27 | + {"cutout": [{"dataset": "123"}, {"dataset": "456"}]}, |
| 28 | + ), |
| 29 | + ( |
| 30 | + {"config": {"dataloader": "abc"}}, |
| 31 | + {"config": {"dataloader": "xyz"}}, |
| 32 | + "xyz", |
| 33 | + ), |
| 34 | + ], |
| 35 | +) |
| 36 | +def test_patch(initial, patch, expected): |
| 37 | + metadata = Metadata(initial) |
| 38 | + assert metadata._metadata == initial |
| 39 | + |
| 40 | + metadata.patch(patch) |
| 41 | + assert metadata._metadata["config"]["dataloader"] == expected |
| 42 | + |
| 43 | + |
| 44 | +def test_constant_fields_patch(): |
| 45 | + model_metadata = mock_load_metadata("unit/checkpoints/atmos.json", supporting_arrays=False) |
| 46 | + metadata = Metadata(model_metadata) |
| 47 | + |
| 48 | + fields = ["z", "sdor", "slor", "lsm"] |
| 49 | + metadata.patch({"dataset": {"constant_fields": fields}}) |
| 50 | + assert metadata._metadata["dataset"]["constant_fields"] == fields |
| 51 | + |
| 52 | + # check that the rest of the metadata is still the same after patching |
| 53 | + metadata._metadata["dataset"].pop("constant_fields") |
| 54 | + assert model_metadata == metadata._metadata |
0 commit comments