Skip to content

Commit 7fabff9

Browse files
feat: Generic cos sin from rad filter (#180)
## Description This filter converts a variable in radian to the cosine and sine of the variable. ## What problem does this change solve? The available `cos_sin_mean_wave_direction` filter expects values in degree. ## What issue or task does this change relate to? ## Additional notes ## ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0801a4b commit 7fabff9

File tree

2 files changed

+257
-0
lines changed

2 files changed

+257
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
from collections.abc import Iterator
12+
from typing import Any
13+
14+
import earthkit.data as ekd
15+
import numpy as np
16+
17+
from anemoi.transform.filters import filter_registry
18+
from anemoi.transform.filters.matching import MatchingFieldsFilter
19+
from anemoi.transform.filters.matching import matching
20+
21+
22+
@filter_registry.register("cos_sin_from_rad")
23+
class CosSinFromRad(MatchingFieldsFilter):
24+
"""A filter to convert any variable in radians to cos() and sin() and back."""
25+
26+
@matching(
27+
select="param",
28+
forward=("param",),
29+
backward=("cos_param", "sin_param"),
30+
)
31+
def __init__(
32+
self,
33+
param: str,
34+
cos_param: str | None = None,
35+
sin_param: str | None = None,
36+
) -> None:
37+
"""Initialize the CosSinFromRad filter.
38+
39+
Parameters
40+
----------
41+
param : str
42+
The name of the variable.
43+
cos_param : str, optional
44+
The name of the cosine of the variable. Default is to prefix "cos_".
45+
sin_param : str, optional
46+
The name of the sine of the variable. Default is to prefix "sin_".
47+
"""
48+
49+
self.param = param
50+
self.cos_param = cos_param if cos_param is not None else f"cos_{param}"
51+
self.sin_param = sin_param if sin_param is not None else f"sin_{param}"
52+
53+
def forward_transform(
54+
self,
55+
param: ekd.Field,
56+
) -> Iterator[ekd.Field]:
57+
"""Convert a direction variable to its cosine and sine components.
58+
59+
Parameters
60+
----------
61+
param : ekd.Field
62+
The direction field.
63+
64+
Returns
65+
-------
66+
Iterator[ekd.Field]
67+
Fields of cosine and sine of the direction.
68+
"""
69+
data = param.to_numpy()
70+
if (min := data.min()) < -2 * np.pi:
71+
raise ValueError(f"Param {self.param} is expected in radians in the range [-2pi, pi], but {min=}")
72+
if (max := data.max()) > 2 * np.pi:
73+
raise ValueError(f"Param {self.param} is expected in radians in the range [-2pi, pi], but {max=}")
74+
75+
yield self.new_field_from_numpy(np.cos(data), template=param, param=self.cos_param)
76+
yield self.new_field_from_numpy(np.sin(data), template=param, param=self.sin_param)
77+
78+
def backward_transform(
79+
self,
80+
cos_param: ekd.Field,
81+
sin_param: ekd.Field,
82+
) -> Iterator[ekd.Field]:
83+
"""Convert cosine and sine components back to direction in radians in the range [-pi, pi).
84+
85+
Parameters
86+
----------
87+
cos_param : ekd.Field
88+
The cosine of the direction field.
89+
sin_param : ekd.Field
90+
The sine of the direction field.
91+
92+
Returns
93+
-------
94+
Iterator[ekd.Field]
95+
Field of the direction.
96+
"""
97+
direction = np.arctan2(sin_param.to_numpy(), cos_param.to_numpy())
98+
99+
yield self.new_field_from_numpy(direction, template=cos_param, param=self.param)
100+
101+
def patch_data_request(self, data_request: dict[str, Any]) -> dict[str, Any]:
102+
"""Modify the data request to include the direction.
103+
104+
Parameters
105+
----------
106+
data_request : Dict[str, Any]
107+
The original data request.
108+
109+
Returns
110+
-------
111+
Dict[str, Any]
112+
The modified data request.
113+
"""
114+
115+
param = data_request.get("param")
116+
if param is None:
117+
return data_request
118+
119+
if self.cos_param in param or self.sin_param in param:
120+
data_request["param"] = [p for p in param if p not in (self.cos_param, self.sin_param)]
121+
data_request["param"].append(self.param)
122+
123+
return data_request

tests/test_cos_sin_from_rad.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
import numpy as np
10+
import pytest
11+
12+
from anemoi.transform.filters import filter_registry
13+
14+
from .utils import assert_fields_equal
15+
from .utils import collect_fields_by_param
16+
17+
MOCK_FIELD_METADATA = {
18+
"latitudes": [10.0, 0.0, -10.0],
19+
"longitudes": [20, 40.0],
20+
"valid_datetime": "2018-08-01T09:00:00Z",
21+
}
22+
23+
RAD_VALUES = np.array([[2.67687254, 2.59108576], [1.83746659, 1.73104875], [1.1348185, 2.23051268]])
24+
25+
COS_RAD_VALUES = np.array([[-0.89394704, -0.85225947], [-0.26352086, -0.15956740], [0.42229696, -0.61289275]])
26+
27+
SIN_RAD_VALUES = np.array([[0.44817262, 0.52311930], [0.96465370, 0.98718704], [0.90645754, 0.79016611]])
28+
29+
30+
@pytest.fixture
31+
def RAD_source(test_source):
32+
RAD_SPEC = [
33+
{"param": "RAD", "values": RAD_VALUES, **MOCK_FIELD_METADATA},
34+
]
35+
return test_source(RAD_SPEC)
36+
37+
38+
@pytest.fixture
39+
def DEG_source(test_source):
40+
DEG_SPEC = [
41+
{"param": "DEG", "values": np.rad2deg(RAD_VALUES), **MOCK_FIELD_METADATA},
42+
]
43+
return test_source(DEG_SPEC)
44+
45+
46+
@pytest.fixture
47+
def cos_sin_RAD_source(test_source):
48+
COS_SIN_RAD = [
49+
{"param": "cos_RAD", "values": COS_RAD_VALUES, **MOCK_FIELD_METADATA},
50+
{"param": "sin_RAD", "values": SIN_RAD_VALUES, **MOCK_FIELD_METADATA},
51+
]
52+
return test_source(COS_SIN_RAD)
53+
54+
55+
def test_forward(RAD_source):
56+
"""Test the cos_sin_from_rad filter."""
57+
filter = filter_registry.create(
58+
"cos_sin_from_rad",
59+
param="RAD",
60+
)
61+
pipeline = RAD_source | filter
62+
63+
output_fields = collect_fields_by_param(pipeline)
64+
65+
assert set(output_fields) == {"cos_RAD", "sin_RAD"}
66+
assert len(output_fields["cos_RAD"]) == 1
67+
assert len(output_fields["sin_RAD"]) == 1
68+
69+
np.testing.assert_allclose(output_fields["cos_RAD"][0].to_numpy(), COS_RAD_VALUES)
70+
np.testing.assert_allclose(output_fields["sin_RAD"][0].to_numpy(), SIN_RAD_VALUES)
71+
72+
73+
def test_reverse(cos_sin_RAD_source):
74+
"""Test the cos_sin_from_rad filter in reverse."""
75+
filter = filter_registry.create(
76+
"cos_sin_from_rad",
77+
param="some_rad",
78+
cos_param="cos_RAD",
79+
sin_param="sin_RAD",
80+
).reverse()
81+
pipeline = cos_sin_RAD_source | filter
82+
83+
output_fields = collect_fields_by_param(pipeline)
84+
85+
assert set(output_fields) == {"some_rad"}
86+
assert len(output_fields["some_rad"]) == 1
87+
88+
np.testing.assert_allclose(output_fields["some_rad"][0].to_numpy(), RAD_VALUES)
89+
90+
91+
def test_round_trip(RAD_source):
92+
"""Test the cos_sin_from_rad filter reproduces inputs on a round trip."""
93+
filter = filter_registry.create(
94+
"cos_sin_from_rad",
95+
param="RAD",
96+
)
97+
cos_sin_RAD_source = RAD_source | filter
98+
pipeline = cos_sin_RAD_source | filter.reverse()
99+
100+
input_fields = collect_fields_by_param(RAD_source)
101+
intermediate_fields = collect_fields_by_param(cos_sin_RAD_source)
102+
output_fields = collect_fields_by_param(pipeline)
103+
104+
assert set(input_fields) == {"RAD"}
105+
assert set(intermediate_fields) == {"cos_RAD", "sin_RAD"}
106+
assert set(output_fields) == {"RAD"}
107+
108+
for input_field, output_field in zip(input_fields["RAD"], output_fields["RAD"]):
109+
assert_fields_equal(input_field, output_field)
110+
111+
112+
def test_exception(DEG_source):
113+
"""Test the cos_sin_from_rad exception.
114+
115+
Inpupt data in degrees.
116+
"""
117+
filter = filter_registry.create(
118+
"cos_sin_from_rad",
119+
param="DEG",
120+
)
121+
pipeline = DEG_source | filter
122+
123+
with pytest.raises(ValueError):
124+
collect_fields_by_param(pipeline)
125+
126+
127+
if __name__ == "__main__":
128+
"""
129+
Run all test functions that start with 'test_'.
130+
"""
131+
for name, obj in list(globals().items()):
132+
if name.startswith("test_") and callable(obj):
133+
print(f"Running {name}...")
134+
obj()

0 commit comments

Comments
 (0)