Skip to content
Merged
43 changes: 43 additions & 0 deletions graphs/docs/graphs/edge_attributes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,49 @@ latitude and longitude coordinates of the source and target nodes.
edge_length:
_target_: anemoi.graphs.edges.attributes.EdgeDirection

***********************
Directional Harmonics
***********************

The `Directional Harmonics` attribute computes harmonic features from
edge directions, providing a periodic encoding of the angle between
source and target nodes. For each order :math:`m` from 1 to the
specified maximum, it computes :math:`\sin(m\psi)` and
:math:`\cos(m\psi)` where :math:`\psi` is the edge direction angle.

.. code:: yaml

edges:
- source_name: ...
target_name: ...
edge_builders: ...
attributes:
dir_harmonics:
_target_: anemoi.graphs.edges.attributes.DirectionalHarmonics
order: 3

***********************
Radial Basis Features
***********************

The `Radial Basis Features` attribute computes Gaussian radial basis
function (RBF) features from edge distances. It evaluates a set of
Gaussian basis functions centered at different scaled distances. By
default, per-node adaptive scaling is used.

.. code:: yaml

edges:
- source_name: ...
target_name: ...
edge_builders: ...
attributes:
rbf_features:
_target_: anemoi.graphs.edges.attributes.RadialBasisFeatures
r_scale: auto
centers: [0.0, 0.25, 0.5, 0.75, 1.0]
sigma: 0.2

******************
Gaussian Weights
******************
Expand Down
20 changes: 18 additions & 2 deletions graphs/docs/graphs/edges/cutoff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ neighbourhood of the target nodes, :math:`V_{target}`.
:alt: Cut-off radius image
:align: center

The neighbourhood is defined by a `cut-off radius`, which is computed
as,
The neighbourhood is defined by a `cut-off radius`, which can be
specified in two mutually exclusive ways:

#. **Using cutoff_factor**: The radius is computed as the product of a
factor and the reference distance:

.. math::

