diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 5600cb254..2e593532e 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -256,8 +256,7 @@ def check_name( resolution: str, dates: list[datetime.datetime], frequency: datetime.timedelta, - raise_exception: bool = True, - is_test: bool = False, + raise_exception: bool = False, ) -> None: """Check the name of the dataset. @@ -271,15 +270,13 @@ def check_name( The frequency of the dataset. raise_exception : bool, optional Whether to raise an exception if the name is invalid. - is_test : bool, optional - Whether this is a test. """ basename, _ = os.path.splitext(os.path.basename(self.path)) try: DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() except Exception as e: - if raise_exception and not is_test: - raise e + if raise_exception: + raise else: LOG.warning(f"Dataset name error: {e}") @@ -577,7 +574,6 @@ def __init__( use_threads: bool = False, statistics_temp_dir: str | None = None, progress: Any = None, - test: bool = False, cache: str | None = None, **kwargs: Any, ): @@ -599,8 +595,6 @@ def __init__( The directory for temporary statistics. progress : Any, optional The progress indicator. - test : bool, optional - Whether this is a test. cache : Optional[str], optional The cache directory. """ @@ -613,9 +607,8 @@ def __init__( self.use_threads = use_threads self.statistics_temp_dir = statistics_temp_dir self.progress = progress - self.test = test - self.main_config = loader_config(config, is_test=test) + self.main_config = loader_config(config) # self.registry.delete() ?? self.tmp_statistics.delete() @@ -748,7 +741,6 @@ def _run(self) -> int: self.dataset.check_name( raise_exception=self.check_name, - is_test=self.test, resolution=resolution, dates=dates, frequency=frequency, diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index bbeaee83a..2b55d673c 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -18,8 +18,6 @@ from anemoi.utils.config import load_any_dict_format from earthkit.data.core.order import normalize_order_by -from anemoi.datasets.dates.groups import Groups - LOG = logging.getLogger(__name__) @@ -340,63 +338,13 @@ def _prepare_serialisation(o: Any) -> Any: return str(o) -def set_to_test_mode(cfg: dict) -> None: - """Modifies the configuration to run in test mode. - - Parameters - ---------- - cfg : dict - The configuration dictionary. - """ - NUMBER_OF_DATES = 4 - - LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") - groups = Groups(**LoadersConfig(cfg).dates) - - dates = groups.provider.values - cfg["dates"] = dict( - start=dates[0], - end=dates[NUMBER_OF_DATES - 1], - frequency=groups.provider.frequency, - group_by=NUMBER_OF_DATES, - ) - - num_ensembles = count_ensembles(cfg) - - def set_element_to_test(obj): - if isinstance(obj, (list, tuple)): - for v in obj: - set_element_to_test(v) - return - if isinstance(obj, (dict, DotDict)): - if "grid" in obj and num_ensembles > 1: - previous = obj["grid"] - obj["grid"] = "20./20." - LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") - if "number" in obj and num_ensembles > 1: - if isinstance(obj["number"], (list, tuple)): - previous = obj["number"] - obj["number"] = previous[0:3] - LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") - for k, v in obj.items(): - set_element_to_test(v) - if "constants" in obj: - constants = obj["constants"] - if "param" in constants and isinstance(constants["param"], list): - constants["param"] = ["cos_latitude"] - - set_element_to_test(cfg) - - -def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: +def loader_config(config: dict) -> LoadersConfig: """Loads and validates the configuration for dataset loaders. Parameters ---------- config : dict The configuration dictionary. - is_test : bool, optional - Whether to run in test mode. Defaults to False. Returns ------- @@ -404,8 +352,6 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: The validated configuration object. """ config = Config(config) - if is_test: - set_to_test_mode(config) obj = LoadersConfig(config) # yaml round trip to check that serialisation works as expected @@ -426,6 +372,9 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: LOG.info(f"Setting env variable {k}={v}") os.environ[k] = str(v) + # Used by pytest only + # copy.pop('checks', None) + return copy diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 8e3cebc08..48a507bb2 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -97,6 +97,7 @@ def load_one( if isinstance(dataset, xr.Dataset): data = dataset else: + print(f"Opening dataset {dataset} with options {options}") data = xr.open_dataset(dataset, **options) fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) diff --git a/tests/create/accumulate.yaml b/tests/create/accumulate.yaml new file mode 100644 index 000000000..156dbaaac --- /dev/null +++ b/tests/create/accumulate.yaml @@ -0,0 +1,30 @@ +dates: + start: 2021-01-10 18:00:00 + #start: 2021-01-10 19:00:00 + end: 2021-01-12 12:00:00 + frequency: 6h + +input: + accumulations2: + expver: "0001" + class: ea + + stream: oper + #stream: enda + + grid: 20./20. + #grid: o96 + levtype: sfc + param: [ tp, cp ] + # accumulation_period: [0, 6] + accumulation_period: 24 + +checks: +- values: + variable: tp + minimum: 0.0 + maximum: 0.06885338 +- values: + variable: cp + minimum: 0.0 + maximum: 0.038482666 diff --git a/tests/create/accumulation.yaml b/tests/create/accumulation.yaml index 2483a111d..718586db0 100644 --- a/tests/create/accumulation.yaml +++ b/tests/create/accumulation.yaml @@ -5,16 +5,26 @@ dates: frequency: 6h input: - accumulations: - expver: "0001" - class: ea + accumulations: + expver: "0001" + class: ea - stream: oper - #stream: enda + stream: oper + #stream: enda - grid: 20./20. - #grid: o96 - levtype: sfc - param: [ tp , cp] - # accumulation_period: [0, 6] - accumulation_period: 24 + grid: 20./20. + #grid: o96 + levtype: sfc + param: [ tp, cp ] + # accumulation_period: [0, 6] + accumulation_period: 24 + +checks: +- values: + variable: tp + minimum: 0.0 + maximum: 0.06885338 +- values: + variable: cp + minimum: 0.0 + maximum: 0.038482666 diff --git a/tests/create/concat.yaml b/tests/create/concat.yaml index 4df851c23..233c4144b 100644 --- a/tests/create/concat.yaml +++ b/tests/create/concat.yaml @@ -11,22 +11,25 @@ common: levtype: sfc stream: oper type: an - param: [2t] + param: [ 2t ] input: concat: - - dates: - start: 2020-12-30 00:00:00 - end: 2021-01-01 12:00:00 - frequency: 12h - mars: - <<: *mars_request - - dates: - start: 2021-01-02 00:00:00 - end: 2021-01-03 12:00:00 - frequency: 12h - mars: - <<: *mars_request + - dates: + start: 2020-12-30 00:00:00 + end: 2021-01-01 12:00:00 + frequency: 12h + mars: + <<: *mars_request + - dates: + start: 2021-01-02 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + mars: + <<: *mars_request statistics: end: 2021 + +checks: + - compare_to_reference: {} diff --git a/tests/create/join.yaml b/tests/create/join.yaml index 6e96a271e..058722c22 100644 --- a/tests/create/join.yaml +++ b/tests/create/join.yaml @@ -41,3 +41,6 @@ naming_scheme: "{param}_{levelist}{level_units}_{accumultion_period}" statistics: end: 2021 + +checks: + - compare_to_reference: {} diff --git a/tests/create/missing.yaml b/tests/create/missing.yaml index 95e26d4c0..e8428bb08 100644 --- a/tests/create/missing.yaml +++ b/tests/create/missing.yaml @@ -27,3 +27,6 @@ input: statistics: end: 2021-01-02 + +checks: + - compare_to_reference: {} diff --git a/tests/create/nan.yaml b/tests/create/nan.yaml index f4f135143..b2560ba18 100644 --- a/tests/create/nan.yaml +++ b/tests/create/nan.yaml @@ -16,3 +16,6 @@ input: statistics: end: 2020 allow_nans: [sst] + +checks: + - compare_to_reference: {} diff --git a/tests/create/pipe.yaml b/tests/create/pipe.yaml index ee10bc1ba..00d5b6738 100644 --- a/tests/create/pipe.yaml +++ b/tests/create/pipe.yaml @@ -40,3 +40,6 @@ input: statistics: end: 2021 + +checks: + - compare_to_reference: {} diff --git a/tests/create/recentre.yaml b/tests/create/recentre.yaml index 1203d51fb..9d766c8b2 100644 --- a/tests/create/recentre.yaml +++ b/tests/create/recentre.yaml @@ -84,3 +84,8 @@ input: - sin_julian_day - sin_local_time - insolation + + +slow_test: true +checks: + - none: {} diff --git a/tests/create/regrid.yaml b/tests/create/regrid.yaml index 00528648a..7b3612989 100644 --- a/tests/create/regrid.yaml +++ b/tests/create/regrid.yaml @@ -28,3 +28,6 @@ input: method: nearest in_grid: o32 out_grid: o48 + +checks: + - none: {} diff --git a/tests/create/repeated-dates.yaml b/tests/create/repeated-dates.yaml new file mode 100644 index 000000000..4a97929b6 --- /dev/null +++ b/tests/create/repeated-dates.yaml @@ -0,0 +1,31 @@ +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + +input: + repeated_dates: + mode: constant + source: + mars: + expver: "0001" + class: ea + grid: 20./20. + levtype: sfc + stream: oper + type: an + param: [ 2t ] + date: 1990-01-01 + +checks: +- dates: + - 2020-12-30 00:00:00 + - 2020-12-30 12:00:00 + - 2020-12-31 00:00:00 + - 2020-12-31 12:00:00 + - 2021-01-01 00:00:00 + - 2021-01-01 12:00:00 + - 2021-01-02 00:00:00 + - 2021-01-02 12:00:00 + - 2021-01-03 00:00:00 + - 2021-01-03 12:00:00 diff --git a/tests/create/s3-winds.yaml b/tests/create/s3-winds.yaml new file mode 100644 index 000000000..9227e594d --- /dev/null +++ b/tests/create/s3-winds.yaml @@ -0,0 +1,37 @@ +name: test_s3_winds_dataset +description: Test creation of a dataset from S3-hosted wind data +attribution: DMI +license: CC-BY-4.0 + +dates: + start: 2020-01-01T00:00:00 + end: 2020-01-01T12:00:00 + frequency: 12h + +input: + join: + - pipe: + - xarray-zarr: + url: s3://dmi-danra-05/single_levels.zarr + options: + storage_options: + anon: true + param: + - u10m + - v10m + - uv-to-ddff: + u_component: u10m + v_component: v10m + wind_speed: ws + wind_direction: wdir + convention: meteo + radians: false +build: + variable_naming: "{param}_{pressure}_{altitude}" + +slow_test: true +checks: +- variables: [ wdir, ws ] +- dates: + - 2020-01-01T00:00:00 + - 2020-01-01T12:00:00 diff --git a/tests/create/test_create.py b/tests/create/test_create.py index dd3f37864..59d09b2a2 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -14,24 +14,31 @@ from unittest.mock import patch import pytest +import yaml from anemoi.transform.filter import Filter from anemoi.transform.filters import filter_registry from anemoi.utils.testing import GetTestArchive from anemoi.utils.testing import GetTestData from anemoi.utils.testing import skip_if_offline -from .utils.compare import Comparer +from .utils.checks import check_dataset from .utils.create import create_dataset from .utils.mock_sources import LoadSource HERE = os.path.dirname(__file__) # find_yamls -NAMES = sorted([os.path.basename(path).split(".")[0] for path in glob.glob(os.path.join(HERE, "*.yaml"))]) -SKIP = ["recentre"] -SKIP += ["accumulation"] # test not in s3 yet -SKIP += ["regrid"] -NAMES = [name for name in NAMES if name not in SKIP] -assert NAMES, "No yaml files found in " + HERE + +NAMES = [] +for path in glob.glob(os.path.join(HERE, "*.yaml")): + name, _ = os.path.splitext(os.path.basename(path)) + with open(path) as f: + conf = yaml.safe_load(f) + if conf.get("skip_test", False): + continue + if conf.get("slow_test", False): + NAMES.append(pytest.param(name, marks=pytest.mark.slow)) + continue + NAMES.append(name) # Used by pipe.yaml @@ -73,14 +80,10 @@ def test_run(name: str, get_test_archive: GetTestArchive, load_source: LoadSourc with patch("earthkit.data.from_source", load_source): config = os.path.join(HERE, name + ".yaml") output = os.path.join(HERE, name + ".zarr") - is_test = False - - create_dataset(config=config, output=output, delta=["12h"], is_test=is_test) - directory = get_test_archive(f"anemoi-datasets/create/mock-mars/{name}.zarr.tgz") - reference = os.path.join(directory, name + ".zarr") + create_dataset(config=config, output=output, delta=["12h"]) - Comparer(output_path=output, reference_path=reference).compare() + check_dataset(name, config, output, get_test_archive) if __name__ == "__main__": diff --git a/tests/create/utils/checks.py b/tests/create/utils/checks.py new file mode 100644 index 000000000..bd715e75b --- /dev/null +++ b/tests/create/utils/checks.py @@ -0,0 +1,338 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import os +from typing import Any + +import numpy as np +import yaml +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.stores import open_zarr + + +class _Check: + + def __init__( + self, + name: str, + config_path: str, + dataset_path: str, + get_test_archive: callable, + **kwargs: Any, + ) -> None: + self.name = name + self.config_path = config_path + self.dataset_path = dataset_path + self.get_test_archive = get_test_archive + self.kwargs = kwargs + + +class CompareToReferenceCheck(_Check): + """Class to compare datasets and their metadata.""" + + def __init__( + self, + name: str, + config_path: str, + dataset_path: str, + get_test_archive: callable, + ignore_keys=None, + **kwargs: Any, + ) -> None: + """Initialize the Comparer instance. + + Parameters + ---------- + output_path : str, optional + The path to the output dataset. + reference_path : str, optional + The path to the reference dataset. + """ + + super().__init__(name, config_path, dataset_path, get_test_archive, **kwargs) + + directory = get_test_archive(f"anemoi-datasets/create/mock-mars/{name}.zarr.tgz") + reference = os.path.join(directory, name + ".zarr") + + self.reference_path = reference + print(f"Comparing {self.dataset_path} and {self.reference_path}") + + self.z_output = open_zarr(self.dataset_path) + self.z_reference = open_zarr(self.reference_path) + + self.z_reference["data"] + self.ds_output = open_dataset(self.dataset_path) + self.ds_reference = open_dataset(self.reference_path) + self.ignore_keys = set(ignore_keys) if ignore_keys else set() + + self.ignore_keys.update( + { + "metadata.latest_write_timestamp", + "metadata.uuid", + "metadata.provenance_load", + "metadata.total_size", + "metadata.history", + "metadata.recipe.checks", + } + ) + + def compare_datasets(self, a: object, b: object) -> None: + """Compare two datasets. + + Parameters + ---------- + a : object + The first dataset. + b : object + The second dataset. + + Raises + ------ + AssertionError + If the datasets do not match. + """ + assert a.shape == b.shape, (a.shape, b.shape) + assert (a.dates == b.dates).all(), (a.dates, b.dates) + for a_, b_ in zip(a.variables, b.variables): + assert a_ == b_, (a, b) + assert a.missing == b.missing, "Missing are different" + + for i_date, date in zip(range(a.shape[0]), a.dates): + + if i_date in a.missing: + continue + + for i_param in range(a.shape[1]): + param = a.variables[i_param] + assert param == b.variables[i_param], ( + date, + param, + a.variables[i_param], + b.variables[i_param], + ) + a_ = a[i_date, i_param] + b_ = b[i_date, i_param] + assert a.shape == b.shape, (date, param, a.shape, b.shape) + + a_nans = np.isnan(a_) + b_nans = np.isnan(b_) + assert np.all(a_nans == b_nans), (date, param, "nans are different") + + a_ = np.where(a_nans, 0, a_) + b_ = np.where(b_nans, 0, b_) + + delta = a_ - b_ + max_delta = np.max(np.abs(delta)) + abs_error = np.abs(a_ - b_) + rel_error = np.abs(a_ - b_) / (np.abs(b_) + 1e-10) # Avoid division by zero + assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta, np.max(abs_error), np.max(rel_error)) + + def compare_statistics(self, ds1: object, ds2: object) -> None: + """Compare the statistics of two datasets. + + Parameters + ---------- + ds1 : object + The first dataset. + ds2 : object + The second dataset. + + Raises + ------ + AssertionError + If the statistics do not match. + """ + vars1 = ds1.variables + vars2 = ds2.variables + assert len(vars1) == len(vars2) + for v1, v2 in zip(vars1, vars2): + idx1 = ds1.name_to_index[v1] + idx2 = ds2.name_to_index[v2] + assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() + assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() + assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() + assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() + + def compare_dot_zattrs(self, a: dict, b: dict, errors: list, *path) -> None: + """Compare the attributes of two Zarr datasets.""" + + name = ".".join(path) + if name in self.ignore_keys: + return + + if type(a) is not type(b): + msg = f"❌ {name} type mismatch actual != expected : {a} ({type(a)}) != {b} ({type(b)})" + errors.append(msg) + return + + if isinstance(a, dict): + a_keys = list(a.keys()) + b_keys = list(b.keys()) + for k in set(a_keys) | set(b_keys): + + name = ".".join(path + (k,)) + + if name in self.ignore_keys: + continue + + if k not in a_keys: + errors.append(f"❌ {name} : missing key (only in reference)") + continue + + if k not in b_keys: + errors.append(f"❌ {name} : additional key (missing in reference)") + continue + + self.compare_dot_zattrs(a[k], b[k], errors, *path, k) + + return + + if isinstance(a, list): + if len(a) != len(b): + errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") + return + + for i, (v, w) in enumerate(zip(a, b)): + self.compare_dot_zattrs(v, w, errors, *path, str(i)) + + return + + try: + a, b = frequency_to_timedelta(a), frequency_to_timedelta(b) + except Exception: + pass + + if a != b: + msg = f"❌ {name} actual != expected : {a} != {b}" + errors.append(msg) + + def run(self) -> None: + """Compare the output dataset with the reference dataset. + + Raises + ------ + AssertionError + If the datasets or their metadata do not match. + """ + errors = [] + self.compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), errors, "metadata") + + if errors: + print() + print("Comparison failed") + print("\n".join(errors)) + + print() + print("⚠️ To update the reference data, run this:") + print("cd " + os.path.dirname(self.dataset_path)) + base = os.path.basename(self.dataset_path) + print(f"tar zcf {base}.tgz {base}") + print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") + print() + raise AssertionError("\n".join(errors)) + + self.compare_datasets(self.ds_output, self.ds_reference) + self.compare_statistics(self.ds_output, self.ds_reference) + # do not compare tendencies statistics yet, as we don't know yet if they should stay + + +class NoneCheck(_Check): + + def run(self) -> None: + pass + + +class _ItemCheck(_Check): + + def normalise(self, item: Any) -> Any: + """Normalize the item for comparison.""" + return item + + def run(self) -> None: + item = self.normalise(self.kwargs.get(self.item_name)) + + ds = open_dataset(self.dataset_path) + dataset_item = self.normalise(getattr(ds, self.item_name)) + + assert item == dataset_item, (item, dataset_item) + + +class VariablesCheck(_ItemCheck): + """Check for variables presence in the dataset.""" + + item_name = "variables" + + +class DatesCheck(_ItemCheck): + """Check for dates presence in the dataset.""" + + item_name = "dates" + + def normalise(self, item): + return [np.datetime64(v) for v in item] + + +class ValuesCheck(_Check): + def __init__( + self, + name: str, + config_path: str, + dataset_path: str, + get_test_archive: callable, + variable: str, + maximum: float, + minimum: float, + **kwargs: Any, + ) -> None: + super().__init__(name, config_path, dataset_path, get_test_archive, **kwargs) + self.variable = variable + self.maximum = maximum + self.minimum = minimum + + def run(self) -> None: + ds = open_dataset(self.dataset_path) + idx = ds.name_to_index[self.variable] + data = ds.data[:, idx] + + actual_max = np.nanmax(data) + actual_min = np.nanmin(data) + + assert actual_max == self.maximum, (self.variable, actual_max, self.maximum) + assert actual_min == self.minimum, (self.variable, actual_min, self.minimum) + + +def check_dataset(name: str, config_path: str, dataset_path: str, get_test_archive: callable) -> None: + """Check the created dataset against a set of checks.""" + + config = yaml.safe_load(open(config_path)) + checks = config.get("checks") + + if not checks: + raise ValueError(f"No checks defined in {config_path}") + + for c in checks: + + if isinstance(c, str): + check = c + kwargs = {} + else: + check, kwargs = next(iter(c.items())) + if not isinstance(kwargs, dict): + kwargs = {check: kwargs} + + check = "".join(word.capitalize() for word in check.split("_")) + "Check" + + if check not in globals(): + raise ValueError(f"Check {check} not implemented") + + print(f"Running check: {check} with args: {kwargs}") + check = globals()[check](name, config_path, dataset_path, get_test_archive, **kwargs) + check.run() diff --git a/tests/create/utils/compare.py b/tests/create/utils/compare.py deleted file mode 100644 index 56b7d0f82..000000000 --- a/tests/create/utils/compare.py +++ /dev/null @@ -1,227 +0,0 @@ -# (C) Copyright 2025- Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import os - -import numpy as np -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets import open_dataset -from anemoi.datasets.data.stores import open_zarr - - -class Comparer: - """Class to compare datasets and their metadata. - - Parameters - ---------- - output_path : str, optional - The path to the output dataset. - reference_path : str, optional - The path to the reference dataset. - """ - - def __init__(self, output_path: str = None, reference_path: str = None) -> None: - """Initialize the Comparer instance. - - Parameters - ---------- - output_path : str, optional - The path to the output dataset. - reference_path : str, optional - The path to the reference dataset. - """ - self.output_path = output_path - self.reference_path = reference_path - print(f"Comparing {self.output_path} and {self.reference_path}") - - self.z_output = open_zarr(self.output_path) - self.z_reference = open_zarr(self.reference_path) - - self.z_reference["data"] - self.ds_output = open_dataset(self.output_path) - self.ds_reference = open_dataset(self.reference_path) - - @staticmethod - def compare_datasets(a: object, b: object) -> None: - """Compare two datasets. - - Parameters - ---------- - a : object - The first dataset. - b : object - The second dataset. - - Raises - ------ - AssertionError - If the datasets do not match. - """ - assert a.shape == b.shape, (a.shape, b.shape) - assert (a.dates == b.dates).all(), (a.dates, b.dates) - for a_, b_ in zip(a.variables, b.variables): - assert a_ == b_, (a, b) - assert a.missing == b.missing, "Missing are different" - - for i_date, date in zip(range(a.shape[0]), a.dates): - if i_date in a.missing: - continue - for i_param in range(a.shape[1]): - param = a.variables[i_param] - assert param == b.variables[i_param], ( - date, - param, - a.variables[i_param], - b.variables[i_param], - ) - a_ = a[i_date, i_param] - b_ = b[i_date, i_param] - assert a.shape == b.shape, (date, param, a.shape, b.shape) - - a_nans = np.isnan(a_) - b_nans = np.isnan(b_) - assert np.all(a_nans == b_nans), (date, param, "nans are different") - - a_ = np.where(a_nans, 0, a_) - b_ = np.where(b_nans, 0, b_) - - delta = a_ - b_ - max_delta = np.max(np.abs(delta)) - abs_error = np.abs(a_ - b_) - rel_error = np.abs(a_ - b_) / (np.abs(b_) + 1e-10) # Avoid division by zero - assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta, np.max(abs_error), np.max(rel_error)) - - @staticmethod - def compare_statistics(ds1: object, ds2: object) -> None: - """Compare the statistics of two datasets. - - Parameters - ---------- - ds1 : object - The first dataset. - ds2 : object - The second dataset. - - Raises - ------ - AssertionError - If the statistics do not match. - """ - vars1 = ds1.variables - vars2 = ds2.variables - assert len(vars1) == len(vars2) - for v1, v2 in zip(vars1, vars2): - idx1 = ds1.name_to_index[v1] - idx2 = ds2.name_to_index[v2] - assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() - assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() - assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() - assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() - - @staticmethod - def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: - """Compare the attributes of two Zarr datasets. - - Parameters - ---------- - a : dict - The attributes of the first dataset. - b : dict - The attributes of the second dataset. - path : str - The current path in the attribute hierarchy. - errors : list - The list to store error messages. - """ - if isinstance(a, dict): - a_keys = list(a.keys()) - b_keys = list(b.keys()) - for k in set(a_keys) | set(b_keys): - if k not in a_keys: - errors.append(f"❌ {path}.{k} : missing key (only in reference)") - continue - if k not in b_keys: - errors.append(f"❌ {path}.{k} : additional key (missing in reference)") - continue - - if k in [ - "timestamp", - "uuid", - "latest_write_timestamp", - "history", - "provenance", - "provenance_load", - "description", - "config_path", - "total_size", - ]: - if type(a[k]) is not type(b[k]): - errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") - continue - - Comparer.compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) - - return - - if isinstance(a, list): - if len(a) != len(b): - errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") - return - - for i, (v, w) in enumerate(zip(a, b)): - Comparer.compare_dot_zattrs(v, w, f"{path}.{i}", errors) - - return - - if type(a) is not type(b): - msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" - errors.append(msg) - return - - # convert a and b from frequency strings to timedeltas when : - # - path ends with .period.0 or .period.n were n is int - # - key is "frequency" - if (path.split(".")[-2] == "period" and path.split(".")[-1].isdigit()) or (path.split(".")[-1] == "frequency"): - a = frequency_to_timedelta(a) - b = frequency_to_timedelta(b) - - if a != b: - msg = f"❌ {path} actual != expected : {a} != {b}" - errors.append(msg) - - def compare(self) -> None: - """Compare the output dataset with the reference dataset. - - Raises - ------ - AssertionError - If the datasets or their metadata do not match. - """ - errors = [] - self.compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) - if errors: - print("Comparison failed") - print("\n".join(errors)) - - if errors: - print() - - print() - print("⚠️ To update the reference data, run this:") - print("cd " + os.path.dirname(self.output_path)) - base = os.path.basename(self.output_path) - print(f"tar zcf {base}.tgz {base}") - print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") - print() - raise AssertionError("Comparison failed") - - self.compare_datasets(self.ds_output, self.ds_reference) - self.compare_statistics(self.ds_output, self.ds_reference) - # do not compare tendencies statistics yet, as we don't know yet if they should stay diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index a57022bab..a773975fd 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -24,7 +24,6 @@ def create_dataset( config: str | dict[str, Any], output: str | None, delta: list[str] | None = None, - is_test: bool = False, ) -> str: """Create a dataset based on the provided configuration. @@ -36,8 +35,6 @@ def create_dataset( The output path for the dataset. If None, a temporary directory will be created. delta : Optional[List[str]], optional List of delta for secondary statistics, by default None. - is_test : bool, optional - Flag indicating if the dataset creation is for testing purposes, by default False. Returns ------- @@ -52,7 +49,7 @@ def create_dataset( if output is None: output = tempfile.mkdtemp(suffix=".zarr") - creator_factory("init", config=config, path=output, overwrite=True, test=is_test).run() + creator_factory("init", config=config, path=output, overwrite=True, test=True).run() creator_factory("load", path=output).run() creator_factory("finalise", path=output).run() creator_factory("patch", path=output).run()