diff --git a/src/anemoi/inference/decorators.py b/src/anemoi/inference/decorators.py index bb8a15dd..a6f36919 100644 --- a/src/anemoi/inference/decorators.py +++ b/src/anemoi/inference/decorators.py @@ -8,11 +8,15 @@ # nor does it submit to any jurisdiction. +import logging +from pathlib import Path from typing import Any from typing import TypeVar from anemoi.inference.context import Context +LOG = logging.getLogger("anemoi.inference") + MARKER = object() F = TypeVar("F", bound=type) @@ -64,3 +68,67 @@ def __init__(wrapped_cls, context: Context, main: object = MARKER, *args: Any, * super().__init__(context, *args, **kwargs) return type(cls.__name__, (WrappedClass,), {}) + + +class ensure_path: + """Decorator to ensure a path argument is a Path object and optionally exists. + + If `is_dir` is True, the path is treated as a directory, if not for files, the parent directory is treated as a directory. + If `must_exist` is True, the directory must exist. + If `create` is True, the directory will be created if it doesn't exist. + + For example: + ``` + @ensure_path("dir", create=True) + class GribOutput + def __init__(context, dir=None, archive_requests=None): + ... + """ + + def __init__(self, arg: str, is_dir: bool = False, create: bool = True, must_exist: bool = False): + self.arg = arg + self.is_dir = is_dir + self.create = create + self.must_exist = must_exist + + def __call__(self, cls: F) -> F: + """Decorate the object to ensure the path argument is a Path object.""" + + class WrappedClass(cls): + def __init__(wrapped_cls, context: Context, *args: Any, **kwargs: Any) -> None: + if self.arg not in kwargs: + LOG.debug(f"Argument '{self.arg}' not found in kwargs, cannot ensure path.") + super().__init__(context, *args, **kwargs) + return + + path = kwargs[self.arg] = Path(kwargs[self.arg]) + if not self.is_dir: + path = path.parent + + if self.must_exist: + if not path.exists(): + raise FileNotFoundError(f"Path '{path}' does not exist.") + if self.create: + path.mkdir(parents=True, exist_ok=True) + + super().__init__(context, *args, **kwargs) + + return type(cls.__name__, (WrappedClass,), {}) + + +class ensure_dir(ensure_path): + """Decorator to ensure a directory path argument is a Path object and optionally exists. + + If `must_exist` is True, the directory must exist. + If `create` is True, the directory will be created if it doesn't exist. + + For example: + ``` + @ensure_dir("dir", create=True) + class PlotOutput + def __init__(context, dir=None, ...): + ... + """ + + def __init__(self, arg: str, create: bool = True, must_exist: bool = False): + super().__init__(arg, is_dir=True, create=create, must_exist=must_exist) diff --git a/src/anemoi/inference/grib/encoding.py b/src/anemoi/inference/grib/encoding.py index ecb03f78..24e34e80 100644 --- a/src/anemoi/inference/grib/encoding.py +++ b/src/anemoi/inference/grib/encoding.py @@ -9,15 +9,18 @@ import logging -import re import warnings from io import IOBase +from pathlib import Path from typing import TYPE_CHECKING from typing import Any +from typing import Hashable import earthkit.data as ekd from anemoi.utils.dates import as_timedelta +from anemoi.inference.utils.templating import render_template + if TYPE_CHECKING: from anemoi.transform.variables import Variable @@ -469,12 +472,12 @@ def encode_message( class GribWriter: """Write GRIB messages to one or more files.""" - def __init__(self, out: str | IOBase, split_output: bool = True) -> None: + def __init__(self, out: Path | IOBase, split_output: bool = True) -> None: """Initialize the GribWriter. Parameters ---------- - out : Union[str, IOBase] + out : Union[Path, IOBase] Path or file-like object to write the grib data to. If a string, it should be a file path. If a file-like object, it should be opened in binary write mode. @@ -487,7 +490,7 @@ def __init__(self, out: str | IOBase, split_output: bool = True) -> None: self.out = out self.split_output = split_output - self._files: dict[str, IOBase] = {} + self._files: dict[Hashable, IOBase] = {} def close(self) -> None: """Close all open files.""" @@ -570,7 +573,7 @@ def write( return handle, path - def target(self, handle: Any) -> tuple[IOBase, str]: + def target(self, handle: Any) -> tuple[IOBase, Path | str]: """Determine the target file for the GRIB message. Parameters @@ -584,7 +587,8 @@ def target(self, handle: Any) -> tuple[IOBase, str]: The file object and the file path. """ if self.split_output: - out = render_template(self.out, handle) + assert not isinstance(self.out, IOBase), "Cannot split output when `out` is a file-like object." + out = render_template(str(self.out), handle) else: out = self.out @@ -596,37 +600,3 @@ def target(self, handle: Any) -> tuple[IOBase, str]: self._files[out] = open(out, "wb") return self._files[out], out - - -_TEMPLATE_EXPRESSION_PATTERN = re.compile(r"\{(.*?)\}") - - -def render_template(template: str, handle: dict) -> str: - """Render a template string with the given keyword arguments. - - Given a template string such as '{dateTime}_{step:03}.grib' and - the GRIB handle, this function will replace the expressions in the - template with the corresponding values from the handle, formatted - according to the optional format specifier. - - For example, the template '{dateTime}_{step:03}.grib' with a handle - containing 'dateTime' as '202501011200' and 'step' as 6 will - produce '202501011200_006.grib'. - - Parameters - ---------- - template : str - The template string to render. - handle : Dict - The earthkit.data handle manager. - - Returns - ------- - str - The rendered template string. - """ - expressions = _TEMPLATE_EXPRESSION_PATTERN.findall(template) - expr_format = [el.split(":") if ":" in el else [el, ""] for el in expressions] - keys = {k[0]: format(handle.get(k[0]), k[1]) for k in expr_format} - path = template.format(**keys) - return path diff --git a/src/anemoi/inference/grib/templates/manager.py b/src/anemoi/inference/grib/templates/manager.py index ba09a88d..45692f2d 100644 --- a/src/anemoi/inference/grib/templates/manager.py +++ b/src/anemoi/inference/grib/templates/manager.py @@ -51,7 +51,7 @@ def __init__(self, owner: Any, templates: list[str] | str | None = None) -> None self.templates_providers = [create_template_provider(self, template) for template in templates] - def template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.Field | None: + def template(self, name: str, state: State, typed_variables: dict[str, Any]) -> ekd.Field | None: """Get the template for a given name and state. Parameters @@ -60,8 +60,8 @@ def template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.F The name of the template. state : State The state object containing template information. - typed_variables : list of Any - The list of typed variables. + typed_variables : dict[str, Any] + The dictionary of typed variables. Returns ------- @@ -78,7 +78,7 @@ def template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.F return self._template_cache.get(name) - def load_template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.Field | None: + def load_template(self, name: str, state: State, typed_variables: dict[str, Any]) -> ekd.Field | None: """Load the template for a given name and state. Parameters @@ -87,8 +87,8 @@ def load_template(self, name: str, state: State, typed_variables: list[Any]) -> The name of the template. state : State The state object containing template information. - typed_variables : list of Any - The list of typed variables. + typed_variables : dict[str, Any] + The dictionary of typed variables. Returns ------- diff --git a/src/anemoi/inference/output.py b/src/anemoi/inference/output.py index 02d7a8f2..42714f7a 100644 --- a/src/anemoi/inference/output.py +++ b/src/anemoi/inference/output.py @@ -12,6 +12,7 @@ from abc import abstractmethod from functools import cached_property from typing import TYPE_CHECKING +from typing import Any from anemoi.inference.post_processors import create_post_processor from anemoi.inference.processor import Processor @@ -244,20 +245,20 @@ class ForwardOutput(Output): def __init__( self, context: "Context", - output: dict | None, + output: Output | Any, variables: list[str] | None = None, post_processors: list[ProcessorConfig] | None = None, output_frequency: int | None = None, write_initial_state: bool | None = None, ): - """Initialize the ForwardOutput object. + """Initialise the ForwardOutput object. Parameters ---------- context : Context The context in which the output operates. - output : dict - The output configuration dictionary. + output : Output | Any + The output configuration dictionary or an Output instance. variables : list, optional The list of variables, by default None. post_processors : Optional[List[ProcessorConfig]], default None @@ -277,8 +278,9 @@ def __init__( output_frequency=None, write_initial_state=write_initial_state, ) - - self.output = None if output is None else create_output(context, output) + if not isinstance(output, Output): + output = create_output(context, output) + self.output = output if self.context.output_frequency is not None: LOG.warning("output_frequency is ignored for '%s'", self.__class__.__name__) diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index 1afd5ffe..71aadca1 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -174,7 +174,7 @@ def __init__( output_frequency: int | None = None, write_initial_state: bool | None = None, ) -> None: - """Initialize the GribOutput object. + """Initialise the GribOutput object. Parameters ---------- @@ -285,6 +285,8 @@ def write_step(self, state: State) -> None: """ reference_date = self.reference_date or self.context.reference_date + assert reference_date is not None, "No reference date set" + step = state["step"] previous_step = state.get("previous_step") start_steps = state.get("start_steps", {}) @@ -367,10 +369,6 @@ def template(self, state: State, name: str) -> object: object The template object. """ - - if self.template_manager is None: - self.template_manager = TemplateManager(self, self.templates) - return self.template_manager.template(name, state, self.typed_variables) def template_lookup(self, name: str) -> dict: diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index 9089d22b..0d9329bf 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -12,6 +12,7 @@ import logging from collections import defaultdict from io import IOBase +from pathlib import Path from typing import Any import earthkit.data as ekd @@ -22,6 +23,7 @@ from anemoi.inference.types import FloatArray from anemoi.inference.types import ProcessorConfig +from ..decorators import ensure_path from ..decorators import main_argument from ..grib.encoding import GribWriter from ..grib.encoding import check_encoding @@ -114,7 +116,7 @@ def __init__( self, context: Context, *, - out: str | IOBase, + out: Path | IOBase, post_processors: list[ProcessorConfig] | None = None, encoding: dict[str, Any] | None = None, archive_requests: dict[str, Any] | None = None, @@ -134,9 +136,8 @@ def __init__( ---------- context : Context The context. - out : Union[str, IOBase] + out : Union[Path, IOBase] Path or file-like object to write the grib data to. - If a string, it should be a file path. If a file-like object, it should be opened in binary write mode. post_processors : Optional[List[ProcessorConfig]], default None Post-processors to apply to the input @@ -312,6 +313,7 @@ def _patch(r: DataRequest) -> DataRequest: @output_registry.register("grib") @main_argument("path") +@ensure_path("path") class GribFileOutput(GribIoOutput): """Handles grib files.""" @@ -319,7 +321,7 @@ def __init__( self, context: Context, *, - path: str, + path: Path, post_processors: list[ProcessorConfig] | None = None, encoding: dict[str, Any] | None = None, archive_requests: dict[str, Any] | None = None, @@ -333,14 +335,15 @@ def __init__( write_initial_state: bool | None = None, split_output: bool = True, ) -> None: - """Initialize the GribFileOutput. + """Initialise the GribFileOutput. Parameters ---------- context : Context The context. - path : str + path : Path Path to the grib file to write the data to. + If the parent directory does not exist, it will be created. post_processors : Optional[List[ProcessorConfig]], default None Post-processors to apply to the input encoding : dict, optional diff --git a/src/anemoi/inference/outputs/netcdf.py b/src/anemoi/inference/outputs/netcdf.py index ccc76979..0a0a779a 100644 --- a/src/anemoi/inference/outputs/netcdf.py +++ b/src/anemoi/inference/outputs/netcdf.py @@ -10,6 +10,7 @@ import logging import os import threading +from pathlib import Path import numpy as np @@ -17,6 +18,7 @@ from anemoi.inference.types import ProcessorConfig from anemoi.inference.types import State +from ..decorators import ensure_path from ..decorators import main_argument from ..output import Output from . import output_registry @@ -30,13 +32,14 @@ @output_registry.register("netcdf") @main_argument("path") +@ensure_path("path") class NetCDFOutput(Output): """NetCDF output class.""" def __init__( self, context: Context, - path: str, + path: Path, variables: list[str] | None = None, post_processors: list[ProcessorConfig] | None = None, output_frequency: int | None = None, @@ -44,14 +47,15 @@ def __init__( float_size: str = "f4", missing_value: float | None = np.nan, ) -> None: - """Initialize the NetCDF output object. + """Initialise the NetCDF output object. Parameters ---------- context : dict The context dictionary. - path : str - The path to save the NetCDF file. + path : Path + The path to save the NetCDF file to. + If the parent directory does not exist, it will be created. post_processors : Optional[List[ProcessorConfig]], default None Post-processors to apply to the input output_frequency : int, optional diff --git a/src/anemoi/inference/outputs/plot.py b/src/anemoi/inference/outputs/plot.py index 7c6b1395..4cf5e300 100644 --- a/src/anemoi/inference/outputs/plot.py +++ b/src/anemoi/inference/outputs/plot.py @@ -8,16 +8,18 @@ # nor does it submit to any jurisdiction. import logging -import os +from pathlib import Path import numpy as np from anemoi.utils.grib import units from anemoi.inference.context import Context +from anemoi.inference.decorators import ensure_dir from anemoi.inference.decorators import main_argument from anemoi.inference.types import FloatArray from anemoi.inference.types import ProcessorConfig from anemoi.inference.types import State +from anemoi.inference.utils.templating import render_template from ..output import Output from . import output_registry @@ -42,55 +44,59 @@ def fix(lons: FloatArray) -> FloatArray: @output_registry.register("plot") -@main_argument("path") +@main_argument("dir") +@ensure_dir("dir") class PlotOutput(Output): """Use `earthkit-plots` to plot the outputs.""" def __init__( self, context: Context, - path: str, + dir: Path, *, variables: list[str] | None = None, mode: str = "subplots", domain: str | list[str] | None = None, - strftime: str = "%Y%m%d%H%M%S", template: str = "plot_{date}.{format}", format: str = "png", - missing_value: float | None = None, post_processors: list[ProcessorConfig] | None = None, output_frequency: int | None = None, write_initial_state: bool | None = None, **kwargs, ) -> None: - """Initialize the PlotOutput. + """Initialise the PlotOutput. Parameters ---------- context : Context The context. - path : str - The path to save the plots. + dir : Path + The directory to save the plots. + If the directory does not exist, it will be created. variables : list, optional The list of variables to plot, by default all. mode : str, optional The plotting mode, can be "subplots" or "overlay", by default "subplots". domain : str | list[str] | None, optional The domain/s to plot, by default None. - strftime : str, optional - The date format string, by default "%Y%m%d%H%M%S". template : str, optional The template for plot filenames, by default "plot_{date}.{format}". + Has access to the following variables: + - date: the date of the forecast step + - basetime: the base time of the forecast + - domain: the domain being plotted + - format: the format of the plot + - variables: the variables being plotted (joined by underscores) format : str, optional The format of the plot, by default "png". - missing_value : float, optional - The value to use for missing data, by default None. post_processors : Optional[List[ProcessorConfig]], default None Post-processors to apply to the input output_frequency : int, optional The frequency of output, by default None. write_initial_state : bool, optional Whether to write the initial state, by default None. + **kwargs : Any + Additional keyword arguments to pass to `earthkit.plots.quickplot`. """ super().__init__( @@ -100,12 +106,11 @@ def __init__( output_frequency=output_frequency, write_initial_state=write_initial_state, ) - self.path = path + + self.dir = dir self.format = format self.variables = variables - self.strftime = strftime self.template = template - self.missing_value = missing_value self.domain = domain self.mode = mode self.kwargs = kwargs @@ -121,8 +126,6 @@ def write_step(self, state: State) -> None: import earthkit.data as ekd import earthkit.plots as ekp - os.makedirs(self.path, exist_ok=True) - longitudes = fix(state["longitudes"]) latitudes = state["latitudes"] date = state["date"] @@ -134,7 +137,7 @@ def write_step(self, state: State) -> None: if self.skip_variable(name): continue - variable = self.context.checkpoint.typed_variables[name] + variable = self.typed_variables[name] param = variable.param plotting_fields.append( @@ -153,10 +156,19 @@ def write_step(self, state: State) -> None: ) fig = ekp.quickplot( - ekd.FieldList.from_fields((plotting_fields)), mode=self.mode, domain=self.domain, **self.kwargs + ekd.FieldList.from_fields(plotting_fields), mode=self.mode, domain=self.domain, **self.kwargs + ) + fname = render_template( + self.template, + { + "date": date, + "basetime": basetime, + "domain": self.domain, + "format": self.format, + "variables": "_".join(self.variables or []), + }, ) - fname = self.template.format(date=date, format=self.format) - fname = os.path.join(self.path, fname) + fname = self.dir / fname fig.save(fname) del fig diff --git a/src/anemoi/inference/outputs/printer.py b/src/anemoi/inference/outputs/printer.py index 2584d687..f1dab062 100644 --- a/src/anemoi/inference/outputs/printer.py +++ b/src/anemoi/inference/outputs/printer.py @@ -11,6 +11,7 @@ import logging from collections.abc import Callable from functools import partial +from pathlib import Path from typing import Any from typing import Literal from typing import Union @@ -18,9 +19,9 @@ import numpy as np from anemoi.inference.context import Context -from anemoi.inference.types import ProcessorConfig from anemoi.inference.types import State +from ..decorators import ensure_path from ..decorators import main_argument from ..output import Output from . import output_registry @@ -105,36 +106,42 @@ def print_state( @output_registry.register("printer") @main_argument("max_lines") +@ensure_path("path") class PrinterOutput(Output): """Printer output class.""" def __init__( self, context: Context, - post_processors: list[ProcessorConfig] | None = None, - path: str | None = None, + path: Path | None = None, variables: ListOrAll | None = None, + max_lines: int = 4, **kwargs: Any, ) -> None: - """Initialize the PrinterOutput. + """Initialise the PrinterOutput. Parameters ---------- context : Context The context. - post_processors : Optional[List[ProcessorConfig]] = None - Post-processors to apply to the input - path : str, optional + path : Path, optional The path to save the printed output, by default None. + If the parent directory does not exist, it will be created. variables : list, optional The list of variables to print, by default None. + max_lines : int, optional + The maximum number of lines to print, by default 4. + If set to 0, all variables will be printed. **kwargs : Any Additional keyword arguments. """ - super().__init__(context, variables=variables, post_processors=post_processors) + super().__init__(context, variables=variables, **kwargs) self.print = print self.variables = variables + self.max_lines = max_lines + + self.f = None if path is not None: self.f = open(path, "w") @@ -148,4 +155,9 @@ def write_step(self, state: State) -> None: state : State The state dictionary. """ - print_state(state, print=self.print, variables=self.variables) + print_state(state, print=self.print, variables=self.variables, max_lines=self.max_lines) + + def close(self) -> None: + if self.f is not None: + self.f.close() + return super().close() diff --git a/src/anemoi/inference/outputs/raw.py b/src/anemoi/inference/outputs/raw.py index 0528e509..58778b7d 100644 --- a/src/anemoi/inference/outputs/raw.py +++ b/src/anemoi/inference/outputs/raw.py @@ -8,14 +8,15 @@ # nor does it submit to any jurisdiction. import logging -import os +from pathlib import Path import numpy as np from anemoi.inference.context import Context -from anemoi.inference.types import ProcessorConfig from anemoi.inference.types import State +from anemoi.inference.utils.templating import render_template +from ..decorators import ensure_dir from ..decorators import main_argument from ..output import Output from . import output_registry @@ -25,47 +26,36 @@ @output_registry.register("raw") @main_argument("path") +@ensure_dir("dir") class RawOutput(Output): """Raw output class.""" def __init__( self, context: Context, - path: str, + dir: Path, template: str = "{date}.npz", strftime: str = "%Y%m%d%H%M%S", variables: list[str] | None = None, - post_processors: list[ProcessorConfig] | None = None, - output_frequency: int | None = None, - write_initial_state: bool | None = None, + **kwargs, ) -> None: - """Initialize the RawOutput class. + """Initialise the RawOutput class. Parameters ---------- context : dict The context. - path : str - The path to save the raw output. + dir : Path + The directory to save the raw output. + If the parent directory does not exist, it will be created. template : str, optional The template for filenames, by default "{date}.npz". + Variables available are `date`, `basetime` `step`. strftime : str, optional The date format string, by default "%Y%m%d%H%M%S". - post_processors : Optional[List[ProcessorConfig]], default None - Post-processors to apply to the input - output_frequency : int, optional - The frequency of output, by default None. - write_initial_state : bool, optional - Whether to write the initial state, by default None. """ - super().__init__( - context, - variables=variables, - post_processors=post_processors, - output_frequency=output_frequency, - write_initial_state=write_initial_state, - ) - self.path = path + super().__init__(context, variables=variables, **kwargs) + self.dir = dir self.template = template self.strftime = strftime @@ -77,7 +67,7 @@ def __repr__(self) -> str: str String representation of the RawOutput object. """ - return f"RawOutput({self.path})" + return f"RawOutput({self.dir})" def write_step(self, state: State) -> None: """Write the state to a compressed .npz file. @@ -87,9 +77,16 @@ def write_step(self, state: State) -> None: state : State The state to be written. """ - os.makedirs(self.path, exist_ok=True) - date = state["date"].strftime(self.strftime) - fn_state = f"{self.path}/{self.template.format(date=date)}" + date = state["date"] + basetime = date - state["step"] + + format_info = { + "date": date.strftime(self.strftime), + "step": state["step"], + "basetime": basetime.strftime(self.strftime), + } + + fn_state = f"{self.dir}/{render_template(self.template, format_info)}" restate = {f"field_{key}": val for key, val in state["fields"].items() if not self.skip_variable(key)} for key in ["date"]: diff --git a/src/anemoi/inference/outputs/tee.py b/src/anemoi/inference/outputs/tee.py index bf49a184..a9705316 100644 --- a/src/anemoi/inference/outputs/tee.py +++ b/src/anemoi/inference/outputs/tee.py @@ -12,11 +12,10 @@ from collections.abc import Sequence from typing import Any -from anemoi.inference.config import Configuration from anemoi.inference.context import Context from anemoi.inference.types import State -from ..output import ForwardOutput +from ..output import Output from . import create_output from . import output_registry @@ -24,39 +23,39 @@ @output_registry.register("tee") -class TeeOutput(ForwardOutput): +class TeeOutput(Output): """TeeOutput class to manage multiple outputs.""" def __init__( self, context: Context, - *args: Configuration, - outputs: Sequence[Configuration] | None = None, + *args, + outputs: Sequence | None = None, **kwargs: Any, ): - """Initialize the TeeOutput. + """Initialise the TeeOutput. Parameters ---------- context : object The context object. - *args : Configuration + *args : Additional positional arguments. - outputs : Sequence[Configuration], optional + outputs : Sequence, optional Outputs to be created. **kwargs : Any Additional keyword arguments. """ super().__init__( context, - None, **kwargs, ) if outputs is None: outputs = args + else: + outputs = [*args, *outputs] - assert isinstance(outputs, (list, tuple)), outputs self.outputs = [create_output(context, output) for output in outputs] # We override write_initial_state and write_state diff --git a/src/anemoi/inference/outputs/truth.py b/src/anemoi/inference/outputs/truth.py index b20e1ebe..9b30e635 100644 --- a/src/anemoi/inference/outputs/truth.py +++ b/src/anemoi/inference/outputs/truth.py @@ -10,11 +10,11 @@ import logging from typing import Any -from anemoi.inference.config import Configuration +from anemoi.inference.runners.default import DefaultRunner from anemoi.inference.state import reduce_state from anemoi.inference.types import State -from ..context import Context +from ..decorators import main_argument from ..output import ForwardOutput from . import output_registry @@ -22,6 +22,7 @@ @output_registry.register("truth") +@main_argument("output") class TruthOutput(ForwardOutput): """Write the input state at the same time for each output state. @@ -29,20 +30,23 @@ class TruthOutput(ForwardOutput): the forecasts, effectively only for times in the past. """ - def __init__(self, context: Context, output: Configuration, **kwargs: Any) -> None: - """Initialize the TruthOutput. + def __init__(self, context: DefaultRunner, output, **kwargs: Any) -> None: + """Initialise the TruthOutput. Parameters ---------- context : Context The context for the output. - output : Configuration + output : The output configuration. kwargs : dict Additional keyword arguments. """ + if not isinstance(context, DefaultRunner): + raise ValueError("TruthOutput can only be used with `DefaultRunner`") + super().__init__(context, output, None, **kwargs) - self._input = self.context.create_input() + self._input = context.create_prognostics_input() def write_step(self, state: State) -> None: """Write a step of the state. @@ -53,6 +57,8 @@ def write_step(self, state: State) -> None: The state to write. """ truth_state = self._input.create_input_state(date=state["date"]) + truth_state["step"] = state["step"] + reduced_state = reduce_state(truth_state) self.output.write_state(reduced_state) diff --git a/src/anemoi/inference/outputs/zarr.py b/src/anemoi/inference/outputs/zarr.py index 0a8450a1..7f72f6a9 100644 --- a/src/anemoi/inference/outputs/zarr.py +++ b/src/anemoi/inference/outputs/zarr.py @@ -10,8 +10,9 @@ from __future__ import annotations import logging -import os import shutil +from pathlib import Path +from typing import TYPE_CHECKING from typing import Any from typing import Literal @@ -27,14 +28,17 @@ LOG = logging.getLogger(__name__) +if TYPE_CHECKING: + from zarr.storage import StoreLike + def create_zarr_array( - store: Any, + store: "StoreLike", name: str, shape: tuple, dtype: str, dimensions: tuple[str, ...], - chunks: tuple[int, ...] | Literal["auto"] | bool, + chunks: tuple[int, ...] | str | bool, fill_value: float | None = None, ) -> Any: """Create a Zarr array with the given parameters. @@ -45,8 +49,6 @@ def create_zarr_array( chunks = chunks if zarr.__version__ >= "3" else chunks if not chunks == "auto" else True - store: zarr.Group = store - if zarr.__version__ >= "3": from zarr import create_array else: @@ -78,21 +80,21 @@ class ZarrOutput(Output): def __init__( self, context: Context, - store: Any, + store: "StoreLike", variables: list[str] | None = None, output_frequency: int | None = None, write_initial_state: bool | None = None, missing_value: float | None = np.nan, float_size: str = "f4", - chunks: tuple[int, ...] | Literal["auto"] = "auto", + chunks: tuple[int, ...] | Literal["auto"] | bool = "auto", ) -> None: - """Initialize the ZarrOutput object. + """Initialise the ZarrOutput object. Parameters ---------- context : dict The context dictionary. - store : Any + store : StoreLike The Zarr store to save the data. Can be a file path or a Zarr store. If an existing store is provided, it is assumed to @@ -107,7 +109,7 @@ def __init__( The size of the float, by default "f4". missing_value : float, optional The missing value, by default np.nan. - chunks : tuple[int, ...] | Literal['auto'], optional + chunks : tuple[int, ...] | Literal['auto'] | bool, optional The chunk size for the Zarr arrays, by default 'auto'. """ @@ -115,8 +117,6 @@ def __init__( context, variables=variables, output_frequency=output_frequency, write_initial_state=write_initial_state ) - from zarr.storage import StoreLike - self.zarr_store: StoreLike = store self.missing_value = missing_value self.chunks = chunks @@ -138,8 +138,11 @@ def open(self, state: State) -> None: """ import zarr - if isinstance(self.zarr_store, str): - if os.path.exists(self.zarr_store): + if isinstance(self.zarr_store, (str, Path)): + zarr_store = Path(self.zarr_store) + zarr_store.parent.mkdir(parents=True, exist_ok=True) + + if zarr_store.exists(): LOG.warning(f"Zarr store {self.zarr_store} already exists. It will be overwritten.") shutil.rmtree(self.zarr_store) diff --git a/src/anemoi/inference/testing/checks.py b/src/anemoi/inference/testing/checks.py index 681fe1b2..83d0ebaf 100644 --- a/src/anemoi/inference/testing/checks.py +++ b/src/anemoi/inference/testing/checks.py @@ -158,3 +158,24 @@ def check_lam( elif reference_file: # check against a reference file, implement when needed raise NotImplementedError("Reference file check is not implemented yet.") + + +@testing_registry.register("check_file_exist") +def check_file_exist(*, file: Path, **kwargs) -> None: + LOG.info(f"Checking file exists: {file}") + assert file.exists(), f"File {file} does not exist." + assert file.stat().st_size > 0, f"File {file} is empty." + + +@testing_registry.register("check_files_in_directory") +def check_files_in_directory(*, file: Path, expected_files: int | None = None, **kwargs) -> None: + LOG.info(f"Checking directory: {file}") + assert file.exists() and file.is_dir(), f"Directory {file} does not exist or is not a directory." + if expected_files is not None: + actual_files = [f for f in file.iterdir() if f.is_file()] + if expected_files < 0: + assert len(actual_files) > 0, "Expected at least one file, but found none." + else: + assert ( + len(actual_files) == expected_files + ), f"Expected {expected_files} files, but found {len(actual_files)}." diff --git a/src/anemoi/inference/utils/templating.py b/src/anemoi/inference/utils/templating.py new file mode 100644 index 00000000..180d0806 --- /dev/null +++ b/src/anemoi/inference/utils/templating.py @@ -0,0 +1,44 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import re + +_TEMPLATE_EXPRESSION_PATTERN = re.compile(r"\{(.*?)\}") + + +def render_template(template: str, handle: dict) -> str: + """Render a template string with the given keyword arguments. + + Given a template string such as '{dateTime}_{step:03}.grib' and + the GRIB handle, this function will replace the expressions in the + template with the corresponding values from the handle, formatted + according to the optional format specifier. + + For example, the template '{dateTime}_{step:03}.grib' with a handle + containing 'dateTime' as '202501011200' and 'step' as 6 will + produce '202501011200_006.grib'. + + Parameters + ---------- + template : str + The template string to render. + handle : dict + The dictionary to use for rendering the template. + + Returns + ------- + str + The rendered template string. + """ + expressions = _TEMPLATE_EXPRESSION_PATTERN.findall(str(template)) + expr_format = [el.split(":") if ":" in el else [el, ""] for el in expressions] + keys = {k[0]: format(handle.get(k[0]), k[1]) for k in expr_format} + path = str(template).format(**keys) + return path diff --git a/tests/integration/single-o48-1.1/config.yaml b/tests/integration/single-o48-1.1/config.yaml index b2b36676..b93775f0 100644 --- a/tests/integration/single-o48-1.1/config.yaml +++ b/tests/integration/single-o48-1.1/config.yaml @@ -83,3 +83,132 @@ grib: ${input:} output: zarr: ${output:} + +- name: grib-in-plots-out + input: input.grib + output: plots + checks: + - check_files_in_directory: + expected_files: 8 # 48h / 6 + inference_config: + write_initial_state: false + checkpoint: ${checkpoint:} + input: + grib: ${input:} + output: + plot: + dir: ${output:} + variables: [2t, 10u, 10v] + +- name: grib-in-plots-out + input: input.grib + output: plots + checks: + - check_files_in_directory: + expected_files: 9 # 48h / 6 + 1 + inference_config: + write_initial_state: true + checkpoint: ${checkpoint:} + input: + grib: ${input:} + output: + plot: + dir: ${output:} + mode: overlay + variables: [2t] + +- name: grib-in-truth-out + input: + output: output.grib + checks: + - check_grib: + expected_variables: + - 2t + - q + grib_keys: + stream: oper + class: ai + type: fc + check_nans: true + inference_config: + write_initial_state: false + checkpoint: ${checkpoint:} + input: dummy + date: '20250101' + output: + truth: + output: + grib: + path: ${output:} + encoding: + stream: oper + class: ai + type: fc + variables: + - 2t + - q_600 + +- name: grib-in-print-output + input: input.grib + output: output.txt + checks: + - check_file_exist: {} + inference_config: + write_initial_state: false + checkpoint: ${checkpoint:} + input: + grib: ${input:} + output: + printer: + path: ${output:} + +- name: grib-in-raw-output + input: input.grib + output: output + checks: + - check_files_in_directory: + expected_files: 8 + inference_config: + write_initial_state: false + checkpoint: ${checkpoint:} + input: + grib: ${input:} + output: + raw: + dir: ${output:} + +- name: grib-in-raw-output-templated + input: input.grib + output: output + checks: + - check_files_in_directory: + expected_files: 8 + inference_config: + write_initial_state: false + checkpoint: ${checkpoint:} + input: + grib: ${input:} + output: + raw: + dir: ${output:} + template: "{date}_step{step}.npz" + + +- name: grib-in-tee-output + input: input.grib + output: output + checks: + - check_files_in_directory: + expected_files: 3 + inference_config: + write_initial_state: false + checkpoint: ${checkpoint:} + input: + grib: ${input:} + output: + tee: + - grib: ${output:}/output.grib + - netcdf: ${output:}/output.nc + - tee: + outputs: + - netcdf: ${output:}/output.nc2