diff --git a/docs/inference/configs/grib-output.rst b/docs/inference/configs/grib-output.rst index 482a427c1..cc3e3a289 100644 --- a/docs/inference/configs/grib-output.rst +++ b/docs/inference/configs/grib-output.rst @@ -8,9 +8,11 @@ The GRIB output will be more efficient in conjunction with a :ref:`grib-input`. The GRIB output will use its input as template for - encoding GRIB messages. If the input is not a GRIB file, the output - will still attempt to encode the data as GRIB messages, but this is - not always possible. + encoding GRIB messages. + + If the input is not a GRIB file, the output will still attempt to + encode the data as GRIB messages, but this is not always possible. In + such cases, custom grib :ref:`templates` need to be provided. .. note:: @@ -21,35 +23,309 @@ The ``grib`` output can be used to specify many more options: .. literalinclude:: yaml/grib-output_1.yaml -*********** - encoding: -*********** +.. _encoding: + +********** + encoding +********** A dictionary of key/value pairs to add to the encoding of the GRIB messages. -***************** - check_encoding: -***************** +**************** + check_encoding +**************** A boolean to check that the GRIB messages have been encoded correctly. +.. _templates: + *********** - template: + templates *********** -If the input is not a GRIB file, the output can use an ``input`` source -to find similar fields to act as a template for encoding the GRIB. +anemoi-inference comes with a minimal set of built-in GRIB templates it +uses to write GRIB output. When you are running a model that runs on a +different grid and/or area than the built-in templates, or you need to +encode messages with a local definition, you may need to provide custom +templates. + +Templates are configured by specifying :ref:`template providers +` in the ``templates`` option of the grib output, for +example: + +.. code:: yaml + + output: + grib: + path: output.grib + templates: + - : + +The ```` can be omitted for providers that don't take any +options, be a string for providers that only take one option, or be a +dictionary for providers that take more than one option. + +Multiple providers can be specified, and they will be tried in the order +they are listed. If a provider cannot provide a template for a variable, +the next provider will be tried. When all providers have been exhausted, +an error will be raised. + +If the ``templates`` option is omitted and no template providers are +specified, the default providers are ``input`` and ``builtin`` (in that +order). + +.. warning:: + + The default providers are **not** automatically enabled once you + specify any providers yourself. If you have custom providers, you + need to explicitly include the ``input`` and/or ``builtin`` providers + if you want to use them. For most use-cases with custom templates, + it's recommended to also enable the ``input`` provider. + +.. admonition:: What is a GRIB template? + + A template is a GRIB message that is used as a blueprint for encoding + other GRIB messages. The template provides the necessary metadata + (like grid definition, area, etc) for the output. + + It's good practice to null the data values of custom templates to + keep them lightweight. This can be done with eccodes' `grib_set + `_: + + .. code:: bash + + grib_set -d 0 template.grib nulled-template.grib + +.. _providers: + +******************** + Template providers +******************** + +The following template providers are available: + +``input`` +========= + +Use the messages from the GRIB input as templates. Only works with a +GRIB input. + +.. code:: yaml + + output: + grib: + templates: + - input + +By default, only prognostic variables are taken from the input. +Diagnostic variables will throw an error if no other template provider +can provide a template. A fallback mapping can optionally be provided to +map output variables to input variables: + +.. code:: yaml + + output: + grib: + templates: + - input: + tp: 2t + cp: 2t + +In this example, if the output variables ``tp`` and ``cp`` are missing +from the input, the input template for ``2t`` will be used instead. + +.. note:: + + The fallback mapping only applies to output variables that are + missing from the input state. If the input contains an output + variable, its template will always be used. + +``builtin`` +=========== + +Use the built-in templates that come with the package. By default these +will be used for diagnostic variables, or if the input is not GRIB. Only +a limited number of grids are `included +`_, +and only at a global area. + +.. code:: yaml + + output: + grib: + templates: + - builtin + +``file`` +======== + +Load templates from a specified GRIB file. + +.. code:: yaml + + output: + grib: + templates: + - file: /path/to/template.grib + +By default, only the first message in the file will be used as template +for **all** output variables. This behaviour can be changed with the +following options: + +- ``path`` the path to the GRIB file + +- ``mode`` how to select a message from the grib file to use as + template. Can be one of: + + - ``first`` (default) use the first message in the file + - ``last`` use the last message in the file + - ``auto`` select variable from the file matching the output + variable name + +- ``variables`` the output variable name(s) for which to use this + template file (list or string). If empty, applies to all variables. + +.. tip:: + + A recommended use-case when using the GRIB input, is to use the + ``input`` provider to cover prognostic variables, and use a ``file`` + provider in auto-mode for diagnostic variables: + + .. code:: yaml + + output: + grib: + templates: + - input + - file: + path: /path/to/file-with-diagnostic-variables.grib + mode: auto + +``samples`` +=========== + +Load templates from specified GRIB files based on rules matched against +a variable's metadata. + +This provider takes a list of samples, each consisting of a dictionary +of matching rules and a path to a GRIB file. Whenever an output template +is requested, the sample's rules are checked against the output +variable's metadata. If all ``key:value`` pairs in a sample's rule match +the corresponding pair in the variable's metadata, the sample file is +selected. + +.. warning:: + + Only **the first message** in the sample GRIB file is used as + template. + +.. code:: yaml + + output: + grib: + templates: + - samples: + - - { matching rules 1 } + - /path/to/template1.grib + - - { matching rules 2 } + - /path/to/template2.grib + # etc + +A practical use-case is to provide templates for different grids and/or +levtypes. For example, if you are running models on both the N320 and +O96 grids, you could provide templates like this: + +.. code:: yaml + + output: + grib: + templates: + - samples: + - - { grid: N320, levtype: pl } + - /path/to/template-n320-pl.grib + - - { grid: N320 } + - /path/to/template-n320-sfc.grib + - - { grid: O96, levtype: pl } + - /path/to/template-o96-pl.grib + - - { grid: O96 } + - /path/to/template-o96-sfc.grib + +Note that the sfc (surface) template doesn't have a levtype rule. The +sfc template will be used for all variables that are not pl (pressure +level). It is also possible to have an empty rule ``{}`` as a catch-all, +but care needs to be taken as it may lead to incorrect encoding if the +grids do not match. + +The following keys are available for use in the matching rules: + ++---------------------------+-------------------------------------------------------------------------+ +| Rule Key | Details | ++===========================+=========================================================================+ +| ``grid`` | Grid definition (e.g. N320, 0.25, etc) | ++---------------------------+-------------------------------------------------------------------------+ +| ``area`` | Area as [north, west, south, east] | ++---------------------------+-------------------------------------------------------------------------+ +| ``time_processing`` | One of ``accumulation``, ``average``, ``maximum``, ``minimum``, | +| | ``instantaneous`` | ++---------------------------+-------------------------------------------------------------------------+ +| ``number_of_grid_points`` | As integer, from checkpoint `grid_indices` or shape of the last dataset | ++---------------------------+-------------------------------------------------------------------------+ +| The variable's MARS keys | Keys like ``levtype``, ``param``, etc. | ++---------------------------+-------------------------------------------------------------------------+ +| GRIB encoding | Any key specified under :ref:`output.grib.encoding ` | ++---------------------------+-------------------------------------------------------------------------+ + +This information is taken from the checkpoint metadata. If unsure what +the values are for a given checkpoint, running a forecast with +``verbosity: 2`` will show the lookup rules in the log. It is also shown +in the error message when a template cannot be found for a variable. + +.. note:: + + Samples are evaluated in the order they are listed and when all keys + in the sample rules match, the sample is selected. This makes the + order of samples important: more specific rules should be listed + first, and more general last. + +.. tip:: + + The sample file path can be a format string with rules. The above + example can be rewritten as: + + .. code:: yaml + + output: + grib: + templates: + - samples: + - - levtype: pl + - /path/to/template-{grid}-pl.grib + - - {} + - /path/to/template-{grid}-sfc.grib + + Or even shorter: + + .. code:: yaml + + output: + grib: + templates: + - samples: + - - {} + - /path/to/template-{grid}-{levtype}.grib + + But notice the subtle difference: in the first example, the sfc file + is used for *all* non-pl variables. In the second example, a file + must exist for each individual levtype. -- ``source``: An input source to use as template for encoding the GRIB. +.. tip:: -- ``date``: The to use when looking for the template (default is the - date at the output field to encode). + Samples can also be provided via a separate YAML file: -- ``reuse``: A boolean to reuse the template for all fields (default is - ``false``). If `true`, the template fetch for the first field will be - used for all fields, irrespective of the variable to encode. + .. code:: yaml -- ``archive_requests``: This is a private feature used to generate the - necessary information to archive the result of the run into ECMWF's - MARS archive. + output: + grib: + templates: + - samples: /path/to/samples.yaml diff --git a/docs/inference/configs/yaml/grib-output_1.yaml b/docs/inference/configs/yaml/grib-output_1.yaml index 7f1d2671d..c6112ab91 100644 --- a/docs/inference/configs/yaml/grib-output_1.yaml +++ b/docs/inference/configs/yaml/grib-output_1.yaml @@ -6,13 +6,5 @@ output: class: rd check_encoding: true templates: - source: mars - date: 2001-01-01 - reuse: True - archive_requests: - path: archive.json - extra: - database: marsrd - patch: - number: null - levtype: sfc + - input + - builtin diff --git a/src/anemoi/inference/grib/templates/__init__.py b/src/anemoi/inference/grib/templates/__init__.py index d0677808f..5aa146e21 100644 --- a/src/anemoi/inference/grib/templates/__init__.py +++ b/src/anemoi/inference/grib/templates/__init__.py @@ -9,6 +9,7 @@ import logging +from typing import TYPE_CHECKING from typing import Any import earthkit.data as ekd @@ -16,6 +17,10 @@ from anemoi.utils.registry import Registry from anemoi.inference.config import Configuration +from anemoi.inference.output import Output + +if TYPE_CHECKING: + from .manager import TemplateManager LOG = logging.getLogger(__name__) @@ -23,12 +28,12 @@ template_provider_registry = Registry(__name__) -def create_template_provider(owner: Any, config: Configuration) -> "TemplateProvider": +def create_template_provider(owner: Output, config: Configuration) -> "TemplateProvider": """Create a template provider from the given configuration. Parameters ---------- - owner : Any + owner : Output The owner of the template provider. config : Configuration The configuration for the template provider. @@ -44,17 +49,20 @@ def create_template_provider(owner: Any, config: Configuration) -> "TemplateProv class TemplateProvider: """Base class for template providers.""" - def __init__(self, manager: Any) -> None: + def __init__(self, manager: "TemplateManager") -> None: """Initialize the template provider. Parameters ---------- - manager : Any + manager : TemplateManager The manager for the template provider. """ self.manager = manager - def template(self, variable: str, lookup: dict[str, Any]) -> ekd.Field: + def __repr__(self): + return f"{self.__class__.__name__}" + + def template(self, variable: str, lookup: dict[str, Any], **kwargs) -> ekd.Field | None: """Get the template for the given variable and lookup. Parameters @@ -63,10 +71,12 @@ def template(self, variable: str, lookup: dict[str, Any]) -> ekd.Field: The variable to get the template for. lookup : Dict[str, Any] The lookup dictionary. + kwargs + Extra arguments for specific template providers. Returns ------- - ekd.Field + ekd.Field | None The template field. """ raise NotImplementedError() @@ -75,55 +85,45 @@ def template(self, variable: str, lookup: dict[str, Any]) -> ekd.Field: class IndexTemplateProvider(TemplateProvider): """Template provider based on an index file.""" - def __init__(self, manager: Any, index_path: str) -> None: + def __init__(self, manager: "TemplateManager", index: str | list) -> None: """Initialize the index template provider. Parameters ---------- - manager : Any + manager : TemplateManager The manager for the template provider. - index_path : str - The path to the index file. + index : str | list + The path to the index.yaml file, or its contents directly as a list. """ super().__init__(manager) - self.index_path = index_path + self.index_path = index - with open(index_path) as f: - self.templates = yaml.safe_load(f) + if isinstance(index, str): + with open(index) as f: + self.templates = yaml.safe_load(f) + else: + self.templates = index if not isinstance(self.templates, list): - raise ValueError("Invalid templates.yaml, must be a list") + raise ValueError(f"Invalid index, must be a list. Got {self.templates}") # TODO: use pydantic for template in self.templates: if not isinstance(template, list): - raise ValueError(f"Invalid template in templates.yaml, must be a list: {template}") + raise ValueError(f"Invalid template index element, must be a list. Got {template}") if len(template) != 2: - raise ValueError(f"Invalid template in templates.yaml, must have exactly 2 elements: {template}") + raise ValueError( + f"Expected template index to be a 2-elements list as `[matching filter, grib file]`. Got {template}." + ) match, grib = template if not isinstance(match, dict): - raise ValueError(f"Invalid match in templates.yaml, must be a dict: {match}") + raise ValueError(f"Invalid match in index element, must be a dict: {match}") if not isinstance(grib, str): - raise ValueError(f"Invalid grib in templates.yaml, must be a string: {grib}") - - def template(self, variable: str, lookup: dict[str, Any]) -> ekd.Field | None: - """Get the template for the given variable and lookup. - - Parameters - ---------- - variable : str - The variable to get the template for. - lookup : Dict[str, Any] - The lookup dictionary. - - Returns - ------- - Optional[ekd.Field] - The template field if found, otherwise None. - """ + raise ValueError(f"Invalid grib in index element, must be a string: {grib}") + def template(self, variable: str, lookup: dict[str, Any], **kwargs) -> ekd.Field | None: def _as_list(value: Any) -> list[Any]: if not isinstance(value, list): return [value] @@ -132,7 +132,7 @@ def _as_list(value: Any) -> list[Any]: for template in self.templates: match, grib = template if LOG.isEnabledFor(logging.DEBUG): - LOG.debug("%s", [(lookup.get(k), _as_list(v)) for k, v in match.items()]) + LOG.debug(f"Matching {match} -> {[(lookup.get(k), _as_list(v)) for k, v in match.items()]}") if all(lookup.get(k) in _as_list(v) for k, v in match.items()): return self.load_template(grib, lookup) diff --git a/src/anemoi/inference/grib/templates/builtin.py b/src/anemoi/inference/grib/templates/builtin.py index 4658a300b..5ac0cb32e 100644 --- a/src/anemoi/inference/grib/templates/builtin.py +++ b/src/anemoi/inference/grib/templates/builtin.py @@ -17,6 +17,7 @@ from . import IndexTemplateProvider from . import template_provider_registry +from .manager import TemplateManager LOG = logging.getLogger(__name__) @@ -32,37 +33,12 @@ class BuiltinTemplates(IndexTemplateProvider): """Builtin templates provider.""" - def __init__(self, manager: Any, index_path: str | None = None) -> None: - """Initialize the BuiltinTemplates instance. - - Parameters - ---------- - manager : Any - The manager instance. - index_path : Optional[str], optional - The path to the index file, by default None. - """ + def __init__(self, manager: TemplateManager, index_path: str | None = None) -> None: if index_path is None: index_path = os.path.join(os.path.dirname(__file__), "builtin.yaml") super().__init__(manager, index_path) def load_template(self, grib: str, lookup: dict[str, Any]) -> ekd.Field | None: - """Load the template for the given GRIB and lookup. - - Parameters - ---------- - grib : str - The GRIB string. - lookup : Dict[str, Any] - The lookup dictionary. - - Returns - ------- - Optional[ekd.Field] - The loaded template field if found, otherwise None. - """ - import earthkit.data as ekd - template = zlib.decompress(base64.b64decode(grib)) return ekd.from_source("memory", template)[0] diff --git a/src/anemoi/inference/grib/templates/file.py b/src/anemoi/inference/grib/templates/file.py index 6dc483288..c7f889cdc 100644 --- a/src/anemoi/inference/grib/templates/file.py +++ b/src/anemoi/inference/grib/templates/file.py @@ -8,46 +8,81 @@ # nor does it submit to any jurisdiction. import logging +from functools import cached_property +from pathlib import Path from typing import Any +from typing import Literal import earthkit.data as ekd +from anemoi.inference.decorators import main_argument +from anemoi.inference.inputs.ekd import find_variable +from anemoi.inference.types import State + from . import TemplateProvider from . import template_provider_registry +from .manager import TemplateManager LOG = logging.getLogger(__name__) @template_provider_registry.register("file") +@main_argument("path") class FileTemplates(TemplateProvider): """Template provider using a single GRIB file.""" - def __init__(self, manager: Any, path: str) -> None: + def __init__( + self, + manager: TemplateManager, + *, + path: str, + mode: Literal["auto", "first", "last"] = "first", + variables: str | list | None = None, + ) -> None: """Initialize the FileTemplates instance. Parameters ---------- - manager : Any + manager : TemplateManager The manager instance. path : str The path to the GRIB file. + mode : Literal["auto", "first", "last"], optional + The method with which to select a message from the grib file to use as template, by default "first": + - "first": use the first message in the grib file + - "last": use the last message in the grib file + - "auto": select variable from the grib file matching the output variable name + variables : str | list, optional + The output variable name(s) for which to use this template file. If empty, applies to all variables. """ self.manager = manager - self.path = path + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"GRIB template file not found: {self.path}") + self.mode = mode + self.variables = variables if isinstance(variables, list) else [variables] if variables else None - def template(self, grib: str, lookup: dict[str, Any]) -> ekd.Field: - """Retrieve the template from the GRIB file. + def __repr__(self): + info = f"{self.__class__.__name__}({self.path.name},mode={self.mode}{{variables}})" + return info.format(variables=f",variables={self.variables}" if self.variables else "") - Parameters - ---------- - grib : str - The GRIB string. - lookup : Dict[str, Any] - The lookup dictionary. - - Returns - ------- - ekd.Field - The field from the GRIB file. - """ - return ekd.from_source("file", self.path)[0] + @cached_property + def _data(self): + return ekd.from_source("file", self.path) + + def template(self, variable: str, lookup: dict[str, Any], state: State, **kwargs) -> ekd.Field | None: + if self.variables and variable not in self.variables: + return None + + match self.mode: + case "first": + return self._data[0] + case "last": + return self._data[-1] + case "auto": + namer = getattr(state.get("_input"), "_namer", self.manager.owner.context.checkpoint.default_namer()) + field = find_variable(self._data, variable, namer) + if len(field) > 0: + return field[0] + + return None diff --git a/src/anemoi/inference/grib/templates/input.py b/src/anemoi/inference/grib/templates/input.py new file mode 100644 index 000000000..905d95d36 --- /dev/null +++ b/src/anemoi/inference/grib/templates/input.py @@ -0,0 +1,65 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +import earthkit.data as ekd + +from anemoi.inference.types import State + +from . import TemplateProvider +from . import template_provider_registry +from .manager import TemplateManager + +LOG = logging.getLogger(__name__) + + +@template_provider_registry.register("input") +class InputTemplates(TemplateProvider): + """Use input fields as the output GRIB template.""" + + def __init__(self, manager: TemplateManager, **fallback: dict[str, str]) -> None: + """Initialize the template provider. + + Parameters + ---------- + manager : TemplateManager + The manager for the template provider. + **fallback : dict[str, str] + A mapping of output to input variable names to use as templates from the input, + used as fallback when the output variable is not present in the input state (e.g., for diagnostic variables). + """ + super().__init__(manager) + + self.fallback = fallback + + def __repr__(self): + info = f"{self.__class__.__name__}{{fallback}}" + if fallback := ", ".join(f"{k}:{v}" for k, v in self.fallback.items()): + fallback = f"(fallback {fallback})" + return info.format(fallback=fallback) + + def template( + self, + variable: str, + lookup: dict[str, Any], + *, + state: State, + **kwargs, + ) -> ekd.Field | None: + if template := state.get("_grib_templates_for_output", {}).get(variable): + return template + + if fallback_variable := self.fallback.get(variable): + if template := state.get("_grib_templates_for_output", {}).get(fallback_variable): + return template + LOG.warning(f"Fallback variable '{fallback_variable}' for output '{variable}' not found in input state.") + + return None diff --git a/src/anemoi/inference/grib/templates/manager.py b/src/anemoi/inference/grib/templates/manager.py index 45692f2d9..e03249da4 100644 --- a/src/anemoi/inference/grib/templates/manager.py +++ b/src/anemoi/inference/grib/templates/manager.py @@ -10,10 +10,12 @@ import json import logging +from collections import defaultdict from typing import Any import earthkit.data as ekd +from anemoi.inference.output import Output from anemoi.inference.types import State from . import create_template_provider @@ -22,23 +24,25 @@ class TemplateManager: - """A class to manage GRIB templates.""" + """A class to manage GRIB template providers.""" - def __init__(self, owner: Any, templates: list[str] | str | None = None) -> None: + def __init__(self, owner: Output, templates: list[str] | str | None = None) -> None: """Initialize the TemplateManager. Parameters ---------- - owner : Any + owner : Output The owner of the TemplateManager. templates : Optional[Union[List[str], str]], optional - A list of template names or a single template name, by default None. + A list of template providers or a single provider, by default None. """ self.owner = owner self.checkpoint = owner.context.checkpoint self.typed_variables = self.checkpoint.typed_variables self._template_cache = {} + self._history = defaultdict(list) + self._logged_variables = defaultdict(set) if templates is None: templates = [] @@ -47,9 +51,31 @@ def __init__(self, owner: Any, templates: list[str] | str | None = None) -> None templates = [templates] if len(templates) == 0: - templates = ["builtin"] + templates = ["input", "builtin"] self.templates_providers = [create_template_provider(self, template) for template in templates] + LOG.info("GRIB template providers:") + for provider in self.templates_providers: + LOG.info(f" - {provider}") + + def log_summary(self) -> None: + """Log a summary of the loaded templates. + Repeated calls will only log newly loaded templates since the last call. + """ + + to_log = defaultdict(set) + + for provider, typed_list in self._history.items(): + variables = [variable for variable in typed_list if variable.param not in self._logged_variables[provider]] + if len(variables) > 0: + for variable in variables: + to_log[provider].add(variable.param) + self._logged_variables[provider].add(variable.param) + + if to_log: + LOG.info("GRIB template summary:") + for provider, variables in to_log.items(): + LOG.info(f" - {provider}: {', '.join(sorted(variables))}") def template(self, name: str, state: State, typed_variables: dict[str, Any]) -> ekd.Field | None: """Get the template for a given name and state. @@ -70,9 +96,6 @@ def template(self, name: str, state: State, typed_variables: dict[str, Any]) -> """ assert name is not None, name - # Use input fields as templates - self._template_cache.update(state.get("_grib_templates_for_output", {})) - if name not in self._template_cache: self.load_template(name, state, typed_variables) @@ -124,9 +147,10 @@ def load_template(self, name: str, state: State, typed_variables: dict[str, Any] tried = [] for provider in self.templates_providers: - template = provider.template(name, lookup) + template = provider.template(name, lookup, state=state) if template is not None: self._template_cache[name] = template + self._history[provider].append(typed) return tried.append(provider) diff --git a/src/anemoi/inference/grib/templates/samples.py b/src/anemoi/inference/grib/templates/samples.py index 99e894f71..7c6f7cdbd 100644 --- a/src/anemoi/inference/grib/templates/samples.py +++ b/src/anemoi/inference/grib/templates/samples.py @@ -15,6 +15,7 @@ from . import IndexTemplateProvider from . import template_provider_registry +from .manager import TemplateManager LOG = logging.getLogger(__name__) @@ -23,24 +24,20 @@ class SamplesTemplates(IndexTemplateProvider): """Class to provide GRIB templates from sample files.""" + def __init__(self, manager: TemplateManager, *args, index_path: str | None = None) -> None: + if index_path is not None: + return super().__init__(manager, index_path) + + if isinstance(args[0], str): + return super().__init__(manager, args[0]) + + return super().__init__(manager, [*args]) + def load_template(self, grib: str, lookup: dict[str, Any]) -> ekd.Field | None: - """Load a GRIB template based on the provided lookup dictionary. - - Parameters - ---------- - grib : str - The GRIB template string. - lookup : Dict[str, Any] - The lookup dictionary to format the GRIB template string. - - Returns - ------- - Optional[ekd.Field] - The loaded GRIB template field, or None if the template is not found. - """ template = grib.format(**lookup) if not os.path.exists(template): LOG.warning(f"Template not found: {template}") return None + LOG.debug(f"Loading sample file: {template}") return ekd.from_source("file", template)[0] diff --git a/src/anemoi/inference/inputs/ekd.py b/src/anemoi/inference/inputs/ekd.py index f80d13987..36d159b5c 100644 --- a/src/anemoi/inference/inputs/ekd.py +++ b/src/anemoi/inference/inputs/ekd.py @@ -31,6 +31,33 @@ LOG = logging.getLogger(__name__) +def find_variable(data: Any, name: str, namer: callable, **kwargs: Any) -> Any: + """Find a variable in an earthkit FieldList/FieldArray. + + Parameters + ---------- + data : Any + The data to search (FieldList or FieldArray). + name : str + The name of the variable to find. + namer: callable + The namer function to use for naming fields. + **kwargs : Any + Additional arguments for selecting the variable. + + Returns + ------- + Any + The selected variable (FieldArray subset). + """ + + def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str: + return namer(field, original_metadata) + + data = FieldArray([f.clone(name=_name) for f in data]) + return data.sel(name=name, **kwargs) + + class RulesNamer: """A namer that uses rules to generate names.""" @@ -181,11 +208,7 @@ def _find_variable(self, data: Any, name: str, **kwargs: Any) -> Any: The selected variable (FieldArray subset). """ - def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str: - return self._namer(field, original_metadata) - - data = FieldArray([f.clone(name=_name) for f in data]) - return data.sel(name=name, **kwargs) + return find_variable(data, name, self._namer, **kwargs) def _create_state( self, diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index a10340513..41d7089ff 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -326,8 +326,11 @@ def write_step(self, state: State) -> None: LOG.error("Error writing field %s", name) LOG.error("Template: %s", template) LOG.error("Keys:\n%s", json.dumps(keys, indent=4, default=str)) + self.template_manager.log_summary() raise + self.template_manager.log_summary() + @abstractmethod def write_message(self, message: FloatArray, *args: Any, **kwargs: Any) -> None: """Write a message to the grib file. diff --git a/tests/unit/test_templates.py b/tests/unit/test_templates.py new file mode 100644 index 000000000..b98a010bc --- /dev/null +++ b/tests/unit/test_templates.py @@ -0,0 +1,155 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from pathlib import Path + +import pytest +import yaml +from anemoi.utils.testing import GetTestData +from earthkit.data.readers.grib.codes import GribField +from pytest_mock import MockerFixture + +from anemoi.inference.checkpoint import Checkpoint +from anemoi.inference.grib.templates.manager import TemplateManager +from anemoi.inference.testing import fake_checkpoints +from anemoi.inference.testing import files_for_tests + + +@pytest.fixture +def manager(mocker: MockerFixture) -> type[TemplateManager]: + @fake_checkpoints + def _manager(config=None): + checkpoint = Checkpoint(files_for_tests("unit/checkpoints/atmos.json")) + checkpoint.typed_variables["unknown"] = checkpoint.typed_variables["2t"] # used in auto-unknown test + owner = mocker.MagicMock() + owner.context.checkpoint = checkpoint + return TemplateManager(owner, templates=config) + + return _manager + + +@pytest.fixture +def grib_template(get_test_data: GetTestData) -> Path: + return get_test_data("anemoi-integration-tests/inference/single-o48-1.1/input.grib") + + +@pytest.fixture +def template_index(grib_template: Path, tmp_path: Path) -> Path: + index = [ + [{"grid": "O96", "levtype": "pl"}, grib_template], + [{"grid": "O96"}, grib_template], + ] + index_file = tmp_path / "templates_index.yaml" + with open(index_file, "w") as f: + yaml.dump(index, f) + return index_file + + +@pytest.mark.parametrize( + "variable, expected_param", + [ + pytest.param("2t", "lsm", id="sfc"), # lsm is the builtin template for sfc + pytest.param("w_100", "q", id="pl"), # q is the builtin template for pl + ], +) +def test_builtin(manager, variable, expected_param): + manager = manager() + template = manager.template(variable, state={}, typed_variables=manager.typed_variables) + + assert isinstance(template, GribField) + assert template.metadata("param") == expected_param + + +@pytest.mark.parametrize( + "file_config, variable, expected_param, expected_type", + [ + pytest.param({}, "2t", "10u", GribField, id="first"), + pytest.param({"mode": "last"}, "2t", "v", GribField, id="last"), + pytest.param({"mode": "auto"}, "2t", "2t", GribField, id="auto-sfc"), + pytest.param({"mode": "auto"}, "w_100", "w", GribField, id="auto-pl"), + pytest.param({"mode": "auto"}, "unknown", None, type(None), id="auto unknown"), + pytest.param({"variables": "10u"}, "2t", None, type(None), id="skip variable"), + ], +) +def test_file(manager, grib_template, file_config, variable, expected_param, expected_type): + config = { + "file": { + "path": grib_template, + **file_config, + } + } + manager = manager(config) + template = manager.template(variable, state={}, typed_variables=manager.typed_variables) + + assert isinstance(template, expected_type) + + if expected_param is not None: + assert template.metadata("param") == expected_param + + +def test_samples_index_path(manager, template_index): + config = { + "samples": {"index_path": str(template_index)}, + } + manager = manager(config) + template = manager.template("2t", state={}, typed_variables=manager.typed_variables) + + assert isinstance(template, GribField) + assert template.metadata("param") == "10u" # first field in the file + + +def test_samples_index_path_str(manager, template_index): + config = { + "samples": str(template_index), + } + manager = manager(config) + template = manager.template("2t", state={}, typed_variables=manager.typed_variables) + + assert isinstance(template, GribField) + assert template.metadata("param") == "10u" # first field in the file + + +def test_samples_direct_index(manager, template_index): + with open(template_index, "r") as f: + index = yaml.safe_load(f) + config = {"samples": index} + manager = manager(config) + template = manager.template("2t", state={}, typed_variables=manager.typed_variables) + + assert isinstance(template, GribField) + assert template.metadata("param") == "10u" # first field in the file + + +@pytest.mark.parametrize( + "config", + [ + pytest.param(None, id="default"), + pytest.param("input", id="input only"), + pytest.param(["input", "builtin"], id="input+builtin"), + pytest.param({"input": {"10u": "2t", "2t": "ignored"}}, id="input with fallback"), + ], +) +def test_input(manager, config, request): + manager = manager(config) + state = { + "_grib_templates_for_output": { + "2t": "2t_input_template", + } + } + + template = manager.template("2t", state=state, typed_variables=manager.typed_variables) + assert template == "2t_input_template" + + template = manager.template("10u", state=state, typed_variables=manager.typed_variables) + if request.node.callspec.id == "input only": + assert template is None + elif request.node.callspec.id == "input with fallback": + assert template == "2t_input_template" + else: + assert template.metadata("param") == "lsm" # builtin