From 2792f04112979203465d9a41f7982d50a781ae82 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 6 Nov 2025 15:21:20 +0000 Subject: [PATCH] feat: add support for forcings in SimpleRunner --- src/anemoi/inference/runners/simple.py | 29 ++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/anemoi/inference/runners/simple.py b/src/anemoi/inference/runners/simple.py index 9abaf68ca..384b79841 100644 --- a/src/anemoi/inference/runners/simple.py +++ b/src/anemoi/inference/runners/simple.py @@ -7,10 +7,14 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import logging from typing import Any +import numpy as np + +from anemoi.inference.context import Context +from anemoi.inference.types import Date +from anemoi.inference.types import FloatArray from anemoi.inference.types import IntArray from anemoi.inference.types import State @@ -58,11 +62,26 @@ def load_forcings_state(self, state: State, date: str) -> None: pass +class UserProvidedForcings(Forcings): + trace_name = "user_provided" + + def __init__(self, context: Context, forcings, variables: list[str], mask: IntArray): + super().__init__(context) + self.variables = variables + self.mask = mask + self.forcings = forcings + + def load_forcings_array(self, dates: list[Date], current_state: State) -> FloatArray: + indices = [self.forcings["dates"].index(d) for d in dates] + result = np.stack([self.forcings["fields"][var][indices] for var in self.variables], axis=0) + return result.astype(np.float32) + + @runner_registry.register("simple") class SimpleRunner(Runner): """Use that runner when using the low level API.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, forcings: dict[str, Any] | None = None, **kwargs: Any) -> None: """Initialize the SimpleRunner. Parameters @@ -73,6 +92,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: Keyword arguments. """ super().__init__(*args, **kwargs) + self.forcings = forcings def create_constant_computed_forcings(self, variables: list[str], mask: IntArray) -> list[Forcings]: """Create constant computed forcings. @@ -151,5 +171,6 @@ def create_dynamic_coupled_forcings(self, variables: list[str], mask: IntArray) # This runner does not support coupled forcings # there are supposed to be already in the state dictionary # or managed by the user. - LOG.warning("Coupled forcings are not supported by this runner: %s", variables) - return [] + # LOG.warning("Coupled forcings are not supported by this runner: %s", variables) + + return [UserProvidedForcings(self, self.forcings, variables, mask)]