diff --git a/.gitignore b/.gitignore index f3777cde2..031461ff0 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,13 @@ _version.py *.to_upload tempCodeRunnerFile.python Untitled-*.py +*.zip +*.json +*.db +*.tgz +_api/ +trace.txt +?/ +*.prof +prof/ +*.gz diff --git a/pyproject.toml b/pyproject.toml index bfcb3c5cb..44df9260c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,12 @@ dependencies = [ "anemoi-transform>=0.1.10", "anemoi-utils[provenance]>=0.4.32", "cfunits", + "glom", + "jsonschema", "numcodecs<0.16", # Until we move to zarr3 "numpy", "pyyaml", + "ruamel-yaml", "semantic-version", "tqdm", "zarr<=2.18.4", diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py new file mode 100644 index 000000000..45400806c --- /dev/null +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -0,0 +1,93 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import argparse +import logging +import sys +from typing import Any + +import yaml + +from anemoi.datasets.create import validate_config + +from .. import Command +from .format import format_recipe +from .migrate import migrate_recipe + +LOG = logging.getLogger(__name__) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + + command_parser.add_argument("--validate", action="store_true", help="Validate recipe.") + command_parser.add_argument("--format", action="store_true", help="Format the recipe.") + command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + + group = command_parser.add_mutually_exclusive_group() + group.add_argument("--inplace", action="store_true", help="Overwrite the recipe file in place.") + group.add_argument("--output", type=str, help="Output file path for the converted recipe.") + + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + + if not args.validate and not args.format and not args.migrate: + args.validate = True + + with open(args.path) as file: + config = yaml.safe_load(file) + + assert isinstance(config, dict) + + if args.validate: + if args.inplace and (not args.format and not args.migrate): + argparse.ArgumentError(None, "--inplace is not supported with --validate.") + + if args.output and (not args.format and not args.migrate): + argparse.ArgumentError(None, "--output is not supported with --validate.") + + validate_config(config) + LOG.info(f"{args.path}: Recipe is valid.") + return + + if args.migrate: + config = migrate_recipe(args, config) + if config is None: + LOG.info(f"{args.path}: No changes needed.") + return + + args.format = True + + if args.format: + formatted = format_recipe(args, config) + assert "dates" in formatted + f = sys.stdout + if args.output: + f = open(args.output, "w") + + if args.inplace: + f = open(args.path, "w") + + print(formatted, file=f) + f.close() + + +command = Recipe diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py new file mode 100644 index 000000000..872060981 --- /dev/null +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -0,0 +1,55 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging + +from ...dumper import yaml_dump + +LOG = logging.getLogger(__name__) + + +def make_dates(config): + if isinstance(config, dict): + return {k: make_dates(v) for k, v in config.items()} + if isinstance(config, list): + return [make_dates(v) for v in config] + if isinstance(config, str): + try: + return datetime.datetime.fromisoformat(config) + except ValueError: + return config + return config + + +ORDER = ( + "name", + "description", + "dataset_status", + "licence", + "attribution", + "env", + "dates", + "common", + "data_sources", + "input", + "output", + "statistics", + "build", + "platform", +) + + +def format_recipe(args, config: dict) -> str: + + config = make_dates(config) + assert config + + return yaml_dump(config, order=ORDER) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py new file mode 100644 index 000000000..03da61fbc --- /dev/null +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -0,0 +1,555 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import sys +from collections.abc import Sequence +from typing import Any + +from glom import assign +from glom import delete +from glom import glom + +from anemoi.datasets.create import validate_config +from anemoi.datasets.dumper import yaml_dump + +LOG = logging.getLogger(__name__) + + +def find_paths(data, target_key=None, target_value=None, *path): + + matches = [] + + if isinstance(data, dict): + for k, v in data.items(): + if (target_key is not None and k == target_key) or (target_value is not None and v == target_value): + matches.append(list(path) + [k]) + matches.extend(find_paths(v, target_key, target_value, *path, k)) + elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)): + for i, item in enumerate(data): + matches.extend(find_paths(item, target_key, target_value, *path, str(i))) + return matches + + +def find_chevrons(data, *path): + + matches = [] + + if isinstance(data, dict): + for k, v in data.items(): + if k == "<<": + matches.append(list(path) + [k]) + matches.extend(find_chevrons(v, *path, k)) + elif isinstance(data, list): + for i, item in enumerate(data): + matches.extend(find_chevrons(item, *path, str(i))) + return matches + + +def find_paths_in_substrees(path, obj, cur_path=None): + if cur_path is None: + cur_path = [] + matches = [] + try: + glom(obj, path) # just to check existence + matches.append(cur_path + path.split(".")) + except Exception: + pass + + if isinstance(obj, dict): + for k, v in obj.items(): + matches.extend(find_paths_in_substrees(path, v, cur_path + [k])) + elif isinstance(obj, list): + for i, v in enumerate(obj): + matches.extend(find_paths_in_substrees(path, v, cur_path + [str(i)])) + return matches + + +MIGRATE = { + "output.statistics_end": "statistics.end", + "has_nans": "statistics.allow_nans", + "loop.dates.group_by": "build.group_by", + "loop.0.dates.group_by": "build.group_by", + "loop.dates": "dates", + "loop.0.dates": "dates", + "copyright": "attribution", + "dates.<<": "dates", + "options.group_by": "build.group_by", + "loops.0.loop_a.dates": "dates", + "loop.0.loop_a.dates": "dates", + "dates.stop": "dates.end", + "dates.group_by": "build.group_by", + "include.mars": "data_sources.mars.mars", + "ensemble_dimension": "build.ensemble_dimension", + "flatten_grid": "build.flatten_grid", +} + +DELETE = [ + "purpose", + # "input.join.0.label", + "status", + "common", + "config_format_version", + "aliases", + # "platform", + "loops.0.loop_a.applies_to", + "loop.0.loop_a.applies_to", + "dataset_status", + "alias", + "resources", + "input.dates.<<", + "input.dates.join.0.label.name", +] + + +SOURCES = { + "oper-accumulations": "accumulations", + "era5-accumulations": "accumulations", + "ensemble-perturbations": "recentre", + "ensemble_perturbations": "recentre", + "perturbations": "recentre", + "custom-regrid": "regrid", +} + +MARKER = object() + + +def _delete(config, path): + x = glom(config, path, default=MARKER) + if x is MARKER: + return + delete(config, path) + + +def _move(config, path, new_path, result): + x = glom(config, path, default=MARKER) + if x is MARKER: + return + delete(result, path) + assign(result, new_path, x, missing=dict) + + +def _fix_input_0(config): + if isinstance(config["input"], dict): + return + + input = config["input"] + new_input = [] + + blocks = {} + first = None + for block in input: + assert isinstance(block, dict), block + + assert len(block) == 1, block + + block_name, values = list(block.items())[0] + + if "kwargs" in values: + inherit = values.pop("inherit", None) + assert len(values) == 1, values + values = values["kwargs"] + values.pop("date", None) + source_name = values.pop("name", None) + + if inherit is not None: + if inherit.startswith("$"): + inherit = inherit[1:] + inherited = blocks[inherit].copy() + inherited.update(values) + values = inherited + + if first is None: + first = source_name + + blocks[block_name] = values.copy() + + new_input.append({SOURCES.get(source_name, source_name): values.copy()}) + else: + assert False, f"Block {block_name} does not have 'kwargs': {values}" + + blocks[block_name] = values.copy() + + config["input"] = dict(join=new_input) + + +def _fix_input_1(result, config): + if isinstance(config["input"], dict): + return + + input = config["input"] + join = [] + for k in input: + assert isinstance(k, dict) + assert len(k) == 1, f"Input key {k} is not a string: {input}" + name, values = list(k.items())[0] + join.append(values) + + result["input"] = {"join": join} + config["input"] = result["input"].copy() + + +def remove_empties(config: dict) -> None: + """Remove empty dictionaries and lists from the config.""" + if isinstance(config, dict): + keys_to_delete = [k for k, v in config.items() if v in (None, {}, [], [{}])] + + for k in keys_to_delete: + del config[k] + + for k, v in config.items(): + remove_empties(v) + + if isinstance(config, list): + for item in config: + remove_empties(item) + + +def _fix_loops(result: dict, config: dict) -> None: + if "loops" not in config: + return + + input = config["input"] + loops = config["loops"] + + assert isinstance(loops, list), loops + assert isinstance(input, list), input + + entries = {} + dates_block = None + for loop in loops: + assert isinstance(loop, dict), loop + assert len(loop) == 1, loop + loop = list(loop.values())[0] + applies_to = loop["applies_to"] + dates = loop["dates"] + assert isinstance(applies_to, list), (applies_to, loop) + for a in applies_to: + entries[a] = dates.copy() + + if "start" in dates: + start = dates["start"] + else: + start = max(dates["values"]) + + if "end" in dates or "stop" in dates: + end = dates.get("end", dates.get("stop")) + else: + end = min(dates["values"]) + + if dates_block is None: + dates_block = { + "start": start, + "end": end, + } + + if "frequency" in dates: + if "frequency" not in dates_block: + dates_block["frequency"] = dates["frequency"] + else: + assert dates_block["frequency"] == dates["frequency"], (dates_block["frequency"], dates["frequency"]) + + dates_block["start"] = min(dates_block["start"], start) + dates_block["end"] = max(dates_block["end"], end) + + concat = [] + result["input"] = {"concat": concat} + + print("Found loops:", entries) + + for block in input: + assert isinstance(block, dict), block + assert len(block) == 1, block + name, values = list(block.items())[0] + assert name in entries, f"Loop {name} not found in loops: {list(entries.keys())}" + dates = entries[name].copy() + + assert "kwargs" not in values + + concat.append(dict(dates=dates, **values)) + + d = concat[0]["dates"] + if all(c["dates"] == d for c in concat): + join = [] + for c in concat: + del c["dates"] + join.append(c) + result["input"] = {"join": join} + + del config["loops"] + config["input"] = result["input"].copy() + config["dates"] = dates_block.copy() + del result["loops"] + result["dates"] = dates_block + + +def _fix_other(result: dict, config: dict) -> None: + paths = find_paths(config, target_key="source_or_dataset", target_value="$previous_data") + for p in paths: + print(f"Fixing {'.'.join(p)}") + assign(result, ".".join(p[:-1] + ["template"]), "${input.join.0.mars}", missing=dict) + delete(result, ".".join(p)) + + paths = find_paths(config, target_key="date", target_value="$dates") + for p in paths: + delete(result, ".".join(p)) + + +def _fix_join(result: dict, config: dict) -> None: + print("Fixing join...") + input = config["input"] + if "dates" in input and "join" in input["dates"]: + result["input"]["join"] = input["dates"]["join"] + config["input"]["join"] = input["dates"]["join"].copy() + + if "join" not in input: + return + + join = input["join"] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1 + + key, values = list(j.items())[0] + + if key not in ("label", "source"): + return + + assert isinstance(values, dict), f"Join values for {key} should be a dict: {values}" + if key == "label": + j = values + j.pop("name") + key, values = list(j.items())[0] + + print(values) + source_name = values.pop("name", "mars") + new_join.append( + { + SOURCES.get(source_name, source_name): values, + } + ) + + result["input"] = {"join": new_join} + config["input"] = result["input"].copy() + + +def _fix_sources(config: dict, what) -> None: + + input = config["input"] + if what not in input: + return + + join = input[what] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1, j + + key, values = list(j.items())[0] + + key = SOURCES.get(key, key) + + new_join.append( + { + key: values, + } + ) + + config["input"][what] = new_join + config["input"][what] = new_join.copy() + + +def _assign(config, path, value): + print(f"Assign {path} {value}") + assign(config, path, value) + + +def _fix_chevrons(result: dict, config: dict) -> None: + print("Fixing chevrons...") + paths = find_chevrons(config) + for p in paths: + a = glom(config, ".".join(p)) + b = glom(config, ".".join(p[:-1])) + delete(result, ".".join(p)) + a.update(b) + assign(result, ".".join(p[:-1]), a) + + +def _fix_some(config: dict) -> None: + + paths = find_paths_in_substrees("label.function", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + _assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("constants.source_or_dataset", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + node["template"] = node.pop("source_or_dataset") + if node["template"] == "$previous_data": + node["template"] = "${input.join.0.mars}" + paths = find_paths_in_substrees("constants.template", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + if node["template"] == "$pl_data": + node["template"] = "${input.join.0.mars}" + for d in ("date", "dates", "time"): + paths = find_paths_in_substrees(d, config) + for p in paths: + if len(p) > 1: + node = glom(config, ".".join(p[:-1])) + if isinstance(node, dict) and isinstance(node[d], str) and node[d].startswith("$"): + del node[d] + + paths = find_paths_in_substrees("source.<<", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + node.update(node.pop("<<")) + parent[node.pop("name")] = node + assert len(parent) == 2 + del parent["source"] + + paths = find_paths_in_substrees("label.mars", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("input.dates.join", config) + for p in paths: + node = glom(config, ".".join(p)) + config["input"]["join"] = node + del config["input"]["dates"] + + paths = find_paths_in_substrees("source.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assign(config, ".".join(p[:-2]), {name: node}) + + paths = find_paths_in_substrees("function.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assert node + assign(config, ".".join(p[:-2]), {name: node}) + + +def _migrate(config: dict, n) -> dict: + + result = config.copy() + + _fix_input_0(result) + # _fix_loops(result, config) + # _fix_input_1(result, config) + # _fix_join(result, config) + # _fix_chevrons(result, config) + # _fix_other(result, config) + + for k, v in MIGRATE.items(): + _move(config, k, v, result) + + _fix_some(result) + _fix_sources(result, "join") + + for k in DELETE: + _delete(result, k) + + remove_empties(result) + + return result + + +def migrate(old: dict) -> dict: + + for i in range(10): + new = _migrate(old, i) + if new == old: + return new + old = new + + return new + + +def has_key(config, key: str) -> bool: + if isinstance(config, dict): + if key in config: + return True + for k, v in config.items(): + if has_key(v, key): + return True + if isinstance(config, list): + for item in config: + if has_key(item, key): + return True + return False + + +def has_value(config, value: str) -> bool: + if isinstance(config, dict): + for k, v in config.items(): + if v == value: + return True + if has_value(v, value): + return True + + if isinstance(config, list): + for item in config: + if item == value: + return True + if has_value(item, value): + return True + return config == value + + +def check(config): + + try: + + validate_config(config) + assert config.get("input", {}) + assert config.get("dates", {}) + assert not has_key(config, "label") + assert not has_key(config, "kwargs") + assert not has_value(config, "$previous_data") + assert not has_value(config, "$pl_data") + assert not has_value(config, "$dates") + assert not has_key(config, "inherit") + assert not has_key(config, "source_or_dataset") + assert not has_key(config, "<<") + + for n in SOURCES.keys(): + assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." + + except Exception as e: + print("Validation failed:") + print(e) + print(yaml_dump(config)) + sys.exit(1) + + +def migrate_recipe(args: Any, config) -> None: + + print(f"Migrating {args.path}") + + migrated = migrate(config) + + check(migrated) + if migrated == config: + return None + + return migrated diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index d8ca4c023..acaf3807d 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -865,7 +865,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(group_of_dates=group) + result = self.input.select(argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. @@ -1617,3 +1617,44 @@ def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: )[name] LOG.debug(f"Creating {cls.__name__} with {kwargs}") return cls(**kwargs) + + +def validate_config(config: Any) -> None: + + import json + + import jsonschema + + def _tidy(d): + if isinstance(d, dict): + return {k: _tidy(v) for k, v in d.items()} + + if isinstance(d, list): + return [_tidy(v) for v in d if v is not None] + + # jsonschema does not support datetime.date + if isinstance(d, datetime.datetime): + return d.isoformat() + + if isinstance(d, datetime.date): + return d.isoformat() + + return d + + # https://json-schema.org + + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "schemas", + "recipe.json", + ) + ) as f: + schema = json.load(f) + + try: + jsonschema.validate(instance=_tidy(config), schema=schema) + except jsonschema.exceptions.ValidationError as e: + LOG.error("❌ Config validation failed (jsonschema):") + LOG.error(e.message) + raise diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 72738012f..f0312170d 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -280,6 +280,8 @@ def __init__(self, config: dict, *args, **kwargs): self.output.order_by = normalize_order_by(self.output.order_by) + self.setdefault("dates", Config()) + self.dates["group_by"] = self.build.group_by ########### diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 2e1c18a90..e30ecefb5 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024-2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,21 +7,15 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from copy import deepcopy +from functools import cached_property +from typing import TYPE_CHECKING from typing import Any -from anemoi.datasets.dates.groups import GroupOfDates +from anemoi.datasets.create.input.context.field import FieldContext -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class Context: - """Context for building input data.""" - - pass +if TYPE_CHECKING: + from anemoi.datasets.create.input.action import Recipe class InputBuilder: @@ -34,72 +28,58 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No ---------- config : dict Configuration dictionary. - data_sources : Union[dict, list] + data_sources : dict Data sources. **kwargs : Any Additional keyword arguments. """ self.kwargs = kwargs + self.config = deepcopy(config) + self.data_sources = deepcopy(dict(data_sources=data_sources)) - config = deepcopy(config) - if data_sources: - config = dict( - data_sources=dict( - sources=data_sources, - input=config, - ) - ) - self.config = config - self.action_path = ["input"] - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Any: - """Select data based on the group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - Group of dates to select data for. - - Returns - ------- - Any - Selected data. - """ - from .action import ActionContext + @cached_property + def action(self) -> "Recipe": + """Returns the action object based on the configuration.""" + from .action import Recipe from .action import action_factory - """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the InputBuilder. - - Returns - ------- - str - String representation. - """ - from .action import ActionContext - from .action import action_factory + sources = action_factory(self.data_sources, "data_sources") + input = action_factory(self.config, "input") - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) + return Recipe(input, sources) - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the select operation. + def select(self, argument) -> Any: + """Select data based on the group of dates. Parameters ---------- - group_of_dates : GroupOfDates + argument : GroupOfDates Group of dates to select data for. Returns ------- - str - Trace string. + Any + Selected data. """ - return f"InputBuilder({group_of_dates})" + context = FieldContext(argument, **self.kwargs) + return context.create_result(self.action(context, argument)) + + +def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder: + """Build an InputBuilder instance. + + Parameters + ---------- + config : dict + Configuration dictionary. + data_sources : Union[dict, list] + Data sources. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + InputBuilder + An instance of InputBuilder. + """ + return InputBuilder(config, data_sources, **kwargs) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 1a7eca5c6..7808ae717 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,251 +7,311 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import json import logging -from copy import deepcopy -from typing import Any -from earthkit.data.core.order import build_remapping - -from ...dates.groups import GroupOfDates -from .context import Context -from .template import substitute +from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) class Action: - """Represents an action to be performed within a given context. - - Attributes - ---------- - context : ActionContext - The context in which the action exists. - kwargs : Dict[str, Any] - Additional keyword arguments. - args : Any - Additional positional arguments. - action_path : List[str] - The action path. + """An "Action" represents a single operation described in the yaml configuration, e.g. a source, a filter, + pipe, join, etc. + + See :ref:`operations` for more details. + """ - def __init__( - self, context: "ActionContext", action_path: list[str], /, *args: Any, **kwargs: dict[str, Any] - ) -> None: - """Initialize an Action instance. - - Parameters - ---------- - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - args : Any - Additional positional arguments. - kwargs : Dict[str, Any] - Additional keyword arguments. - """ - if "args" in kwargs and "kwargs" in kwargs: - """We have: - args = [] - kwargs = {args: [...], kwargs: {...}} - move the content of kwargs to args and kwargs. - """ - assert len(kwargs) == 2, (args, kwargs) - assert not args, (args, kwargs) - args = kwargs.pop("args") - kwargs = kwargs.pop("kwargs") - - assert isinstance(context, ActionContext), type(context) - self.context = context - self.kwargs = kwargs - self.args = args - self.action_path = action_path - - @classmethod - def _short_str(cls, x: str) -> str: - """Shorten the string representation if it exceeds 1000 characters. - - Parameters - ---------- - x : str - The string to shorten. - - Returns - ------- - str - The shortened string. - """ - x = str(x) - if len(x) < 1000: - return x - return x[:1000] + "..." - - def _repr(self, *args: Any, _indent_: str = "\n", _inline_: str = "", **kwargs: Any) -> str: - """Generate a string representation of the Action instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str, optional - The indentation string, by default "\n". - _inline_ : str, optional - The inline string, by default "". - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - more = more[:5000] - txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Action instance. - - Returns - ------- - str - The string representation. - """ - return self._repr() - - def select(self, dates: object, **kwargs: Any) -> None: - """Select dates for the action. - - Parameters - ---------- - dates : object - The dates to select. - kwargs : Any - Additional keyword arguments. - """ - self._raise_not_implemented() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the selection of a group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates to trace. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({group_of_dates})" - - -class ActionContext(Context): - """Represents the context in which an action is performed. - - Attributes - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. + def __init__(self, config, *path): + self.config = config + self.path = path + assert path[0] in ( + "input", + "data_sources", + ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" + + +class Concat(Action): + """The Concat contruct is used to concat different actions that are responsible + for delivery fields for different dates. + + See :ref:`building-concat` for more details. + + .. block-code:: yaml + + input: + concat: + - dates: + start: 2023-01-01 + end: 2023-01-31 + frequency: 1d + action: # some action + ... + + - dates: + start: 2023-02-01 + end: 2023-02-28 + frequency: 1d + action: # some action + """ - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: dict[str, Any], use_grib_paramid: bool) -> None: - """Initialize an ActionContext instance. - - Parameters - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - super().__init__() - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - - -def action_factory(config: dict[str, Any], context: ActionContext, action_path: list[str]) -> Action: - """Factory function to create an Action instance based on the configuration. - - Parameters - ---------- - config : Dict[str, Any] - The action configuration. - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - - Returns - ------- - Action - The created Action instance. + def __init__(self, config, *path): + super().__init__(config, *path, "concat") + + assert isinstance(config, list), f"Value must be a dict {list}" + + self.choices = [] + + for i, item in enumerate(config): + + dates = item["dates"] + filtering_dates = DatesProvider.from_config(**dates) + action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) + self.choices.append((filtering_dates, action)) + + def __repr__(self): + return f"Concat({self.choices})" + + def __call__(self, context, argument): + + results = context.empty_result() + + for filtering_dates, action in self.choices: + dates = context.matching_dates(filtering_dates, argument) + if len(dates) == 0: + continue + results += action(context, dates) + + return context.register(results, self.path) + + +class Join(Action): + """Implement the join operation to combine results from multiple actions. + + See :ref:`building-join` for more details. + + .. block-code:: yaml + + input: + join: + - grib: + ... + + - netcdf: # some other action + ... + """ - from .concat import ConcatAction - from .data_sources import DataSourcesAction - from .function import FunctionAction - from .join import JoinAction - from .pipe import PipeAction - from .repeated_dates import RepeatedDatesAction - - # from .data_sources import DataSourcesAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") - - config = deepcopy(config) - key = list(config.keys())[0] - - if isinstance(config[key], list): - args, kwargs = config[key], {} - elif isinstance(config[key], dict): - args, kwargs = [], config[key] - else: - raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") - - cls = { - "data_sources": DataSourcesAction, - "data-sources": DataSourcesAction, - "concat": ConcatAction, - "join": JoinAction, - "pipe": PipeAction, - "function": FunctionAction, - "repeated_dates": RepeatedDatesAction, - "repeated-dates": RepeatedDatesAction, - }.get(key) - - if cls is None: - from ..sources import create_source - - source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source) - - return cls(context, action_path + [key], *args, **kwargs) + + def __init__(self, config, *path): + super().__init__(config, *path, "join") + + assert isinstance(config, list), f"Value of Join Action must be a list, got: {config}" + + self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] + + def __repr__(self): + return f"Join({self.actions})" + + def __call__(self, context, argument): + results = context.empty_result() + + for action in self.actions: + results += action(context, argument) + + return context.register(results, self.path) + + +class Pipe(Action): + """Implement the pipe operation to chain results from a + source through multiple filters. + + See :ref:`building-pipe` for more details. + + .. block-code:: yaml + + input: + pipe: + - grib: + ... + + - rename: + ... + + """ + + def __init__(self, config, *path): + assert isinstance(config, list), f"Value of Pipe Action must be a list, got {config}" + super().__init__(config, *path, "pipe") + self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] + + def __repr__(self): + return f"Pipe({self.actions})" + + def __call__(self, context, argument): + result = context.empty_result() + + for i, action in enumerate(self.actions): + if i == 0: + result = action(context, argument) + else: + result = action(context, result) + + return context.register(result, self.path) + + +class Function(Action): + """Base class for sources and filters.""" + + def __init__(self, config, *path): + super().__init__(config, *path, self.name) + + def __call__(self, context, argument): + + config = context.resolve(self.config) # Substitute the ${} variables in the config + + config["_type"] = self.name # Find a better way to do this + + source = self.create_object(context, config) + + return context.register(self.call_object(context, source, argument), self.path) + + +class DatasetSourceMixin: + """Mixin class for sources defined in anemoi-datasets""" + + def create_object(self, context, config): + from anemoi.datasets.create.sources import create_source as create_datasets_source + + return create_datasets_source(context, config) + + def call_object(self, context, source, argument): + return source.execute(context.source_argument(argument)) + + +class TransformSourceMixin: + """Mixin class for sources defined in anemoi-transform""" + + def create_object(self, context, config): + from anemoi.transform.sources import create_source as create_transform_source + + return create_transform_source(context, config) + + +class TransformFilterMixin: + """Mixin class for filters defined in anemoi-transform""" + + def create_object(self, context, config): + from anemoi.transform.filters import create_filter as create_transform_filter + + return create_transform_filter(context, config) + + def call_object(self, context, filter, argument): + return filter.forward(context.filter_argument(argument)) + + +class FilterFunction(Function): + """Action to call a filter on the argument (e.g. rename, regrid, etc.).""" + + def __call__(self, context, argument): + return self.call(context, argument, context.filter_argument) + + +def _make_name(name, what): + name = name.replace("_", "-") + name = "".join(x.title() for x in name.split("-")) + return name + what.title() + + +def new_source(name, mixin): + return type( + _make_name(name, "source"), + (Function, mixin), + {"name": name}, + ) + + +def new_filter(name, mixin): + return type( + _make_name(name, "filter"), + (Function, mixin), + {"name": name}, + ) + + +class DataSources(Action): + """Action to call a source (e.g. mars, netcdf, grib, etc.).""" + + def __init__(self, config, *path): + super().__init__(config, *path) + assert isinstance(config, (dict, list)), f"Invalid config type: {type(config)}" + if isinstance(config, dict): + self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + else: + self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} + + def __call__(self, context, argument): + for name, source in self.sources.items(): + context.register(source(context, argument), self.path + (name,)) + + +class Recipe(Action): + """Action that represent a recipe (i.e. a sequence of data_sources and input).""" + + def __init__(self, input, data_sources): + self.input = input + self.data_sources = data_sources + + def __call__(self, context, argument): + # Load data_sources + self.data_sources(context, argument) + return self.input(context, argument) + + +KLASS = { + "concat": Concat, + "join": Join, + "pipe": Pipe, + "data-sources": DataSources, +} + +LEN_KLASS = len(KLASS) + + +def make(key, config, *path): + + if LEN_KLASS == len(KLASS): + + # Load pluggins + from anemoi.transform.filters import filter_registry as transform_filter_registry + from anemoi.transform.sources import source_registry as transform_source_registry + + from anemoi.datasets.create.sources import source_registry as dataset_source_registry + + # Register sources, local first + for name in dataset_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + + for name in transform_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + + # Register filters + for name in transform_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) + + return KLASS[key.replace("_", "-")](config, *path) + + +def action_factory(data, *path): + + assert len(path) > 0, f"Path must contain at least one element {path}" + assert path[0] in ("input", "data_sources") + + assert isinstance(data, dict), f"Input data must be a dictionary, got {type(data)}" + assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" + + key, value = next(iter(data.items())) + return make(key, value, *path) diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py deleted file mode 100644 index 90dbac15a..000000000 --- a/src/anemoi/datasets/create/input/concat.py +++ /dev/null @@ -1,161 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from copy import deepcopy -from functools import cached_property -from typing import Any - -from earthkit.data import FieldList - -from anemoi.datasets.dates import DatesProvider - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class ConcatResult(Result): - """Represents the result of concatenating multiple results.""" - - def __init__( - self, - context: object, - action_path: list[str], - group_of_dates: GroupOfDates, - results: list[Result], - **kwargs: Any, - ) -> None: - """Initializes a ConcatResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, group_of_dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the concatenated datasource from all results.""" - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - @property - def variables(self) -> list[str]: - """Returns the list of variables, ensuring all results have the same variables.""" - variables = None - for f in self.results: - if f.empty: - continue - if variables is None: - variables = f.variables - assert variables == f.variables, (variables, f.variables) - assert variables is not None, self.results - return variables - - def __repr__(self) -> str: - """Returns a string representation of the ConcatResult instance. - - Returns - ------- - str - A string representation of the ConcatResult instance. - """ - content = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class ConcatAction(Action): - """Represents an action that concatenates multiple actions based on their dates.""" - - def __init__(self, context: object, action_path: list[str], *configs: dict[str, Any]) -> None: - """Initializes a ConcatAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - configs : Dict[str, Any] - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - parts = [] - for i, cfg in enumerate(configs): - if "dates" not in cfg: - raise ValueError(f"Missing 'dates' in {cfg}") - cfg = deepcopy(cfg) - dates_cfg = cfg.pop("dates") - assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = DatesProvider.from_config(**dates_cfg) - action = action_factory(cfg, context, action_path + [str(i)]) - parts.append((filtering_dates, action)) - self.parts = parts - - def __repr__(self) -> str: - """Returns a string representation of the ConcatAction instance. - - Returns - ------- - str - A string representation of the ConcatAction instance. - """ - content = "\n".join([str(i) for i in self.parts]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> ConcatResult | EmptyResult: - """Selects the concatenated result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - Union[ConcatResult, EmptyResult] - The concatenated result or an empty result. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - results = [] - for filtering_dates, action in self.parts: - newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - if newdates: - results.append(action.select(newdates)) - if not results: - return EmptyResult(self.context, self.action_path, group_of_dates) - - return ConcatResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py deleted file mode 100644 index 5b17afa68..000000000 --- a/src/anemoi/datasets/create/input/context.py +++ /dev/null @@ -1,86 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import textwrap -from typing import Any - -from anemoi.utils.humanize import plural - -from .trace import step -from .trace import trace - -LOG = logging.getLogger(__name__) - - -class Context: - """Class to handle the build context in the dataset creation process.""" - - def __init__(self) -> None: - """Initializes a Context instance.""" - # used_references is a set of reference paths that will be needed - self.used_references = set() - # results is a dictionary of reference path -> obj - self.results = {} - - def will_need_reference(self, key: list | tuple) -> None: - """Marks a reference as needed. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - self.used_references.add(key) - - def notify_result(self, key: list | tuple, result: Any) -> None: - """Notifies that a result is available for a reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - result : Any - The result object. - """ - trace( - "🎯", - step(key), - "notify result", - textwrap.shorten(repr(result).replace(",", ", "), width=40), - plural(len(result), "field"), - ) - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.used_references: - if key in self.results: - raise ValueError(f"Duplicate result {key}") - self.results[key] = result - - def get_result(self, key: list | tuple) -> Any: - """Retrieves the result for a given reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - - Returns - ------- - Any - The result for the given reference. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.results: - return self.results[key] - all_keys = sorted(list(self.results.keys())) - raise ValueError(f"Cannot find result {key} in {all_keys}") diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py new file mode 100644 index 000000000..89df7a727 --- /dev/null +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -0,0 +1,71 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +LOG = logging.getLogger(__name__) + + +class Context(ABC): + """Context for building input data.""" + + def __init__(self, /, argument: Any) -> None: + self.results = {} + self.cache = {} + self.argument = argument + + def trace(self, emoji, *message) -> None: + + print(f"{emoji}: {message}") + + def register(self, data: Any, path: list[str]) -> Any: + + if not path: + return data + + assert path[0] in ("input", "data_sources"), path + + LOG.info(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + LOG.warning(f"Path not found {path}") + for p in sorted(self.results): + LOG.info(f" Available paths: {p}") + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + def create_source(self, config: Any, *path) -> Any: + from anemoi.datasets.create.input.action import action_factory + + if not isinstance(config, dict): + # It is already a result (e.g. ekd.FieldList), loaded from ${a.b.c} + # TODO: something more elegant + return lambda *args, **kwargs: config + + return action_factory(config, *path) + + @abstractmethod + def empty_result(self) -> Any: ... + + @abstractmethod + def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py new file mode 100644 index 000000000..1dd01340e --- /dev/null +++ b/src/anemoi/datasets/create/input/context/field.py @@ -0,0 +1,54 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Any + +from earthkit.data.core.order import build_remapping + +from ..result.field import FieldResult +from . import Context + + +class FieldContext(Context): + + def __init__( + self, + /, + argument: Any, + order_by: str, + flatten_grid: bool, + remapping: dict[str, Any], + use_grib_paramid: bool, + ) -> None: + super().__init__(argument) + self.order_by = order_by + self.flatten_grid = flatten_grid + self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid + self.partial_ok = False + + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument # .dates + + def filter_argument(self, argument: Any) -> Any: + return argument + + def create_result(self, data): + return FieldResult(self, data) + + def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: + from anemoi.datasets.dates.groups import GroupOfDates + + return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 09b411d6e..31bf3d8cc 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -17,7 +17,7 @@ from .action import Action from .action import action_factory from .misc import _tidy -from .result import Result +from .result.field import Result LOG = logging.getLogger(__name__) @@ -55,6 +55,7 @@ def __init__( self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs] self.input = action_factory(input, context, ["input"]) + self.names = [a_path for a_path, config in configs] def select(self, group_of_dates: GroupOfDates) -> "DataSourcesResult": """Selects the data sources result for the given group of dates. diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py deleted file mode 100644 index 935a3e677..000000000 --- a/src/anemoi/datasets/create/input/empty.py +++ /dev/null @@ -1,53 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property - -from earthkit.data import FieldList - -from .misc import assert_fieldlist -from .result import Result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class EmptyResult(Result): - """Class to represent an empty result in the dataset creation process.""" - - empty = True - - def __init__(self, context: object, action_path: list, dates: object) -> None: - """Initializes an EmptyResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - dates : object - The dates object. - """ - super().__init__(context, action_path + ["empty"], dates) - - @cached_property - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns an empty datasource.""" - from earthkit.data import from_source - - return from_source("empty") - - @property - def variables(self) -> list[str]: - """Returns an empty list of variables.""" - return [] diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py deleted file mode 100644 index d6ea4d75c..000000000 --- a/src/anemoi/datasets/create/input/filter.py +++ /dev/null @@ -1,117 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any - -from earthkit.data import FieldList - -from .function import FunctionContext -from .misc import _tidy -from .misc import assert_fieldlist -from .step import StepAction -from .step import StepResult -from .template import notify_result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class FilterStepResult(StepResult): - @property - @notify_result - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns the filtered datasource.""" - ds: FieldList = self.upstream_result.datasource - ds = ds.sel(**self.action.kwargs) - return _tidy(ds) - - -class FilterStepAction(StepAction): - """Represents an action to filter a step result.""" - - result_class: type[FilterStepResult] = FilterStepResult - - -class StepFunctionResult(StepResult): - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource after applying the function.""" - - self.action.filter.context = FunctionContext(self) - try: - return _tidy( - self.action.filter.execute( - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, - ) - ) - - except Exception: - LOG.error(f"Error in {self.action.name}", exc_info=True) - raise - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - A string representation of the traced datasource. - """ - return f"{self.action.name}({self.group_of_dates})" - - -class FunctionStepAction(StepAction): - """Represents an action to apply a function to a step result.""" - - result_class: type[StepFunctionResult] = StepFunctionResult - - def __init__( - self, - context: object, - action_path: list, - previous_step: StepAction, - name: str, - filter: Any, - *args: Any, - **kwargs: Any, - ) -> None: - """Initializes a FunctionStepAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - previous_step : StepAction - The previous step action. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, previous_step, *args, **kwargs) - self.name = name - self.filter = filter diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py deleted file mode 100644 index 51353b8e6..000000000 --- a/src/anemoi/datasets/create/input/function.py +++ /dev/null @@ -1,232 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .template import substitute -from .trace import trace -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results. - """ - - def __init__(self, owner: Result) -> None: - """Initializes a FunctionContext instance. - - Parameters - ---------- - owner : object - The owner object. - """ - self.owner = owner - self.use_grib_paramid: bool = owner.context.use_grib_paramid - - def trace(self, emoji: str, *args: Any) -> None: - """Traces the given arguments with an emoji. - - Parameters - ---------- - emoji : str - The emoji to use. - *args : Any - The arguments to trace. - """ - trace(emoji, *args) - - def info(self, *args: Any, **kwargs: Any) -> None: - """Logs an info message. - - Parameters - ---------- - *args : Any - The arguments for the log message. - **kwargs : Any - The keyword arguments for the log message. - """ - LOG.info(*args, **kwargs) - - @property - def dates_provider(self) -> object: - """Returns the dates provider.""" - return self.owner.group_of_dates.provider - - @property - def partial_ok(self) -> bool: - """Returns whether partial results are acceptable.""" - return self.owner.group_of_dates.partial_ok - - def get_result(self, *args, **kwargs) -> Any: - return self.owner.context.get_result(*args, **kwargs) - - -class FunctionAction(Action): - """Represents an action that executes a function. - - Attributes - ---------- - name : str - The name of the function. - """ - - def __init__(self, context: object, action_path: list, _name: str, source, **kwargs: dict[str, Any]) -> None: - """Initializes a FunctionAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - _name : str - The name of the function. - **kwargs : Dict[str, Any] - Additional keyword arguments. - """ - super().__init__(context, action_path, **kwargs) - self.name: str = _name - self.source = source - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": - """Selects the function result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - FunctionResult - The function result instance. - """ - return FunctionResult(self.context, self.action_path, group_of_dates, action=self) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionAction instance.""" - content: str = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) - content = self._short_str(content) - return self._repr(_inline_=content, _indent_=" ") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Traces the selection of the function for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - str - The trace string. - """ - return f"{self.name}({group_of_dates})" - - -class FunctionResult(Result): - """Represents the result of executing a function. - - Attributes - ---------- - action : Action - The action instance. - args : tuple - The positional arguments for the function. - kwargs : dict - The keyword arguments for the function. - """ - - def __init__(self, context: object, action_path: list, group_of_dates: GroupOfDates, action: Action) -> None: - """Initializes a FunctionResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - action : Action - The action instance. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(action, Action), type(action) - self.action: Action = action - - self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.action.name}({self.group_of_dates})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource for the function result.""" - # args, kwargs = resolve(self.context, (self.args, self.kwargs)) - self.action.source.context = FunctionContext(self) - - return _tidy( - self.action.source.execute( - list(self.group_of_dates), # Will provide a list of datetime objects - ) - ) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionResult instance.""" - try: - return f"{self.action.name}({self.group_of_dates})" - except Exception: - return f"{self.__class__.__name__}(unitialised)" - - @property - def function(self) -> None: - """Raises NotImplementedError as this property is not implemented. - - Raises - ------ - NotImplementedError - Always raised. - """ - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py deleted file mode 100644 index 122612c59..000000000 --- a/src/anemoi/datasets/create/input/join.py +++ /dev/null @@ -1,129 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class JoinResult(Result): - """Represents a result that combines multiple results. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - - def __init__( - self, context: object, action_path: list, group_of_dates: GroupOfDates, results: list[Result], **kwargs: Any - ) -> None: - """Initializes a JoinResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - super().__init__(context, action_path, group_of_dates) - self.results: list[Result] = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the combined datasource from all results.""" - ds: FieldList = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - def __repr__(self) -> str: - """Returns a string representation of the JoinResult instance.""" - content: str = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class JoinAction(Action): - """Represents an action that combines multiple actions. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - actions : List[Action] - The list of actions. - """ - - def __init__(self, context: object, action_path: list, *configs: dict) -> None: - """Initializes a JoinAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - *configs : dict - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - self.actions: list[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def __repr__(self) -> str: - """Returns a string representation of the JoinAction instance.""" - content: str = "\n".join([str(i) for i in self.actions]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> JoinResult: - """Selects the results for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - JoinResult - The combined result for the given group of dates. - """ - results: list[Result] = [a.select(group_of_dates) for a in self.actions] - return JoinResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py deleted file mode 100644 index 6c9fea0df..000000000 --- a/src/anemoi/datasets/create/input/pipe.py +++ /dev/null @@ -1,66 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -import logging -from typing import Any - -from .action import Action -from .action import action_factory -from .step import step_factory -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class PipeAction(Action): - """A class to represent a pipeline of actions.""" - - def __init__(self, context: Any, action_path: list, *configs: dict) -> None: - """Initialize the PipeAction. - - Parameters - ---------- - context : Any - The context for the action. - action_path : list - The path of the action. - configs : dict - The configurations for the actions. - """ - super().__init__(context, action_path, *configs) - if len(configs) <= 1: - raise ValueError( - f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" - ) - - current: Any = action_factory(configs[0], context, action_path + ["0"]) - for i, c in enumerate(configs[1:]): - current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) - self.last_step: Any = current - - @trace_select - def select(self, group_of_dates: Any) -> Any: - """Select data based on the group of dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to select data for. - - Returns - ------- - Any - The selected data. - """ - return self.last_step.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the PipeAction.""" - return f"PipeAction({self.last_step})" diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index 3a358d343..ad46fe208 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -22,7 +22,7 @@ from .action import Action from .action import action_factory from .join import JoinResult -from .result import Result +from .result.field import Result from .trace import trace_select LOG = logging.getLogger(__name__) @@ -345,6 +345,8 @@ def __init__(self, context: Any, action_path: list[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) + self.mode = mode + self.kwargs = kwargs @trace_select def select(self, group_of_dates: Any) -> JoinResult: diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/create/input/result/__init__.py new file mode 100644 index 000000000..03a00c51d --- /dev/null +++ b/src/anemoi/datasets/create/input/result/__init__.py @@ -0,0 +1,17 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC + +LOG = logging.getLogger(__name__) + + +class Result(ABC): + pass diff --git a/src/anemoi/datasets/create/input/result.py b/src/anemoi/datasets/create/input/result/field.py similarity index 87% rename from src/anemoi/datasets/create/input/result.py rename to src/anemoi/datasets/create/input/result/field.py index 043824050..083d2ffd7 100644 --- a/src/anemoi/datasets/create/input/result.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -22,9 +22,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from .action import ActionContext -from .trace import trace -from .trace import trace_datasource +from . import Result LOG = logging.getLogger(__name__) @@ -278,40 +276,22 @@ def sort(old_dic: DefaultDict[str, set]) -> dict[str, list[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class Result: +class FieldResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False _coords_already_built: bool = False - def __init__(self, context: ActionContext, action_path: list[str], dates: Any) -> None: - """Initialize a Result instance. + def __init__(self, context: Any, datasource: Any) -> None: - Parameters - ---------- - context : ActionContext - The context in which the result exists. - action_path : list of str - The action path. - dates : Any - The dates associated with the result. - """ from anemoi.datasets.dates.groups import GroupOfDates - assert isinstance(dates, GroupOfDates), dates - - assert isinstance(context, ActionContext), type(context) - assert isinstance(action_path, list), action_path - self.context: Any = context - self.group_of_dates: Any = dates - self.action_path: list[str] = action_path - - @property - @trace_datasource - def datasource(self) -> Any: - """Retrieve the data source for the result.""" - self._raise_not_implemented() + self.datasource = datasource + self.group_of_dates = context.argument + assert isinstance( + self.group_of_dates, GroupOfDates + ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" @property def data_request(self) -> dict[str, Any]: @@ -326,7 +306,7 @@ def get_cube(self) -> Any: Any The data cube. """ - trace("🧊", f"getting cube from {self.__class__.__name__}") + ds: Any = self.datasource remapping: Any = self.context.remapping @@ -519,66 +499,6 @@ def explain(self, ds: Any, *args: Any, remapping: Any, patches: Any) -> None: print() exit(1) - def _repr(self, *args: Any, _indent_: str = "\n", **kwargs: Any) -> str: - """Return the string representation of the Result instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str - Indentation string. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more: str = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - dates: str = " no-dates" - if self.group_of_dates is not None: - dates = f" {len(self.group_of_dates)} dates" - dates += " (" - dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates) - if len(dates) > 100: - dates = dates[:100] + "..." - dates += ")" - - more = more[:5000] - txt: str = f"{self.__class__.__name__}:{dates}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Result instance.""" - return self._repr() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Trace the data source for the result. - - Parameters - ---------- - args : Any - Additional positional arguments. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({self.group_of_dates})" - def build_coords(self) -> None: """Build the coordinates for the result.""" if self._coords_already_built: diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py deleted file mode 100644 index 031b44094..000000000 --- a/src/anemoi/datasets/create/input/step.py +++ /dev/null @@ -1,173 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from copy import deepcopy -from typing import Any - -from .action import Action -from .action import ActionContext -from .context import Context -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class StepResult(Result): - """Represents the result of a step in the data processing pipeline.""" - - def __init__( - self, context: Context, action_path: list[str], group_of_dates: Any, action: Action, upstream_result: Result - ) -> None: - """Initialize a StepResult instance. - - Parameters - ---------- - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - group_of_dates - The group of dates associated with this step. - action - The action associated with this step. - upstream_result - The result of the upstream step. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(upstream_result, Result), type(upstream_result) - self.upstream_result: Result = upstream_result - self.action: Action = action - - @property - @notify_result - @trace_datasource - def datasource(self) -> Any: - """Retrieve the datasource associated with this step result.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class StepAction(Action): - """Represents an action that is part of a step in the data processing pipeline.""" - - result_class: type[StepResult] | None = None - - def __init__( - self, context: ActionContext, action_path: list[str], previous_step: Any, *args: Any, **kwargs: Any - ) -> None: - """Initialize a StepAction instance. - - Parameters - ---------- - context - The context in which the action is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - """ - super().__init__(context, action_path, *args, **kwargs) - self.previous_step: Any = previous_step - - @trace_select - def select(self, group_of_dates: Any) -> StepResult: - """Select the result for a given group of dates. - - Parameters - ---------- - group_of_dates - The group of dates to select the result for. - - Returns - ------- - unknown - The result of the step. - """ - return self.result_class( - self.context, - self.action_path, - group_of_dates, - self, - self.previous_step.select(group_of_dates), - ) - - def __repr__(self) -> str: - """Return a string representation of the StepAction instance. - - Returns - ------- - unknown - String representation of the instance. - """ - return self._repr(self.previous_step, _inline_=str(self.kwargs)) - - -def step_factory(config: dict[str, Any], context: ActionContext, action_path: list[str], previous_step: Any) -> Any: - """Factory function to create a step action based on the given configuration. - - Parameters - ---------- - config - The configuration dictionary for the step. - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - - Returns - ------- - unknown - An instance of a step action. - """ - - from .filter import FilterStepAction - from .filter import FunctionStepAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - config = deepcopy(config) - assert len(config) == 1, config - - key = list(config.keys())[0] - cls = dict( - filter=FilterStepAction, - # rename=RenameAction, - # remapping=RemappingAction, - ).get(key) - - if isinstance(config[key], list): - args, kwargs = config[key], {} - - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - if isinstance(config[key], str): - args, kwargs = [config[key]], {} - - if cls is not None: - return cls(context, action_path, previous_step, *args, **kwargs) - - # Try filters from transform filter registry - from anemoi.transform.filters import filter_registry as transform_filter_registry - - if transform_filter_registry.is_registered(key): - from ..filter import TransformFilter - - return FunctionStepAction( - context, action_path + [key], previous_step, key, TransformFilter(context, key, config) - ) - - raise ValueError(f"Unknown step action `{key}`") diff --git a/src/anemoi/datasets/create/input/template.py b/src/anemoi/datasets/create/input/template.py deleted file mode 100644 index 5effc7d7f..000000000 --- a/src/anemoi/datasets/create/input/template.py +++ /dev/null @@ -1,161 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import re -from abc import ABC -from abc import abstractmethod -from collections.abc import Callable -from functools import wraps -from typing import Any - -from .context import Context - -LOG = logging.getLogger(__name__) - - -def notify_result(method: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to notify the context of the result of the method call. - - Parameters - ---------- - method : Callable[..., Any] - The method to wrap. - - Returns - ------- - Callable[..., Any] - The wrapped method. - """ - - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result: Any = method(self, *args, **kwargs) - self.context.notify_result(self.action_path, result) - return result - - return wrapper - - -class Substitution(ABC): - """Abstract base class for substitutions in templates.""" - - @abstractmethod - def resolve(self, context: Context) -> Any: - """Resolve the substitution using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - pass - - -class Reference(Substitution): - """A class to represent a reference to another value in the context.""" - - def __init__(self, context: Any, action_path: list[str]) -> None: - """Initialize a Reference instance. - - Parameters - ---------- - context : Any - The context in which the reference exists. - action_path : list of str - The action path to resolve. - """ - self.context: Any = context - self.action_path: list[str] = action_path - - def resolve(self, context: Context) -> Any: - """Resolve the reference using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - return context.get_result(self.action_path) - - -def resolve(context: Context, x: Any) -> Any: - """Recursively resolve substitutions in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for resolution. - x : Union[tuple, list, dict, Substitution, Any] - The structure to resolve. - - Returns - ------- - Any - The resolved structure. - """ - if isinstance(x, tuple): - return tuple([resolve(context, y) for y in x]) - - if isinstance(x, list): - return [resolve(context, y) for y in x] - - if isinstance(x, dict): - return {k: resolve(context, v) for k, v in x.items()} - - if isinstance(x, Substitution): - return x.resolve(context) - - return x - - -def substitute(context: Context, x: Any) -> Any: - """Recursively substitute references in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for substitution. - x : Union[tuple, list, dict, str, Any] - The structure to substitute. - - Returns - ------- - Any - The substituted structure. - """ - if isinstance(x, tuple): - return tuple([substitute(context, y) for y in x]) - - if isinstance(x, list): - return [substitute(context, y) for y in x] - - if isinstance(x, dict): - return {k: substitute(context, v) for k, v in x.items()} - - if not isinstance(x, str): - return x - - if re.match(r"^\${[\.\w\-]+}$", x): - path = x[2:-1].split(".") - context.will_need_reference(path) - return Reference(context, path) - - return x diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index c7a5b0fbf..104f24863 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -45,7 +45,7 @@ def constants(context: Any, dates: list[str], template: dict[str, Any], param: s if len(template) == 0: raise ValueError("Forcings template is empty.") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute: Any = constants diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index 0c1d62da4..bbafaa465 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -35,7 +35,7 @@ def forcings(context: Any, dates: list[str], template: str, param: str) -> Any: Loaded forcing data. """ context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute = forcings diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index 8639fa824..4dbd481cd 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,6 @@ from collections.abc import Callable from typing import Any -from anemoi.datasets.create.input.template import resolve - from ..source import Source from . import source_registry @@ -74,7 +72,7 @@ def __call__(self, execute: Callable) -> Callable: def execute_wrapper(self, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(self.context, (self.args, self.kwargs)) + args, kwargs = self.args, self.kwargs try: return execute(self.context, dates, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py new file mode 100644 index 000000000..b56537979 --- /dev/null +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -0,0 +1,306 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from collections import defaultdict +from collections.abc import Generator +from typing import Any + +import numpy as np +from anemoi.transform.fields import new_field_with_valid_datetime +from anemoi.transform.fields import new_fieldlist_from_list +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry + +LOG = logging.getLogger(__name__) + + +class Action: + pass + + +class Result: + pass + + +class DateMapper: + """A factory class to create DateMapper instances based on the given mode.""" + + @staticmethod + def from_mode(mode: str, source: Any, config: dict[str, Any]) -> "DateMapper": + """Create a DateMapper instance based on the given mode. + + Parameters + ---------- + mode : str + The mode to use for the DateMapper. + source : Any + The data source. + config : dict + Configuration parameters. + + Returns + ------- + DateMapper + An instance of DateMapper. + """ + MODES: dict = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + """A DateMapper implementation that maps dates to the closest available dates.""" + + def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: + """Initialize DateMapperClosest. + + Parameters + ---------- + source : Any + The data source. + frequency : str + Frequency of the dates. + maximum : str + Maximum time delta. + skip_all_nans : bool + Whether to skip all NaN values. + """ + self.source: Any = source + self.maximum: Any = frequency_to_timedelta(maximum) + self.frequency: Any = frequency_to_timedelta(frequency) + self.skip_all_nans: bool = skip_all_nans + self.tried: set[Any] = set() + self.found: set[Any] = set() + + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: + """Transform the group of dates to the closest available dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + info = {k: "no-data" for k in to_try} + + if not to_try: + LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") + # return [] + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + cnt = 0 + for f in result.datasource: + cnt += 1 + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + info[date] = "all-nans" + continue + + info[date] = "ok" + self.found.add(date) + + if cnt == 0: + raise ValueError(f"No data found for {group_of_dates} in {self.source}") + + self.tried.update(to_try) + + if not self.found: + for k, v in info.items(): + LOG.warning(f"{k}: {v}") + + raise ValueError(f"No matching data found for {asked_dates} in {self.source}") + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + """A DateMapper implementation that maps dates to specified climatology dates.""" + + def __init__(self, source: Any, year: int, day: int, hour: int | None = None) -> None: + """Initialize DateMapperClimatology. + + Parameters + ---------- + source : Any + The data source. + year : int + The year to map to. + day : int + The day to map to. + hour : Optional[int] + The hour to map to. + """ + self.year: int = year + self.day: int = day + self.hour: int | None = hour + + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: + """Transform the group of dates to the specified climatology dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + if self.hour is not None: + new_date = new_date.replace(hour=self.hour, minute=0, second=0) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + """A DateMapper implementation that maps dates to a constant date.""" + + def __init__(self, source: Any, date: Any | None = None) -> None: + """Initialize DateMapperConstant. + + Parameters + ---------- + source : Any + The data source. + date : Optional[Any] + The constant date to map to. + """ + self.source: Any = source + self.date: Any | None = date + + def transform(self, group_of_dates: Any) -> tuple[Any, Any]: + """Transform the group of dates to a constant date. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Tuple[Any, Any] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] + + +@source_registry.register("repeated_dates") +class RepeatedDatesSource(Source): + + def __init__(self, context, source: Any, mode: str, **kwargs) -> None: + # assert False, (context, source, mode, kwargs) + super().__init__(context, **kwargs) + self.mapper = DateMapper.from_mode(mode, source, kwargs) + self.source = source + + def execute(self, group_of_dates): + source = self.context.create_source(self.source, "data_sources", str(id(self))) + + result = [] + for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): + print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + source_results = source(self.context, one_date_group) + for field in source_results: + for date in many_dates_group: + result.append(new_field_with_valid_datetime(field, date)) + + return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 8c4122760..bc6dacafd 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -58,6 +58,8 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]: class DatesProvider: """Base class for date generation. + Examples + -------- >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)] diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py new file mode 100644 index 000000000..18c8d34d4 --- /dev/null +++ b/src/anemoi/datasets/dumper.py @@ -0,0 +1,76 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import io +import logging + +import ruamel.yaml + +LOG = logging.getLogger(__name__) + + +def represent_date(dumper, data): + + if isinstance(data, datetime.datetime): + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + else: + iso_str = data.isoformat() + + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) + + +# --- Represent multiline strings with | style --- +def represent_multiline_str(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +# --- Represent short lists inline (flow style) --- +def represent_inline_list(dumper, data): + + if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data) + + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + + +def yaml_dump(obj, order=None, stream=None, **kwargs): + + if order: + + def _ordering(k): + return order.index(k) if k in order else len(order) + + obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} + + yaml = ruamel.yaml.YAML() + yaml.width = 120 # wrap long flow sequences + + yaml.Representer.add_representer(datetime.date, represent_date) + yaml.Representer.add_representer(datetime.datetime, represent_date) + yaml.Representer.add_representer(str, represent_multiline_str) + yaml.Representer.add_representer(list, represent_inline_list) + + data = ruamel.yaml.comments.CommentedMap() + for i, (k, v) in enumerate(obj.items()): + data[k] = v + if i > 0: + data.yaml_set_comment_before_after_key(key=k, before="\n") + + if stream: + yaml.dump(data, stream=stream, **kwargs) + + stream = io.StringIO() + yaml.dump(data, stream=stream, **kwargs) + return stream.getvalue() diff --git a/src/anemoi/datasets/schemas/recipe.json b/src/anemoi/datasets/schemas/recipe.json new file mode 100644 index 000000000..3c02bfd64 --- /dev/null +++ b/src/anemoi/datasets/schemas/recipe.json @@ -0,0 +1,131 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$id": "https://ecmwf.int/anemoi-datasets-recipe.schema.json", + "title": "Product", + "description": "Anemoi datasets recipe configuration", + "additionalProperties": false, + "$defs": { + "source-or-filter": { + "type": "object", + "minProperties": 1, + "maxProperties": 1 + }, + "pipe": { + "type": "array", + "items": { + "$ref": "#/$defs/input-object" + } + }, + "join": { + "type": "array", + "items": { + "$ref": "#/$defs/input-object" + } + }, + "concat": { + "type": "array", + "items": { + "type": "object", + "minProperties": 2, + "maxProperties": 2, + "required": [ + "dates" + ] + } + }, + "input-object": { + "oneOf": [ + { + "$ref": "#/$defs/pipe" + }, + { + "$ref": "#/$defs/join" + }, + { + "$ref": "#/$defs/concat" + }, + { + "$ref": "#/$defs/source-or-filter" + } + ] + } + }, + "properties": { + "env": { + "type": "object" + }, + "description": { + "type": "string" + }, + "name": { + "type": "string" + }, + "licence": { + "type": "string" + }, + "attribution": { + "type": "string" + }, + "dates": { + "type": "object", + "required": [ + "start", + "end" + ], + "properties": { + "start": { + "type": "string", + "format": "date" + }, + "end": { + "type": "string", + "format": "date" + }, + "frequency": { + "type": [ + "integer", + "string" + ] + }, + "group_by": { + "type": [ + "integer", + "string" + ] + } + } + }, + "input": { + "$ref": "#/$defs/input-object" + }, + "data_sources": { + "type": "object", + "patternProperties": { + "^[a-zA-Z_][a-zA-Z0-9_]*$": { + "$ref": "#/$defs/input-object" + } + }, + "additionalProperties": false + }, + "output": { + "type": "object" + }, + "statistics": { + "type": "object" + }, + "build": { + "type": "object" + }, + "common": { + "type": "object" + }, + "platform": { + "type": "object" + } + }, + "required": [ + "dates", + "input" + ] +} diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 193c9a26a..dd3f37864 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -14,6 +14,8 @@ from unittest.mock import patch import pytest +from anemoi.transform.filter import Filter +from anemoi.transform.filters import filter_registry from anemoi.utils.testing import GetTestArchive from anemoi.utils.testing import GetTestData from anemoi.utils.testing import skip_if_offline @@ -32,6 +34,18 @@ assert NAMES, "No yaml files found in " + HERE +# Used by pipe.yaml +@filter_registry.register("filter") +class TestFilter(Filter): + + def __init__(self, **kwargs): + + self.kwargs = kwargs + + def forward(self, data): + return data.sel(**self.kwargs) + + @pytest.fixture def load_source(get_test_data: GetTestData) -> LoadSource: return LoadSource(get_test_data)