Skip to content

Commit 8659de9

Browse files
ssmmnn11anaprietonempre-commit-ci[bot]
authored
feat(graphs): new edge attributes and faster graph cleaning (#617)
New edge attributes: Directional harmonics RadialBasisFeatures faster clean up of max number of neighbours (> 10x faster) allow for cutoff radius in km specification. <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--617.org.readthedocs.build/en/617/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--617.org.readthedocs.build/en/617/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--617.org.readthedocs.build/en/617/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: Ana Prieto Nemesio <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d0f775e commit 8659de9

File tree

9 files changed

+569
-45
lines changed

9 files changed

+569
-45
lines changed

graphs/docs/graphs/edge_attributes.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,49 @@ latitude and longitude coordinates of the source and target nodes.
4545
edge_length:
4646
_target_: anemoi.graphs.edges.attributes.EdgeDirection
4747
48+
***********************
49+
Directional Harmonics
50+
***********************
51+
52+
The `Directional Harmonics` attribute computes harmonic features from
53+
edge directions, providing a periodic encoding of the angle between
54+
source and target nodes. For each order :math:`m` from 1 to the
55+
specified maximum, it computes :math:`\sin(m\psi)` and
56+
:math:`\cos(m\psi)` where :math:`\psi` is the edge direction angle.
57+
58+
.. code:: yaml
59+
60+
edges:
61+
- source_name: ...
62+
target_name: ...
63+
edge_builders: ...
64+
attributes:
65+
dir_harmonics:
66+
_target_: anemoi.graphs.edges.attributes.DirectionalHarmonics
67+
order: 3
68+
69+
***********************
70+
Radial Basis Features
71+
***********************
72+
73+
The `Radial Basis Features` attribute computes Gaussian radial basis
74+
function (RBF) features from edge distances. It evaluates a set of
75+
Gaussian basis functions centered at different scaled distances. By
76+
default, per-node adaptive scaling is used.
77+
78+
.. code:: yaml
79+
80+
edges:
81+
- source_name: ...
82+
target_name: ...
83+
edge_builders: ...
84+
attributes:
85+
rbf_features:
86+
_target_: anemoi.graphs.edges.attributes.RadialBasisFeatures
87+
r_scale: auto
88+
centers: [0.0, 0.25, 0.5, 0.75, 1.0]
89+
sigma: 0.2
90+
4891
******************
4992
Gaussian Weights
5093
******************

graphs/docs/graphs/edges/cutoff.rst

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ neighbourhood of the target nodes, :math:`V_{target}`.
1313
:alt: Cut-off radius image
1414
:align: center
1515

16-
The neighbourhood is defined by a `cut-off radius`, which is computed
17-
as,
16+
The neighbourhood is defined by a `cut-off radius`, which can be
17+
specified in two mutually exclusive ways:
18+
19+
#. **Using cutoff_factor**: The radius is computed as the product of a
20+
factor and the reference distance:
1821

1922
.. math::
2023
@@ -33,19 +36,32 @@ where :math:`d(x, y)` is the `Haversine distance
3336
that can be adjusted to increase or decrease the size of the
3437
neighbourhood, and consequently the number of connections in the graph.
3538

39+
2. **Using cutoff_distance_km**: The radius can be directly specified in
40+
kilometers, which is converted to the appropriate unit sphere
41+
distance.
42+
3643
To use this method to create your connections, you can use the following
3744
YAML configuration:
3845

3946
.. code:: yaml
4047
4148
edges:
49+
# Using cutoff_factor (relative to grid reference distance)
4250
- source_name: source
4351
target_name: destination
4452
edge_builders:
4553
- _target_: anemoi.graphs.edges.CutOffEdges
4654
cutoff_factor: 0.6
4755
# max_num_neighbours: 64
4856
57+
# Or using cutoff_distance_km (direct distance specification)
58+
- source_name: source
59+
target_name: destination
60+
edge_builders:
61+
- _target_: anemoi.graphs.edges.CutOffEdges
62+
cutoff_distance_km: 300.0
63+
# max_num_neighbours: 64
64+
4965
The optional argument ``max_num_neighbours`` (default: 64) can be set to
5066
limit the maximum number of neighbours each node can connect to.
5167

graphs/src/anemoi/graphs/describe.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def get_edge_summary(self) -> list[list]:
8383
- Number of edges.
8484
- Number of isolated source nodes.
8585
- Number of isolated target nodes.
86+
- Min edges per target node.
87+
- Max edges per target node.
8688
- Total dimension of the attributes.
8789
- List of attribute names.
8890
"""
@@ -91,13 +93,22 @@ def get_edge_summary(self) -> list[list]:
9193
attributes = edges.edge_attrs()
9294
attributes.remove("edge_index")
9395

96+
# Compute edge counts per target node
97+
target_edge_counts = torch.bincount(edges.edge_index[1], minlength=self.graph[dst_name].num_nodes)
98+
# Only consider nodes that have at least one edge
99+
connected_target_counts = target_edge_counts[target_edge_counts > 0]
100+
min_edges_per_target = connected_target_counts.min().item() if len(connected_target_counts) > 0 else 0
101+
max_edges_per_target = connected_target_counts.max().item() if len(connected_target_counts) > 0 else 0
102+
94103
edge_summary.append(
95104
[
96105
src_name,
97106
dst_name,
98107
edges.num_edges,
99108
self.graph[src_name].num_nodes - len(torch.unique(edges.edge_index[0])),
100109
self.graph[dst_name].num_nodes - len(torch.unique(edges.edge_index[1])),
110+
min_edges_per_target,
111+
max_edges_per_target,
101112
sum(edges[attr].shape[1] for attr in attributes),
102113
", ".join([f"{attr}({edges[attr].shape[1]}D)" for attr in attributes]),
103114
]
@@ -190,10 +201,12 @@ def describe(self, show_attribute_distributions: bool = True) -> None:
190201
"Num. edges",
191202
"Isolated Source",
192203
"Isolated Target",
204+
"Min edges/target",
205+
"Max edges/target",
193206
"Attribute dim",
194207
"Attributes",
195208
],
196-
align=["<", "<", ">", ">", ">", ">", ">"],
209+
align=["<", "<", ">", ">", ">", ">", ">", ">", ">"],
197210
margin=3,
198211
)
199212
)

graphs/src/anemoi/graphs/edges/attributes.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch_geometric.typing import Adj
1919
from torch_geometric.typing import PairTensor
2020
from torch_geometric.typing import Size
21+
from torch_geometric.utils import scatter
2122

2223
from anemoi.graphs.edges.directional import compute_directions
2324
from anemoi.graphs.normalise import NormaliserMixin
@@ -99,6 +100,46 @@ def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
99100
return edge_dirs
100101

101102

103+
class DirectionalHarmonics(EdgeDirection):
104+
"""Computes directional harmonics from edge directions.
105+
106+
Builds directional harmonics [sin(mψ), cos(mψ)]_{m=1..order} from per-edge
107+
2D directions (dx, dy). Returns shape [N, 2*order].
108+
109+
Attributes
110+
----------
111+
order : int
112+
The maximum order of harmonics to compute.
113+
norm : str | None
114+
Normalisation method. Options: None, "l1", "l2", "unit-max", "unit-range", "unit-std".
115+
116+
Methods
117+
-------
118+
compute(x_i, x_j)
119+
Compute directional harmonics from edge directions.
120+
"""
121+
122+
def __init__(self, order: int = 3, norm: str | None = None, dtype: str = "float32") -> None:
123+
self.order = order
124+
super().__init__(norm=norm, dtype=dtype)
125+
126+
def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
127+
# Get the 2D direction vectors [dx, dy]
128+
edge_dirs = compute_directions(x_i, x_j)
129+
130+
# Compute the angle ψ from the direction vectors
131+
psi = torch.atan2(edge_dirs[:, 1], edge_dirs[:, 0]) # atan2(dy, dx)
132+
133+
# Build harmonics: [sin(ψ), cos(ψ), sin(2ψ), cos(2ψ), ..., sin(order*ψ), cos(order*ψ)]
134+
harmonics = []
135+
for m in range(1, self.order + 1):
136+
harmonics.append(torch.sin(m * psi))
137+
harmonics.append(torch.cos(m * psi))
138+
139+
# Stack into shape [N, 2*order]
140+
return torch.stack(harmonics, dim=1)
141+
142+
102143
class Azimuth(BasePositionalBuilder):
103144
"""Compute the azimuth of the edge.
104145
@@ -172,6 +213,153 @@ class AttributeFromTargetNode(BaseEdgeAttributeFromNodeBuilder):
172213
nodes_axis = NodesAxis.TARGET
173214

174215

216+
class RadialBasisFeatures(EdgeLength):
217+
"""Radial basis features from edge distances using Gaussian RBFs.
218+
219+
Computes Gaussian radial basis function features from normalized great-circle distances:
220+
phi_r = [exp(-((α - c)/σ)²) for c in centers], where α = r_ij / r_scale.
221+
222+
Provides RBF features via per-node adaptive scaling.
223+
By default, each destination node's edges are normalized by that node's maximum edge length.
224+
RBF features are normalized per target node per RBF center: within each RBF center,
225+
all edges pointing to the same target node have values that sum to 1 (L1 norm).
226+
227+
Parameters
228+
----------
229+
r_scale : float | None, optional
230+
Global scale factor for normalizing distances. Default is None.
231+
If None: Use per-node adaptive scaling (max edge length per destination node).
232+
If float: Use global scale for all nodes.
233+
centers : list of float, optional
234+
RBF center positions along normalized distance axis [0, 1].
235+
Default is [0.0, 0.25, 0.5, 0.75, 1.0].
236+
sigma : float, optional
237+
Width (standard deviation) of Gaussian RBF functions. Default is 0.2.
238+
Controls how localized each basis function is around its center.
239+
epsilon : float, optional
240+
Small constant to avoid division by zero. Default is 1e-10.
241+
dtype : str, optional
242+
Data type for computations. Default is "float32".
243+
244+
Note
245+
----
246+
RBF features are normalized per target node per RBF center.
247+
Within each RBF center, all edges to the same target node sum to 1.
248+
249+
Methods
250+
-------
251+
compute(x_i, x_j)
252+
Compute raw edge distances (RBF computation happens in aggregate).
253+
aggregate(edge_features, index, ptr, dim_size)
254+
Compute RBF features with adaptive scaling and per-target-node normalization.
255+
256+
Examples
257+
--------
258+
# Default: per-node adaptive scaling with grouped normalization
259+
rbf = RadialBasisFeatures()
260+
261+
# To use global scale
262+
rbf_global = RadialBasisFeatures(r_scale=1.0)
263+
264+
# Custom RBF centers and width
265+
rbf_custom = RadialBasisFeatures(centers=[0.0, 0.33, 0.67, 1.0], sigma=0.15)
266+
267+
Notes
268+
-----
269+
- Closer edges → higher values at low-distance centers (0.0, 0.25)
270+
- Farther edges → higher values at high-distance centers (0.75, 1.0)
271+
"""
272+
273+
norm_by_group: bool = True # normalise the RBF features per destination node
274+
275+
def __init__(
276+
self,
277+
r_scale: float | None = None,
278+
centers: list[float] | None = None,
279+
sigma: float = 0.2,
280+
norm: str = "l1",
281+
epsilon: float = 1e-10,
282+
dtype: str = "float32",
283+
) -> None:
284+
self.epsilon = epsilon
285+
self.r_scale = r_scale
286+
287+
if self.r_scale is not None and self.r_scale < self.epsilon:
288+
LOGGER.warning(
289+
"r_scale (%f) is too small (< epsilon=%f). Clamping to epsilon to avoid division by zero.",
290+
self.r_scale,
291+
self.epsilon,
292+
)
293+
self.r_scale = self.epsilon
294+
295+
self.centers = centers if centers is not None else [0.0, 0.25, 0.5, 0.75, 1.0]
296+
297+
# Normalize centers if using global scaling
298+
if self.r_scale is not None:
299+
self.centers = [c / self.r_scale for c in self.centers]
300+
301+
# Check that centers are in the range [0, 1]
302+
assert all(
303+
0.0 <= c <= 1.0 for c in self.centers
304+
), f"RBF centers must be in range [0, 1] (or [0, r_scale] if r_scale is set). Got centers: {centers}, r_scale: {r_scale}"
305+
306+
self.sigma = sigma
307+
super().__init__(norm=norm, dtype=dtype)
308+
309+
def aggregate(self, edge_features: torch.Tensor, index: torch.Tensor, ptr=None, dim_size=None) -> torch.Tensor:
310+
"""Aggregate edge features with per-node scaling and per-target-node normalization.
311+
312+
Parameters
313+
----------
314+
edge_features : torch.Tensor
315+
Raw edge distances, shape [num_edges] or [num_edges, 1]
316+
index : torch.Tensor
317+
Destination node index for each edge
318+
ptr : optional
319+
CSR pointer (not used)
320+
dim_size : int, optional
321+
Number of destination nodes
322+
323+
Returns
324+
-------
325+
torch.Tensor
326+
RBF features, shape [num_edges, num_centers].
327+
Normalized per target node per RBF center .
328+
"""
329+
# Ensure edge_features is 1D
330+
if edge_features.ndim == 2:
331+
edge_features = edge_features.squeeze(-1)
332+
333+
# Compute scale factor per destination node
334+
if self.r_scale is None:
335+
# Per-node max edge length scaling
336+
max_dists = scatter(edge_features, index.long(), dim=0, dim_size=dim_size, reduce="max")
337+
338+
# Clamp to epsilon to avoid division by zero
339+
max_dists = torch.clamp(max_dists, min=self.epsilon)
340+
341+
# Broadcast to each edge
342+
scales = max_dists[index]
343+
alpha = edge_features / scales # Normalized distance [0, 1]
344+
else:
345+
# Global scaling
346+
scales = torch.full_like(edge_features, self.r_scale)
347+
alpha = edge_features / scales # Scaled distance [0, max_edge/r_scale]
348+
349+
# Compute Gaussian RBF for each center
350+
rbf_features = []
351+
for center in self.centers:
352+
rbf = torch.exp(-(((alpha - center) / self.sigma) ** 2))
353+
rbf_features.append(rbf)
354+
355+
rbf_features = torch.stack(rbf_features, dim=1)
356+
357+
# Within each RBF center, normalise edges to the same target node
358+
rbf_features = self.normalise(rbf_features, index, dim_size)
359+
360+
return rbf_features
361+
362+
175363
class GaussianDistanceWeights(EdgeLength):
176364
"""Gaussian distance weights."""
177365

0 commit comments

Comments
 (0)