Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/anemoi/inference/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
# nor does it submit to any jurisdiction.


import logging
from pathlib import Path
from typing import Any
from typing import TypeVar

from anemoi.inference.context import Context

LOG = logging.getLogger("anemoi.inference")

MARKER = object()
F = TypeVar("F", bound=type)

Expand Down Expand Up @@ -64,3 +68,67 @@ def __init__(wrapped_cls, context: Context, main: object = MARKER, *args: Any, *
super().__init__(context, *args, **kwargs)

return type(cls.__name__, (WrappedClass,), {})


class ensure_path:
"""Decorator to ensure a path argument is a Path object and optionally exists.

If `is_dir` is True, the path is treated as a directory, if not for files, the parent directory is treated as a directory.
If `must_exist` is True, the directory must exist.
If `create` is True, the directory will be created if it doesn't exist.

For example:
```
@ensure_path("dir", create=True)
class GribOutput
def __init__(context, dir=None, archive_requests=None):
...
"""

def __init__(self, arg: str, is_dir: bool = False, create: bool = True, must_exist: bool = False):
self.arg = arg
self.is_dir = is_dir
self.create = create
self.must_exist = must_exist

def __call__(self, cls: F) -> F:
"""Decorate the object to ensure the path argument is a Path object."""

class WrappedClass(cls):
def __init__(wrapped_cls, context: Context, *args: Any, **kwargs: Any) -> None:
if self.arg not in kwargs:
LOG.debug(f"Argument '{self.arg}' not found in kwargs, cannot ensure path.")
super().__init__(context, *args, **kwargs)
return

path = kwargs[self.arg] = Path(kwargs[self.arg])
if not self.is_dir:
path = path.parent

if self.must_exist:
if not path.exists():
raise FileNotFoundError(f"Path '{path}' does not exist.")
if self.create:
path.mkdir(parents=True, exist_ok=True)

super().__init__(context, *args, **kwargs)

return type(cls.__name__, (WrappedClass,), {})


class ensure_dir(ensure_path):
"""Decorator to ensure a directory path argument is a Path object and optionally exists.

If `must_exist` is True, the directory must exist.
If `create` is True, the directory will be created if it doesn't exist.

For example:
```
@ensure_dir("dir", create=True)
class PlotOutput
def __init__(context, dir=None, ...):
...
"""

def __init__(self, arg: str, create: bool = True, must_exist: bool = False):
super().__init__(arg, is_dir=True, create=create, must_exist=must_exist)
50 changes: 10 additions & 40 deletions src/anemoi/inference/grib/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@


import logging
import re
import warnings
from io import IOBase
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Hashable

import earthkit.data as ekd
from anemoi.utils.dates import as_timedelta

from anemoi.inference.utils.templating import render_template

if TYPE_CHECKING:
from anemoi.transform.variables import Variable

Expand Down Expand Up @@ -469,12 +472,12 @@ def encode_message(
class GribWriter:
"""Write GRIB messages to one or more files."""

def __init__(self, out: str | IOBase, split_output: bool = True) -> None:
def __init__(self, out: Path | IOBase, split_output: bool = True) -> None:
"""Initialize the GribWriter.

Parameters
----------
out : Union[str, IOBase]
out : Union[Path, IOBase]
Path or file-like object to write the grib data to.
If a string, it should be a file path.
If a file-like object, it should be opened in binary write mode.
Expand All @@ -487,7 +490,7 @@ def __init__(self, out: str | IOBase, split_output: bool = True) -> None:

self.out = out
self.split_output = split_output
self._files: dict[str, IOBase] = {}
self._files: dict[Hashable, IOBase] = {}

def close(self) -> None:
"""Close all open files."""
Expand Down Expand Up @@ -570,7 +573,7 @@ def write(

return handle, path

def target(self, handle: Any) -> tuple[IOBase, str]:
def target(self, handle: Any) -> tuple[IOBase, Path | str]:
"""Determine the target file for the GRIB message.

Parameters
Expand All @@ -584,7 +587,8 @@ def target(self, handle: Any) -> tuple[IOBase, str]:
The file object and the file path.
"""
if self.split_output:
out = render_template(self.out, handle)
assert not isinstance(self.out, IOBase), "Cannot split output when `out` is a file-like object."
out = render_template(str(self.out), handle)
else:
out = self.out

Expand All @@ -596,37 +600,3 @@ def target(self, handle: Any) -> tuple[IOBase, str]:
self._files[out] = open(out, "wb")

return self._files[out], out


_TEMPLATE_EXPRESSION_PATTERN = re.compile(r"\{(.*?)\}")


def render_template(template: str, handle: dict) -> str:
"""Render a template string with the given keyword arguments.

Given a template string such as '{dateTime}_{step:03}.grib' and
the GRIB handle, this function will replace the expressions in the
template with the corresponding values from the handle, formatted
according to the optional format specifier.

For example, the template '{dateTime}_{step:03}.grib' with a handle
containing 'dateTime' as '202501011200' and 'step' as 6 will
produce '202501011200_006.grib'.

Parameters
----------
template : str
The template string to render.
handle : Dict
The earthkit.data handle manager.

Returns
-------
str
The rendered template string.
"""
expressions = _TEMPLATE_EXPRESSION_PATTERN.findall(template)
expr_format = [el.split(":") if ":" in el else [el, ""] for el in expressions]
keys = {k[0]: format(handle.get(k[0]), k[1]) for k in expr_format}
path = template.format(**keys)
return path
12 changes: 6 additions & 6 deletions src/anemoi/inference/grib/templates/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, owner: Any, templates: list[str] | str | None = None) -> None

self.templates_providers = [create_template_provider(self, template) for template in templates]

def template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.Field | None:
def template(self, name: str, state: State, typed_variables: dict[str, Any]) -> ekd.Field | None:
"""Get the template for a given name and state.

Parameters
Expand All @@ -60,8 +60,8 @@ def template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.F
The name of the template.
state : State
The state object containing template information.
typed_variables : list of Any
The list of typed variables.
typed_variables : dict[str, Any]
The dictionary of typed variables.

Returns
-------
Expand All @@ -78,7 +78,7 @@ def template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.F

return self._template_cache.get(name)

def load_template(self, name: str, state: State, typed_variables: list[Any]) -> ekd.Field | None:
def load_template(self, name: str, state: State, typed_variables: dict[str, Any]) -> ekd.Field | None:
"""Load the template for a given name and state.

Parameters
Expand All @@ -87,8 +87,8 @@ def load_template(self, name: str, state: State, typed_variables: list[Any]) ->
The name of the template.
state : State
The state object containing template information.
typed_variables : list of Any
The list of typed variables.
typed_variables : dict[str, Any]
The dictionary of typed variables.

Returns
-------
Expand Down
14 changes: 8 additions & 6 deletions src/anemoi/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from abc import abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING
from typing import Any

from anemoi.inference.post_processors import create_post_processor
from anemoi.inference.processor import Processor
Expand Down Expand Up @@ -244,20 +245,20 @@ class ForwardOutput(Output):
def __init__(
self,
context: "Context",
output: dict | None,
output: Output | Any,
variables: list[str] | None = None,
post_processors: list[ProcessorConfig] | None = None,
output_frequency: int | None = None,
write_initial_state: bool | None = None,
):
"""Initialize the ForwardOutput object.
"""Initialise the ForwardOutput object.

Parameters
----------
context : Context
The context in which the output operates.
output : dict
The output configuration dictionary.
output : Output | Any
The output configuration dictionary or an Output instance.
variables : list, optional
The list of variables, by default None.
post_processors : Optional[List[ProcessorConfig]], default None
Expand All @@ -277,8 +278,9 @@ def __init__(
output_frequency=None,
write_initial_state=write_initial_state,
)

self.output = None if output is None else create_output(context, output)
if not isinstance(output, Output):
output = create_output(context, output)
self.output = output

if self.context.output_frequency is not None:
LOG.warning("output_frequency is ignored for '%s'", self.__class__.__name__)
Expand Down
8 changes: 3 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
output_frequency: int | None = None,
write_initial_state: bool | None = None,
) -> None:
"""Initialize the GribOutput object.
"""Initialise the GribOutput object.

Parameters
----------
Expand Down Expand Up @@ -285,6 +285,8 @@ def write_step(self, state: State) -> None:
"""

reference_date = self.reference_date or self.context.reference_date
assert reference_date is not None, "No reference date set"

step = state["step"]
previous_step = state.get("previous_step")
start_steps = state.get("start_steps", {})
Expand Down Expand Up @@ -367,10 +369,6 @@ def template(self, state: State, name: str) -> object:
object
The template object.
"""

if self.template_manager is None:
self.template_manager = TemplateManager(self, self.templates)

return self.template_manager.template(name, state, self.typed_variables)

def template_lookup(self, name: str) -> dict:
Expand Down
15 changes: 9 additions & 6 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from collections import defaultdict
from io import IOBase
from pathlib import Path
from typing import Any

import earthkit.data as ekd
Expand All @@ -22,6 +23,7 @@
from anemoi.inference.types import FloatArray
from anemoi.inference.types import ProcessorConfig

from ..decorators import ensure_path
from ..decorators import main_argument
from ..grib.encoding import GribWriter
from ..grib.encoding import check_encoding
Expand Down Expand Up @@ -114,7 +116,7 @@ def __init__(
self,
context: Context,
*,
out: str | IOBase,
out: Path | IOBase,
post_processors: list[ProcessorConfig] | None = None,
encoding: dict[str, Any] | None = None,
archive_requests: dict[str, Any] | None = None,
Expand All @@ -134,9 +136,8 @@ def __init__(
----------
context : Context
The context.
out : Union[str, IOBase]
out : Union[Path, IOBase]
Path or file-like object to write the grib data to.
If a string, it should be a file path.
If a file-like object, it should be opened in binary write mode.
post_processors : Optional[List[ProcessorConfig]], default None
Post-processors to apply to the input
Expand Down Expand Up @@ -312,14 +313,15 @@ def _patch(r: DataRequest) -> DataRequest:

@output_registry.register("grib")
@main_argument("path")
@ensure_path("path")
class GribFileOutput(GribIoOutput):
"""Handles grib files."""

def __init__(
self,
context: Context,
*,
path: str,
path: Path,
post_processors: list[ProcessorConfig] | None = None,
encoding: dict[str, Any] | None = None,
archive_requests: dict[str, Any] | None = None,
Expand All @@ -333,14 +335,15 @@ def __init__(
write_initial_state: bool | None = None,
split_output: bool = True,
) -> None:
"""Initialize the GribFileOutput.
"""Initialise the GribFileOutput.

Parameters
----------
context : Context
The context.
path : str
path : Path
Path to the grib file to write the data to.
If the parent directory does not exist, it will be created.
post_processors : Optional[List[ProcessorConfig]], default None
Post-processors to apply to the input
encoding : dict, optional
Expand Down
Loading
Loading