Expand All @@ -33,19 +36,32 @@ where :math:`d(x, y)` is the `Haversine distance
that can be adjusted to increase or decrease the size of the
neighbourhood, and consequently the number of connections in the graph.

2. **Using cutoff_distance_km**: The radius can be directly specified in
kilometers, which is converted to the appropriate unit sphere
distance.

To use this method to create your connections, you can use the following
YAML configuration:

.. code:: yaml

edges:
# Using cutoff_factor (relative to grid reference distance)
- source_name: source
target_name: destination
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges
cutoff_factor: 0.6
# max_num_neighbours: 64

# Or using cutoff_distance_km (direct distance specification)
- source_name: source
target_name: destination
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges
cutoff_distance_km: 300.0
# max_num_neighbours: 64

The optional argument ``max_num_neighbours`` (default: 64) can be set to
limit the maximum number of neighbours each node can connect to.

Expand Down
15 changes: 14 additions & 1 deletion graphs/src/anemoi/graphs/describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def get_edge_summary(self) -> list[list]:
- Number of edges.
- Number of isolated source nodes.
- Number of isolated target nodes.
- Min edges per target node.
- Max edges per target node.
- Total dimension of the attributes.
- List of attribute names.
"""
Expand All @@ -91,13 +93,22 @@ def get_edge_summary(self) -> list[list]:
attributes = edges.edge_attrs()
attributes.remove("edge_index")

# Compute edge counts per target node
target_edge_counts = torch.bincount(edges.edge_index[1], minlength=self.graph[dst_name].num_nodes)
# Only consider nodes that have at least one edge
connected_target_counts = target_edge_counts[target_edge_counts > 0]
min_edges_per_target = connected_target_counts.min().item() if len(connected_target_counts) > 0 else 0
max_edges_per_target = connected_target_counts.max().item() if len(connected_target_counts) > 0 else 0

edge_summary.append(
[
src_name,
dst_name,
edges.num_edges,
self.graph[src_name].num_nodes - len(torch.unique(edges.edge_index[0])),
self.graph[dst_name].num_nodes - len(torch.unique(edges.edge_index[1])),
min_edges_per_target,
max_edges_per_target,
sum(edges[attr].shape[1] for attr in attributes),
", ".join([f"{attr}({edges[attr].shape[1]}D)" for attr in attributes]),
]
Expand Down Expand Up @@ -190,10 +201,12 @@ def describe(self, show_attribute_distributions: bool = True) -> None:
"Num. edges",
"Isolated Source",
"Isolated Target",
"Min edges/target",
"Max edges/target",
"Attribute dim",
"Attributes",
],
align=["<", "<", ">", ">", ">", ">", ">"],
align=["<", "<", ">", ">", ">", ">", ">", ">", ">"],
margin=3,
)
)
Expand Down
188 changes: 188 additions & 0 deletions graphs/src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch_geometric.typing import Adj
from torch_geometric.typing import PairTensor
from torch_geometric.typing import Size
from torch_geometric.utils import scatter

from anemoi.graphs.edges.directional import compute_directions
from anemoi.graphs.normalise import NormaliserMixin
Expand Down Expand Up @@ -99,6 +100,46 @@ def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
return edge_dirs


class DirectionalHarmonics(EdgeDirection):
"""Computes directional harmonics from edge directions.

Builds directional harmonics [sin(mψ), cos(mψ)]_{m=1..order} from per-edge
2D directions (dx, dy). Returns shape [N, 2*order].

Attributes
----------
order : int
The maximum order of harmonics to compute.
norm : str | None
Normalisation method. Options: None, "l1", "l2", "unit-max", "unit-range", "unit-std".

Methods
-------
compute(x_i, x_j)
Compute directional harmonics from edge directions.
"""

def __init__(self, order: int = 3, norm: str | None = None, dtype: str = "float32") -> None:
self.order = order
super().__init__(norm=norm, dtype=dtype)

def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
# Get the 2D direction vectors [dx, dy]
edge_dirs = compute_directions(x_i, x_j)

# Compute the angle ψ from the direction vectors
psi = torch.atan2(edge_dirs[:, 1], edge_dirs[:, 0]) # atan2(dy, dx)

# Build harmonics: [sin(ψ), cos(ψ), sin(2ψ), cos(2ψ), ..., sin(order*ψ), cos(order*ψ)]
harmonics = []
for m in range(1, self.order + 1):
harmonics.append(torch.sin(m * psi))
harmonics.append(torch.cos(m * psi))

# Stack into shape [N, 2*order]
return torch.stack(harmonics, dim=1)


class Azimuth(BasePositionalBuilder):
"""Compute the azimuth of the edge.

Expand Down Expand Up @@ -172,6 +213,153 @@ class AttributeFromTargetNode(BaseEdgeAttributeFromNodeBuilder):
nodes_axis = NodesAxis.TARGET


class RadialBasisFeatures(EdgeLength):
"""Radial basis features from edge distances using Gaussian RBFs.

Computes Gaussian radial basis function features from normalized great-circle distances:
phi_r = [exp(-((α - c)/σ)²) for c in centers], where α = r_ij / r_scale.

Provides RBF features via per-node adaptive scaling.
By default, each destination node's edges are normalized by that node's maximum edge length.
RBF features are normalized per target node per RBF center: within each RBF center,
all edges pointing to the same target node have values that sum to 1 (L1 norm).

Parameters
----------
r_scale : float | None, optional
Global scale factor for normalizing distances. Default is None.
If None: Use per-node adaptive scaling (max edge length per destination node).
If float: Use global scale for all nodes.
centers : list of float, optional
RBF center positions along normalized distance axis [0, 1].
Default is [0.0, 0.25, 0.5, 0.75, 1.0].
sigma : float, optional
Width (standard deviation) of Gaussian RBF functions. Default is 0.2.
Controls how localized each basis function is around its center.
epsilon : float, optional
Small constant to avoid division by zero. Default is 1e-10.
dtype : str, optional
Data type for computations. Default is "float32".

Note
----
RBF features are normalized per target node per RBF center.
Within each RBF center, all edges to the same target node sum to 1.

Methods
-------
compute(x_i, x_j)
Compute raw edge distances (RBF computation happens in aggregate).
aggregate(edge_features, index, ptr, dim_size)
Compute RBF features with adaptive scaling and per-target-node normalization.

Examples
--------
# Default: per-node adaptive scaling with grouped normalization
rbf = RadialBasisFeatures()

# To use global scale
rbf_global = RadialBasisFeatures(r_scale=1.0)

# Custom RBF centers and width
rbf_custom = RadialBasisFeatures(centers=[0.0, 0.33, 0.67, 1.0], sigma=0.15)

Notes
-----
- Closer edges → higher values at low-distance centers (0.0, 0.25)
- Farther edges → higher values at high-distance centers (0.75, 1.0)
"""

norm_by_group: bool = True # normalise the RBF features per destination node

def __init__(
self,
r_scale: float | None = None,
centers: list[float] | None = None,
sigma: float = 0.2,
norm: str = "l1",
epsilon: float = 1e-10,
dtype: str = "float32",
) -> None:
self.epsilon = epsilon
self.r_scale = r_scale

if self.r_scale is not None and self.r_scale < self.epsilon:
LOGGER.warning(
"r_scale (%f) is too small (< epsilon=%f). Clamping to epsilon to avoid division by zero.",
self.r_scale,
self.epsilon,
)
self.r_scale = self.epsilon

self.centers = centers if centers is not None else [0.0, 0.25, 0.5, 0.75, 1.0]

# Normalize centers if using global scaling
if self.r_scale is not None:
self.centers = [c / self.r_scale for c in self.centers]

# Check that centers are in the range [0, 1]
assert all(
0.0 <= c <= 1.0 for c in self.centers
), f"RBF centers must be in range [0, 1] (or [0, r_scale] if r_scale is set). Got centers: {centers}, r_scale: {r_scale}"

self.sigma = sigma
super().__init__(norm=norm, dtype=dtype)

def aggregate(self, edge_features: torch.Tensor, index: torch.Tensor, ptr=None, dim_size=None) -> torch.Tensor:
"""Aggregate edge features with per-node scaling and per-target-node normalization.

Parameters
----------
edge_features : torch.Tensor
Raw edge distances, shape [num_edges] or [num_edges, 1]
index : torch.Tensor
Destination node index for each edge
ptr : optional
CSR pointer (not used)
dim_size : int, optional
Number of destination nodes

Returns
-------
torch.Tensor
RBF features, shape [num_edges, num_centers].
Normalized per target node per RBF center .
"""
# Ensure edge_features is 1D
if edge_features.ndim == 2:
edge_features = edge_features.squeeze(-1)

# Compute scale factor per destination node
if self.r_scale is None:
# Per-node max edge length scaling
max_dists = scatter(edge_features, index.long(), dim=0, dim_size=dim_size, reduce="max")

# Clamp to epsilon to avoid division by zero
max_dists = torch.clamp(max_dists, min=self.epsilon)

# Broadcast to each edge
scales = max_dists[index]
alpha = edge_features / scales # Normalized distance [0, 1]
else:
# Global scaling
scales = torch.full_like(edge_features, self.r_scale)
alpha = edge_features / scales # Scaled distance [0, max_edge/r_scale]

# Compute Gaussian RBF for each center
rbf_features = []
for center in self.centers:
rbf = torch.exp(-(((alpha - center) / self.sigma) ** 2))
rbf_features.append(rbf)

rbf_features = torch.stack(rbf_features, dim=1)

# Within each RBF center, normalise edges to the same target node
rbf_features = self.normalise(rbf_features, index, dim_size)

return rbf_features


class GaussianDistanceWeights(EdgeLength):
"""Gaussian distance weights."""

Expand Down
Loading