Skip to content

Commit a017784

Browse files
JPXKQXpre-commit-ci[bot]anaprietonem
authored andcommitted
feat(graphs): add LimitedAreaMask for stretched hidden nodes (#671)
## Description <!-- What issue or task does this change relate to? --> This PR add a new node attribute builder, `LimitedAreaMask`. - Schemas updated - Documentation updated - fix: plot_interactive_nodes_2d. Update the argument `titlefont_size` to a newer version of plotly. <img width="1905" height="966" alt="Screenshot 2025-11-13 at 12 37 36" src="https://github.com/user-attachments/assets/0a2e6810-751d-4966-a6bd-f54f9456a63e" /> [hidden_nodes.html](https://github.com/user-attachments/files/23524994/hidden_nodes.html) ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--671.org.readthedocs.build/en/671/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--671.org.readthedocs.build/en/671/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--671.org.readthedocs.build/en/671/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent 94b5362 commit a017784

File tree

8 files changed

+65
-14
lines changed

8 files changed

+65
-14
lines changed

graphs/docs/graphs/node_attributes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ connections only between subsets of nodes.
2121

2222
node_attributes/weights
2323
node_attributes/anemoi_dataset_attribute
24+
node_attributes/area_masks
2425

2526
Additionally, different boolean operations have been implemented to
2627
support more complex use cases:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
###########
2+
Area mask
3+
###########
4+
5+
The `LimitedAreaMask` node attribute builder creates a mask over the
6+
nodes covering the limited area.
7+
8+
The configuration for these masks, is specified in the YAML file:
9+
10+
.. literalinclude:: ../yaml/attributes_lam_mask.yaml
11+
:language: yaml
12+
13+
.. note::
14+
15+
This node attribute builder is only supported for nodes created using
16+
subclasses of ``StretchedIcosahedronNodes``. Currently, it is
17+
available exclusively for nodes built with the ``StretchedTriNodes``
18+
subclass.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
nodes:
2+
data: ...
3+
hidden:
4+
node_builder:
5+
_target_: anemoi.graphs.nodes.StretchedTriNodes
6+
# ...
7+
attributes:
8+
cutout_mask:
9+
_target_: anemoi.graphs.nodes.attributes.LimitedAreaMask

graphs/src/anemoi/graphs/nodes/attributes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .boolean_op import BooleanOrMask
1919
from .masks import CutOutMask
2020
from .masks import GridsMask
21+
from .masks import LimitedAreaMask
2122
from .masks import NonmissingAnemoiDatasetVariable
2223
from .masks import NonzeroAnemoiDatasetVariable
2324

@@ -27,6 +28,7 @@
2728
"PlanarAreaWeights",
2829
"UniformWeights",
2930
"CutOutMask",
31+
"LimitedAreaMask",
3032
"MaskedPlanarAreaWeights",
3133
"NonmissingAnemoiDatasetVariable",
3234
"NonzeroAnemoiDatasetVariable",

graphs/src/anemoi/graphs/nodes/attributes/masks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,23 @@ class GridsMask(BaseCombineAnemoiDatasetsMask):
162162
def __init__(self, grids: int | list[int] = 0) -> None:
163163
self.grids = [grids] if isinstance(grids, int) else grids
164164
super().__init__()
165+
166+
167+
class LimitedAreaMask(BooleanBaseNodeAttribute):
168+
"""Limited area mask.
169+
170+
It adds a mask based on an area of interest. This mask is only defined
171+
for nodes built with a subclass of `StretchedIcosahedronNodes`.
172+
173+
Methods
174+
-------
175+
compute(self, graph, nodes_name)
176+
Compute the attribute for each node.
177+
"""
178+
179+
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> torch.Tensor:
180+
assert nodes["node_type"] in [
181+
"StretchedTriNodes"
182+
], f"{self.__class__.__name__} can only be used with StretchedIcosahedronNodes."
183+
lam_mask = nodes["_area_mask_builder"].get_mask(nodes.x)
184+
return torch.from_numpy(lam_mask)

graphs/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(
4747
"resolutions",
4848
"nx_graph",
4949
"node_ordering",
50-
"area_mask_builder",
5150
}
5251
if not hasattr(self, "multi_scale_edge_cls"):
5352
raise AttributeError("Classes inheriting from IcosahedralNodes must set 'multi_scale_edge_cls' attribute.")
@@ -86,6 +85,7 @@ def __init__(
8685
) -> None:
8786

8887
super().__init__(resolution, name)
88+
self.hidden_attributes = self.hidden_attributes | {"area_mask_builder"}
8989

9090
self.area_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name)
9191

@@ -160,7 +160,7 @@ def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]:
160160
return create_hex_nodes(resolution=max(self.resolutions), area_mask_builder=self.area_mask_builder)
161161

162162

163-
class StretchedIcosahedronNodes(IcosahedralNodes, ABC):
163+
class StretchedIcosahedronNodes(LimitedAreaIcosahedralNodes, ABC):
164164
"""Nodes based on iterative refinements of an icosahedron with 2
165165
different resolutions.
166166
@@ -176,19 +176,18 @@ def __init__(
176176
lam_resolution: int,
177177
name: str,
178178
reference_node_name: str,
179-
mask_attr_name: str,
179+
mask_attr_name: str | None = None,
180180
margin_radius_km: float = 100.0,
181181
) -> None:
182-
183-
super().__init__(lam_resolution, name)
182+
super().__init__(
183+
resolution=lam_resolution,
184+
reference_node_name=reference_node_name,
185+
mask_attr_name=mask_attr_name,
186+
margin_radius_km=margin_radius_km,
187+
name=name,
188+
)
184189
self.global_resolution = global_resolution
185190

186-
self.area_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name)
187-
188-
def register_nodes(self, graph: HeteroData) -> None:
189-
self.area_mask_builder.fit(graph)
190-
return super().register_nodes(graph)
191-
192191

193192
class StretchedTriNodes(StretchedIcosahedronNodes):
194193
"""Nodes based on iterative refinements of an icosahedron with 2

graphs/src/anemoi/graphs/plotting/interactive_2d_html.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def plot_interactive_nodes_2d(graph: HeteroData, nodes_name: str, out_file: Opti
237237
sliders=[
238238
dict(active=0, currentvalue={"visible": False}, len=0.4, x=0.5, xanchor="center", steps=slider_steps)
239239
],
240-
titlefont_size=16,
240+
title_font_size=16,
241241
showlegend=False,
242242
hovermode="closest",
243243
margin={"b": 20, "l": 5, "r": 5, "t": 40},

graphs/src/anemoi/graphs/schemas/node_attributes_schemas.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ class SphericalAreaWeightSchema(BaseModel):
4949

5050

5151
class CutOutMaskSchema(BaseModel):
52-
target_: Literal["anemoi.graphs.nodes.attributes.CutOutMask"] = Field(..., alias="_target_")
53-
"Implementation of the cutout mask from anemoi.graphs.nodes.attributes."
52+
target_: Literal["anemoi.graphs.nodes.attributes.CutOutMask", "anemoi.graphs.nodes.attributes.LimitedAreaMask"] = (
53+
Field(..., alias="_target_")
54+
)
55+
"Implementation of the area masks from anemoi.graphs.nodes.attributes."
5456

5557

5658
class GridsMaskSchema(BaseModel):

0 commit comments

Comments
 (0)