|
18 | 18 | from torch_geometric.typing import Adj |
19 | 19 | from torch_geometric.typing import PairTensor |
20 | 20 | from torch_geometric.typing import Size |
| 21 | +from torch_geometric.utils import scatter |
21 | 22 |
|
22 | 23 | from anemoi.graphs.edges.directional import compute_directions |
23 | 24 | from anemoi.graphs.normalise import NormaliserMixin |
@@ -99,6 +100,46 @@ def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: |
99 | 100 | return edge_dirs |
100 | 101 |
|
101 | 102 |
|
| 103 | +class DirectionalHarmonics(EdgeDirection): |
| 104 | + """Computes directional harmonics from edge directions. |
| 105 | +
|
| 106 | + Builds directional harmonics [sin(mψ), cos(mψ)]_{m=1..order} from per-edge |
| 107 | + 2D directions (dx, dy). Returns shape [N, 2*order]. |
| 108 | +
|
| 109 | + Attributes |
| 110 | + ---------- |
| 111 | + order : int |
| 112 | + The maximum order of harmonics to compute. |
| 113 | + norm : str | None |
| 114 | + Normalisation method. Options: None, "l1", "l2", "unit-max", "unit-range", "unit-std". |
| 115 | +
|
| 116 | + Methods |
| 117 | + ------- |
| 118 | + compute(x_i, x_j) |
| 119 | + Compute directional harmonics from edge directions. |
| 120 | + """ |
| 121 | + |
| 122 | + def __init__(self, order: int = 3, norm: str | None = None, dtype: str = "float32") -> None: |
| 123 | + self.order = order |
| 124 | + super().__init__(norm=norm, dtype=dtype) |
| 125 | + |
| 126 | + def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: |
| 127 | + # Get the 2D direction vectors [dx, dy] |
| 128 | + edge_dirs = compute_directions(x_i, x_j) |
| 129 | + |
| 130 | + # Compute the angle ψ from the direction vectors |
| 131 | + psi = torch.atan2(edge_dirs[:, 1], edge_dirs[:, 0]) # atan2(dy, dx) |
| 132 | + |
| 133 | + # Build harmonics: [sin(ψ), cos(ψ), sin(2ψ), cos(2ψ), ..., sin(order*ψ), cos(order*ψ)] |
| 134 | + harmonics = [] |
| 135 | + for m in range(1, self.order + 1): |
| 136 | + harmonics.append(torch.sin(m * psi)) |
| 137 | + harmonics.append(torch.cos(m * psi)) |
| 138 | + |
| 139 | + # Stack into shape [N, 2*order] |
| 140 | + return torch.stack(harmonics, dim=1) |
| 141 | + |
| 142 | + |
102 | 143 | class Azimuth(BasePositionalBuilder): |
103 | 144 | """Compute the azimuth of the edge. |
104 | 145 |
|
@@ -172,6 +213,153 @@ class AttributeFromTargetNode(BaseEdgeAttributeFromNodeBuilder): |
172 | 213 | nodes_axis = NodesAxis.TARGET |
173 | 214 |
|
174 | 215 |
|
| 216 | +class RadialBasisFeatures(EdgeLength): |
| 217 | + """Radial basis features from edge distances using Gaussian RBFs. |
| 218 | +
|
| 219 | + Computes Gaussian radial basis function features from normalized great-circle distances: |
| 220 | + phi_r = [exp(-((α - c)/σ)²) for c in centers], where α = r_ij / r_scale. |
| 221 | +
|
| 222 | + Provides RBF features via per-node adaptive scaling. |
| 223 | + By default, each destination node's edges are normalized by that node's maximum edge length. |
| 224 | + RBF features are normalized per target node per RBF center: within each RBF center, |
| 225 | + all edges pointing to the same target node have values that sum to 1 (L1 norm). |
| 226 | +
|
| 227 | + Parameters |
| 228 | + ---------- |
| 229 | + r_scale : float | None, optional |
| 230 | + Global scale factor for normalizing distances. Default is None. |
| 231 | + If None: Use per-node adaptive scaling (max edge length per destination node). |
| 232 | + If float: Use global scale for all nodes. |
| 233 | + centers : list of float, optional |
| 234 | + RBF center positions along normalized distance axis [0, 1]. |
| 235 | + Default is [0.0, 0.25, 0.5, 0.75, 1.0]. |
| 236 | + sigma : float, optional |
| 237 | + Width (standard deviation) of Gaussian RBF functions. Default is 0.2. |
| 238 | + Controls how localized each basis function is around its center. |
| 239 | + epsilon : float, optional |
| 240 | + Small constant to avoid division by zero. Default is 1e-10. |
| 241 | + dtype : str, optional |
| 242 | + Data type for computations. Default is "float32". |
| 243 | +
|
| 244 | + Note |
| 245 | + ---- |
| 246 | + RBF features are normalized per target node per RBF center. |
| 247 | + Within each RBF center, all edges to the same target node sum to 1. |
| 248 | +
|
| 249 | + Methods |
| 250 | + ------- |
| 251 | + compute(x_i, x_j) |
| 252 | + Compute raw edge distances (RBF computation happens in aggregate). |
| 253 | + aggregate(edge_features, index, ptr, dim_size) |
| 254 | + Compute RBF features with adaptive scaling and per-target-node normalization. |
| 255 | +
|
| 256 | + Examples |
| 257 | + -------- |
| 258 | + # Default: per-node adaptive scaling with grouped normalization |
| 259 | + rbf = RadialBasisFeatures() |
| 260 | +
|
| 261 | + # To use global scale |
| 262 | + rbf_global = RadialBasisFeatures(r_scale=1.0) |
| 263 | +
|
| 264 | + # Custom RBF centers and width |
| 265 | + rbf_custom = RadialBasisFeatures(centers=[0.0, 0.33, 0.67, 1.0], sigma=0.15) |
| 266 | +
|
| 267 | + Notes |
| 268 | + ----- |
| 269 | + - Closer edges → higher values at low-distance centers (0.0, 0.25) |
| 270 | + - Farther edges → higher values at high-distance centers (0.75, 1.0) |
| 271 | + """ |
| 272 | + |
| 273 | + norm_by_group: bool = True # normalise the RBF features per destination node |
| 274 | + |
| 275 | + def __init__( |
| 276 | + self, |
| 277 | + r_scale: float | None = None, |
| 278 | + centers: list[float] | None = None, |
| 279 | + sigma: float = 0.2, |
| 280 | + norm: str = "l1", |
| 281 | + epsilon: float = 1e-10, |
| 282 | + dtype: str = "float32", |
| 283 | + ) -> None: |
| 284 | + self.epsilon = epsilon |
| 285 | + self.r_scale = r_scale |
| 286 | + |
| 287 | + if self.r_scale is not None and self.r_scale < self.epsilon: |
| 288 | + LOGGER.warning( |
| 289 | + "r_scale (%f) is too small (< epsilon=%f). Clamping to epsilon to avoid division by zero.", |
| 290 | + self.r_scale, |
| 291 | + self.epsilon, |
| 292 | + ) |
| 293 | + self.r_scale = self.epsilon |
| 294 | + |
| 295 | + self.centers = centers if centers is not None else [0.0, 0.25, 0.5, 0.75, 1.0] |
| 296 | + |
| 297 | + # Normalize centers if using global scaling |
| 298 | + if self.r_scale is not None: |
| 299 | + self.centers = [c / self.r_scale for c in self.centers] |
| 300 | + |
| 301 | + # Check that centers are in the range [0, 1] |
| 302 | + assert all( |
| 303 | + 0.0 <= c <= 1.0 for c in self.centers |
| 304 | + ), f"RBF centers must be in range [0, 1] (or [0, r_scale] if r_scale is set). Got centers: {centers}, r_scale: {r_scale}" |
| 305 | + |
| 306 | + self.sigma = sigma |
| 307 | + super().__init__(norm=norm, dtype=dtype) |
| 308 | + |
| 309 | + def aggregate(self, edge_features: torch.Tensor, index: torch.Tensor, ptr=None, dim_size=None) -> torch.Tensor: |
| 310 | + """Aggregate edge features with per-node scaling and per-target-node normalization. |
| 311 | +
|
| 312 | + Parameters |
| 313 | + ---------- |
| 314 | + edge_features : torch.Tensor |
| 315 | + Raw edge distances, shape [num_edges] or [num_edges, 1] |
| 316 | + index : torch.Tensor |
| 317 | + Destination node index for each edge |
| 318 | + ptr : optional |
| 319 | + CSR pointer (not used) |
| 320 | + dim_size : int, optional |
| 321 | + Number of destination nodes |
| 322 | +
|
| 323 | + Returns |
| 324 | + ------- |
| 325 | + torch.Tensor |
| 326 | + RBF features, shape [num_edges, num_centers]. |
| 327 | + Normalized per target node per RBF center . |
| 328 | + """ |
| 329 | + # Ensure edge_features is 1D |
| 330 | + if edge_features.ndim == 2: |
| 331 | + edge_features = edge_features.squeeze(-1) |
| 332 | + |
| 333 | + # Compute scale factor per destination node |
| 334 | + if self.r_scale is None: |
| 335 | + # Per-node max edge length scaling |
| 336 | + max_dists = scatter(edge_features, index.long(), dim=0, dim_size=dim_size, reduce="max") |
| 337 | + |
| 338 | + # Clamp to epsilon to avoid division by zero |
| 339 | + max_dists = torch.clamp(max_dists, min=self.epsilon) |
| 340 | + |
| 341 | + # Broadcast to each edge |
| 342 | + scales = max_dists[index] |
| 343 | + alpha = edge_features / scales # Normalized distance [0, 1] |
| 344 | + else: |
| 345 | + # Global scaling |
| 346 | + scales = torch.full_like(edge_features, self.r_scale) |
| 347 | + alpha = edge_features / scales # Scaled distance [0, max_edge/r_scale] |
| 348 | + |
| 349 | + # Compute Gaussian RBF for each center |
| 350 | + rbf_features = [] |
| 351 | + for center in self.centers: |
| 352 | + rbf = torch.exp(-(((alpha - center) / self.sigma) ** 2)) |
| 353 | + rbf_features.append(rbf) |
| 354 | + |
| 355 | + rbf_features = torch.stack(rbf_features, dim=1) |
| 356 | + |
| 357 | + # Within each RBF center, normalise edges to the same target node |
| 358 | + rbf_features = self.normalise(rbf_features, index, dim_size) |
| 359 | + |
| 360 | + return rbf_features |
| 361 | + |
| 362 | + |
175 | 363 | class GaussianDistanceWeights(EdgeLength): |
176 | 364 | """Gaussian distance weights.""" |
177 | 365 |
|
|
0 commit comments