diff --git a/graphs/docs/graphs/edge_attributes.rst b/graphs/docs/graphs/edge_attributes.rst index 2ec435d8e..2d731f97b 100644 --- a/graphs/docs/graphs/edge_attributes.rst +++ b/graphs/docs/graphs/edge_attributes.rst @@ -45,6 +45,49 @@ latitude and longitude coordinates of the source and target nodes. edge_length: _target_: anemoi.graphs.edges.attributes.EdgeDirection +*********************** + Directional Harmonics +*********************** + +The `Directional Harmonics` attribute computes harmonic features from +edge directions, providing a periodic encoding of the angle between +source and target nodes. For each order :math:`m` from 1 to the +specified maximum, it computes :math:`\sin(m\psi)` and +:math:`\cos(m\psi)` where :math:`\psi` is the edge direction angle. + +.. code:: yaml + + edges: + - source_name: ... + target_name: ... + edge_builders: ... + attributes: + dir_harmonics: + _target_: anemoi.graphs.edges.attributes.DirectionalHarmonics + order: 3 + +*********************** + Radial Basis Features +*********************** + +The `Radial Basis Features` attribute computes Gaussian radial basis +function (RBF) features from edge distances. It evaluates a set of +Gaussian basis functions centered at different scaled distances. By +default, per-node adaptive scaling is used. + +.. code:: yaml + + edges: + - source_name: ... + target_name: ... + edge_builders: ... + attributes: + rbf_features: + _target_: anemoi.graphs.edges.attributes.RadialBasisFeatures + r_scale: auto + centers: [0.0, 0.25, 0.5, 0.75, 1.0] + sigma: 0.2 + ****************** Gaussian Weights ****************** diff --git a/graphs/docs/graphs/edges/cutoff.rst b/graphs/docs/graphs/edges/cutoff.rst index 19ff56c25..6e1bb132b 100644 --- a/graphs/docs/graphs/edges/cutoff.rst +++ b/graphs/docs/graphs/edges/cutoff.rst @@ -13,8 +13,11 @@ neighbourhood of the target nodes, :math:`V_{target}`. :alt: Cut-off radius image :align: center -The neighbourhood is defined by a `cut-off radius`, which is computed -as, +The neighbourhood is defined by a `cut-off radius`, which can be +specified in two mutually exclusive ways: + +#. **Using cutoff_factor**: The radius is computed as the product of a + factor and the reference distance: .. math:: @@ -33,12 +36,17 @@ where :math:`d(x, y)` is the `Haversine distance that can be adjusted to increase or decrease the size of the neighbourhood, and consequently the number of connections in the graph. +2. **Using cutoff_distance_km**: The radius can be directly specified in + kilometers, which is converted to the appropriate unit sphere + distance. + To use this method to create your connections, you can use the following YAML configuration: .. code:: yaml edges: + # Using cutoff_factor (relative to grid reference distance) - source_name: source target_name: destination edge_builders: @@ -46,6 +54,14 @@ YAML configuration: cutoff_factor: 0.6 # max_num_neighbours: 64 + # Or using cutoff_distance_km (direct distance specification) + - source_name: source + target_name: destination + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges + cutoff_distance_km: 300.0 + # max_num_neighbours: 64 + The optional argument ``max_num_neighbours`` (default: 64) can be set to limit the maximum number of neighbours each node can connect to. diff --git a/graphs/src/anemoi/graphs/describe.py b/graphs/src/anemoi/graphs/describe.py index f2db18a18..75db182d8 100644 --- a/graphs/src/anemoi/graphs/describe.py +++ b/graphs/src/anemoi/graphs/describe.py @@ -83,6 +83,8 @@ def get_edge_summary(self) -> list[list]: - Number of edges. - Number of isolated source nodes. - Number of isolated target nodes. + - Min edges per target node. + - Max edges per target node. - Total dimension of the attributes. - List of attribute names. """ @@ -91,6 +93,13 @@ def get_edge_summary(self) -> list[list]: attributes = edges.edge_attrs() attributes.remove("edge_index") + # Compute edge counts per target node + target_edge_counts = torch.bincount(edges.edge_index[1], minlength=self.graph[dst_name].num_nodes) + # Only consider nodes that have at least one edge + connected_target_counts = target_edge_counts[target_edge_counts > 0] + min_edges_per_target = connected_target_counts.min().item() if len(connected_target_counts) > 0 else 0 + max_edges_per_target = connected_target_counts.max().item() if len(connected_target_counts) > 0 else 0 + edge_summary.append( [ src_name, @@ -98,6 +107,8 @@ def get_edge_summary(self) -> list[list]: edges.num_edges, self.graph[src_name].num_nodes - len(torch.unique(edges.edge_index[0])), self.graph[dst_name].num_nodes - len(torch.unique(edges.edge_index[1])), + min_edges_per_target, + max_edges_per_target, sum(edges[attr].shape[1] for attr in attributes), ", ".join([f"{attr}({edges[attr].shape[1]}D)" for attr in attributes]), ] @@ -190,10 +201,12 @@ def describe(self, show_attribute_distributions: bool = True) -> None: "Num. edges", "Isolated Source", "Isolated Target", + "Min edges/target", + "Max edges/target", "Attribute dim", "Attributes", ], - align=["<", "<", ">", ">", ">", ">", ">"], + align=["<", "<", ">", ">", ">", ">", ">", ">", ">"], margin=3, ) ) diff --git a/graphs/src/anemoi/graphs/edges/attributes.py b/graphs/src/anemoi/graphs/edges/attributes.py index cf5c5035e..ad6e64179 100644 --- a/graphs/src/anemoi/graphs/edges/attributes.py +++ b/graphs/src/anemoi/graphs/edges/attributes.py @@ -18,6 +18,7 @@ from torch_geometric.typing import Adj from torch_geometric.typing import PairTensor from torch_geometric.typing import Size +from torch_geometric.utils import scatter from anemoi.graphs.edges.directional import compute_directions from anemoi.graphs.normalise import NormaliserMixin @@ -99,6 +100,46 @@ def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: return edge_dirs +class DirectionalHarmonics(EdgeDirection): + """Computes directional harmonics from edge directions. + + Builds directional harmonics [sin(mψ), cos(mψ)]_{m=1..order} from per-edge + 2D directions (dx, dy). Returns shape [N, 2*order]. + + Attributes + ---------- + order : int + The maximum order of harmonics to compute. + norm : str | None + Normalisation method. Options: None, "l1", "l2", "unit-max", "unit-range", "unit-std". + + Methods + ------- + compute(x_i, x_j) + Compute directional harmonics from edge directions. + """ + + def __init__(self, order: int = 3, norm: str | None = None, dtype: str = "float32") -> None: + self.order = order + super().__init__(norm=norm, dtype=dtype) + + def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: + # Get the 2D direction vectors [dx, dy] + edge_dirs = compute_directions(x_i, x_j) + + # Compute the angle ψ from the direction vectors + psi = torch.atan2(edge_dirs[:, 1], edge_dirs[:, 0]) # atan2(dy, dx) + + # Build harmonics: [sin(ψ), cos(ψ), sin(2ψ), cos(2ψ), ..., sin(order*ψ), cos(order*ψ)] + harmonics = [] + for m in range(1, self.order + 1): + harmonics.append(torch.sin(m * psi)) + harmonics.append(torch.cos(m * psi)) + + # Stack into shape [N, 2*order] + return torch.stack(harmonics, dim=1) + + class Azimuth(BasePositionalBuilder): """Compute the azimuth of the edge. @@ -172,6 +213,153 @@ class AttributeFromTargetNode(BaseEdgeAttributeFromNodeBuilder): nodes_axis = NodesAxis.TARGET +class RadialBasisFeatures(EdgeLength): + """Radial basis features from edge distances using Gaussian RBFs. + + Computes Gaussian radial basis function features from normalized great-circle distances: + phi_r = [exp(-((α - c)/σ)²) for c in centers], where α = r_ij / r_scale. + + Provides RBF features via per-node adaptive scaling. + By default, each destination node's edges are normalized by that node's maximum edge length. + RBF features are normalized per target node per RBF center: within each RBF center, + all edges pointing to the same target node have values that sum to 1 (L1 norm). + + Parameters + ---------- + r_scale : float | None, optional + Global scale factor for normalizing distances. Default is None. + If None: Use per-node adaptive scaling (max edge length per destination node). + If float: Use global scale for all nodes. + centers : list of float, optional + RBF center positions along normalized distance axis [0, 1]. + Default is [0.0, 0.25, 0.5, 0.75, 1.0]. + sigma : float, optional + Width (standard deviation) of Gaussian RBF functions. Default is 0.2. + Controls how localized each basis function is around its center. + epsilon : float, optional + Small constant to avoid division by zero. Default is 1e-10. + dtype : str, optional + Data type for computations. Default is "float32". + + Note + ---- + RBF features are normalized per target node per RBF center. + Within each RBF center, all edges to the same target node sum to 1. + + Methods + ------- + compute(x_i, x_j) + Compute raw edge distances (RBF computation happens in aggregate). + aggregate(edge_features, index, ptr, dim_size) + Compute RBF features with adaptive scaling and per-target-node normalization. + + Examples + -------- + # Default: per-node adaptive scaling with grouped normalization + rbf = RadialBasisFeatures() + + # To use global scale + rbf_global = RadialBasisFeatures(r_scale=1.0) + + # Custom RBF centers and width + rbf_custom = RadialBasisFeatures(centers=[0.0, 0.33, 0.67, 1.0], sigma=0.15) + + Notes + ----- + - Closer edges → higher values at low-distance centers (0.0, 0.25) + - Farther edges → higher values at high-distance centers (0.75, 1.0) + """ + + norm_by_group: bool = True # normalise the RBF features per destination node + + def __init__( + self, + r_scale: float | None = None, + centers: list[float] | None = None, + sigma: float = 0.2, + norm: str = "l1", + epsilon: float = 1e-10, + dtype: str = "float32", + ) -> None: + self.epsilon = epsilon + self.r_scale = r_scale + + if self.r_scale is not None and self.r_scale < self.epsilon: + LOGGER.warning( + "r_scale (%f) is too small (< epsilon=%f). Clamping to epsilon to avoid division by zero.", + self.r_scale, + self.epsilon, + ) + self.r_scale = self.epsilon + + self.centers = centers if centers is not None else [0.0, 0.25, 0.5, 0.75, 1.0] + + # Normalize centers if using global scaling + if self.r_scale is not None: + self.centers = [c / self.r_scale for c in self.centers] + + # Check that centers are in the range [0, 1] + assert all( + 0.0 <= c <= 1.0 for c in self.centers + ), f"RBF centers must be in range [0, 1] (or [0, r_scale] if r_scale is set). Got centers: {centers}, r_scale: {r_scale}" + + self.sigma = sigma + super().__init__(norm=norm, dtype=dtype) + + def aggregate(self, edge_features: torch.Tensor, index: torch.Tensor, ptr=None, dim_size=None) -> torch.Tensor: + """Aggregate edge features with per-node scaling and per-target-node normalization. + + Parameters + ---------- + edge_features : torch.Tensor + Raw edge distances, shape [num_edges] or [num_edges, 1] + index : torch.Tensor + Destination node index for each edge + ptr : optional + CSR pointer (not used) + dim_size : int, optional + Number of destination nodes + + Returns + ------- + torch.Tensor + RBF features, shape [num_edges, num_centers]. + Normalized per target node per RBF center . + """ + # Ensure edge_features is 1D + if edge_features.ndim == 2: + edge_features = edge_features.squeeze(-1) + + # Compute scale factor per destination node + if self.r_scale is None: + # Per-node max edge length scaling + max_dists = scatter(edge_features, index.long(), dim=0, dim_size=dim_size, reduce="max") + + # Clamp to epsilon to avoid division by zero + max_dists = torch.clamp(max_dists, min=self.epsilon) + + # Broadcast to each edge + scales = max_dists[index] + alpha = edge_features / scales # Normalized distance [0, 1] + else: + # Global scaling + scales = torch.full_like(edge_features, self.r_scale) + alpha = edge_features / scales # Scaled distance [0, max_edge/r_scale] + + # Compute Gaussian RBF for each center + rbf_features = [] + for center in self.centers: + rbf = torch.exp(-(((alpha - center) / self.sigma) ** 2)) + rbf_features.append(rbf) + + rbf_features = torch.stack(rbf_features, dim=1) + + # Within each RBF center, normalise edges to the same target node + rbf_features = self.normalise(rbf_features, index, dim_size) + + return rbf_features + + class GaussianDistanceWeights(EdgeLength): """Gaussian distance weights.""" diff --git a/graphs/src/anemoi/graphs/edges/builders/cutoff.py b/graphs/src/anemoi/graphs/edges/builders/cutoff.py index 0f8cc07ac..3fba45903 100644 --- a/graphs/src/anemoi/graphs/edges/builders/cutoff.py +++ b/graphs/src/anemoi/graphs/edges/builders/cutoff.py @@ -40,8 +40,11 @@ class CutOffEdges(BaseDistanceEdgeBuilders): The name of the source nodes. target_name : str The name of the target nodes. - cutoff_factor : float + cutoff_factor : float | None Factor to multiply the grid reference distance to get the cut-off radius. + Mutually exclusive with cutoff_distance_km. + cutoff_distance_km : float | None + Cutoff radius in kilometers. Mutually exclusive with cutoff_factor. source_mask_attr_name : str | None The name of the source mask attribute to filter edge connections. target_mask_attr_name : str | None @@ -63,17 +66,33 @@ def __init__( self, source_name: str, target_name: str, - cutoff_factor: float, + cutoff_factor: float | None = None, + cutoff_distance_km: float | None = None, source_mask_attr_name: str | None = None, target_mask_attr_name: str | None = None, max_num_neighbours: int = 64, ) -> None: super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) - assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float." + + # Validate that exactly one of cutoff_factor or cutoff_distance_km is provided + if cutoff_factor is None and cutoff_distance_km is None: + raise ValueError("Either cutoff_factor or cutoff_distance_km must be provided.") + if cutoff_factor is not None and cutoff_distance_km is not None: + raise ValueError("cutoff_factor and cutoff_distance_km are mutually exclusive. Provide only one.") + + if cutoff_factor is not None: + assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float." + assert cutoff_factor > 0, "Cutoff factor must be positive." + + if cutoff_distance_km is not None: + assert isinstance(cutoff_distance_km, (int, float)), "Cutoff distance must be a float." + assert cutoff_distance_km > 0, "Cutoff distance must be positive." + assert isinstance(max_num_neighbours, int), "Number of nearest neighbours must be an integer." - assert cutoff_factor > 0, "Cutoff factor must be positive." assert max_num_neighbours > 0, "Number of nearest neighbours must be positive." + self.cutoff_factor = cutoff_factor + self.cutoff_distance_km = cutoff_distance_km self.max_num_neighbours = max_num_neighbours @staticmethod @@ -104,8 +123,9 @@ def get_reference_distance(nodes: NodeStorage, mask_attr_name: torch.Tensor | No def get_cutoff_radius(self, graph: HeteroData): """Compute the cut-off radius. - The cut-off radius is computed as the product of the target nodes - reference distance and the cut-off factor. + The cut-off radius is computed either as: + - The product of the target nodes reference distance and the cut-off factor, or + - Directly from the cutoff distance in kilometers. Parameters ---------- @@ -115,12 +135,20 @@ def get_cutoff_radius(self, graph: HeteroData): Returns ------- float - The cut-off radius. + The cut-off radius in Cartesian coordinates on unit sphere. """ - reference_dist = CutOffEdges.get_reference_distance( - graph[self.target_name], mask_attr_name=self.target_mask_attr_name - ) - return reference_dist * self.cutoff_factor + if self.cutoff_distance_km is not None: + # Convert km to Cartesian distance on unit sphere + # For small distances: Cartesian distance ≈ great circle distance (radians) + # radians = km / EARTH_RADIUS + radius = self.cutoff_distance_km / EARTH_RADIUS + else: + # Use factor-based approach + reference_dist = CutOffEdges.get_reference_distance( + graph[self.target_name], mask_attr_name=self.target_mask_attr_name + ) + radius = reference_dist * self.cutoff_factor + return radius def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: """Prepare node information and get source and target nodes.""" @@ -136,7 +164,7 @@ def _compute_edge_index_pyg(self, source_coords: torch.Tensor, target_coords: to def _crop_to_max_num_neighbours(self, adjmat): """Remove neighbors exceeding the maximum allowed limit.""" nodes_to_drop = np.maximum(np.bincount(adjmat.row) - self.max_num_neighbours, 0) - if num_nodes_to_drop := nodes_to_drop.sum() == 0: + if (num_nodes_to_drop := nodes_to_drop.sum()) == 0: return adjmat LOGGER.info( @@ -146,14 +174,26 @@ def _crop_to_max_num_neighbours(self, adjmat): self.max_num_neighbours, ) - # Compute indices to remove - mask = np.ones(adjmat.nnz, dtype=bool) - node_idx = np.where(nodes_to_drop > 0)[0] - for node_id in node_idx: - indices_of_largest_dist = np.argpartition(adjmat.data[adjmat.row == node_id], -nodes_to_drop[node_id])[ - -nodes_to_drop[node_id] : - ] - mask[np.where(adjmat.row == node_id)[0][indices_of_largest_dist]] = False + # Vectorized approach: sort edges by (row, distance) to group by node + # no repeated O(nnz) scans in a loop + sort_idx = np.lexsort((adjmat.data, adjmat.row)) + sorted_rows = adjmat.row[sort_idx] + + # Find where each row starts and ends + row_changes = np.concatenate(([0], np.where(np.diff(sorted_rows) != 0)[0] + 1, [len(sorted_rows)])) + + # Compute rank of each edge within its row + edge_rank_in_row = np.zeros(len(sorted_rows), dtype=int) + for i in range(len(row_changes) - 1): + start, end = row_changes[i], row_changes[i + 1] + edge_rank_in_row[start:end] = np.arange(end - start) + + # Keep edges where rank < max_num_neighbours (smallest distances are first due to sorting) + mask_sorted = edge_rank_in_row < self.max_num_neighbours + + # Map back to original order + mask = np.zeros(adjmat.nnz, dtype=bool) + mask[sort_idx] = mask_sorted # Define the new sparse matrix return coo_matrix((adjmat.data[mask], (adjmat.row[mask], adjmat.col[mask])), shape=adjmat.shape) @@ -183,12 +223,22 @@ def compute_edge_index(self, source_nodes: NodeStorage, target_nodes: NodeStorag torch.Tensor of shape (2, num_edges) The adjacency matrix. """ - LOGGER.info( - "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", - self.radius * EARTH_RADIUS, - self.source_name, - self.target_name, - ) + radius_km = self.radius * EARTH_RADIUS + if self.cutoff_distance_km is not None: + LOGGER.info( + "Using CutOff-Edges (with radius = %.1f km [direct]) between %s and %s.", + radius_km, + self.source_name, + self.target_name, + ) + else: + LOGGER.info( + "Using CutOff-Edges (with radius = %.1f km [factor=%.2f]) between %s and %s.", + radius_km, + self.cutoff_factor, + self.source_name, + self.target_name, + ) return super().compute_edge_index(source_nodes=source_nodes, target_nodes=target_nodes) @@ -231,8 +281,9 @@ def get_cartesian_node_coordinates( def get_cutoff_radius(self, graph: HeteroData): """Compute the cut-off radius. - The cut-off radius is computed as the product of the target nodes - reference distance and the cut-off factor. + The cut-off radius is computed either as: + - The product of the source nodes reference distance and the cut-off factor, or + - Directly from the cutoff distance in kilometers. Parameters ---------- @@ -242,12 +293,20 @@ def get_cutoff_radius(self, graph: HeteroData): Returns ------- float - The cut-off radius. + The cut-off radius in Cartesian coordinates on unit sphere. """ - reference_dist = CutOffEdges.get_reference_distance( - graph[self.source_name], mask_attr_name=self.source_mask_attr_name - ) - return reference_dist * self.cutoff_factor + if self.cutoff_distance_km is not None: + # Convert km to Cartesian distance on unit sphere + # For small distances: Cartesian distance ≈ great circle distance (radians) + # radians = km / EARTH_RADIUS + radius = self.cutoff_distance_km / EARTH_RADIUS + else: + # Use factor-based approach + reference_dist = CutOffEdges.get_reference_distance( + graph[self.source_name], mask_attr_name=self.source_mask_attr_name + ) + radius = reference_dist * self.cutoff_factor + return radius def undo_masking_adj_matrix(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage): adj_matrix = adj_matrix.T diff --git a/graphs/src/anemoi/graphs/schemas/edge_attributes_schemas.py b/graphs/src/anemoi/graphs/schemas/edge_attributes_schemas.py index f90fc923f..f4d43e73f 100644 --- a/graphs/src/anemoi/graphs/schemas/edge_attributes_schemas.py +++ b/graphs/src/anemoi/graphs/schemas/edge_attributes_schemas.py @@ -19,8 +19,10 @@ class ImplementedEdgeAttributeSchema(str, Enum): edge_length = "anemoi.graphs.edges.attributes.EdgeLength" edge_dirs = "anemoi.graphs.edges.attributes.EdgeDirection" + directional_harmonics = "anemoi.graphs.edges.attributes.DirectionalHarmonics" azimuth = "anemoi.graphs.edges.attributes.Azimuth" gaussian_weights = "anemoi.graphs.edges.attributes.GaussianDistanceWeights" + radial_basis_features = "anemoi.graphs.edges.attributes.RadialBasisFeatures" class BaseEdgeAttributeSchema(BaseModel): @@ -40,4 +42,26 @@ class EdgeAttributeFromNodeSchema(BaseModel): "Normalisation method applied to the edge attribute." -EdgeAttributeSchema = BaseEdgeAttributeSchema | EdgeAttributeFromNodeSchema +class DirectionalHarmonicsSchema(BaseModel): + target_: Literal["anemoi.graphs.edges.attributes.DirectionalHarmonics"] = Field(..., alias="_target_") + "Directional harmonics from edge directions" + order: int = Field(default=3, description="Maximum order of harmonics to compute") + norm: Literal["unit-max", "l1", "l2", "unit-sum", "unit-std"] | None = Field( + default=None, description="Normalization method" + ) + dtype: str = Field(default="float32", description="Data type for computations") + + +class RadialBasisFeaturesSchema(BaseModel): + target_: Literal["anemoi.graphs.edges.attributes.RadialBasisFeatures"] = Field(..., alias="_target_") + "Radial basis function features from edge distances" + r_scale: float | None = Field(default=None, description="Global scale factor (None for adaptive per-node scaling)") + centers: list[float] | None = Field(default=None, description="RBF center positions [0, 1]") + sigma: float = Field(default=0.2, description="Width of Gaussian RBF functions") + epsilon: float = Field(default=1e-10, description="Small constant to avoid division by zero") + dtype: str = Field(default="float32", description="Data type for computations") + + +EdgeAttributeSchema = ( + BaseEdgeAttributeSchema | EdgeAttributeFromNodeSchema | DirectionalHarmonicsSchema | RadialBasisFeaturesSchema +) diff --git a/graphs/src/anemoi/graphs/schemas/edge_schemas.py b/graphs/src/anemoi/graphs/schemas/edge_schemas.py index b7f15762f..701eefb27 100644 --- a/graphs/src/anemoi/graphs/schemas/edge_schemas.py +++ b/graphs/src/anemoi/graphs/schemas/edge_schemas.py @@ -14,6 +14,7 @@ from pydantic import Field from pydantic import PositiveFloat from pydantic import PositiveInt +from pydantic import model_validator from anemoi.utils.schemas import BaseModel @@ -36,12 +37,29 @@ class CutoffEdgeSchema(BaseModel): ..., alias="_target_" ) "Cut-off based edges implementation from anemoi.graphs.edges." - cutoff_factor: PositiveFloat = Field(example=0.6) - "Factor to multiply the grid reference distance to get the cut-off radius. Default to 0.6." + cutoff_factor: PositiveFloat | None = Field(default=None, example=0.6) + "Factor to multiply the grid reference distance to get the cut-off radius. Mutually exclusive with cutoff_distance_km." + cutoff_distance_km: PositiveFloat | None = Field(default=None, example=500.0) + "Cutoff radius in kilometers. Mutually exclusive with cutoff_factor." source_mask_attr_name: str | None = Field(default=None, examples=["boundary_mask"]) "Mask to apply to source nodes of the edges. Default to None." target_mask_attr_name: str | None = Field(default=None, examples=["boundary_mask"]) "Mask to apply to target nodes of the edges. Default to None." + max_num_neighbours: PositiveInt = Field(default=64, example=64) + "Maximum number of nearest neighbours to consider when building edges. Default to 64." + + @model_validator(mode="after") + def validate_cutoff_params(self): + """Validate that exactly one of cutoff_factor or cutoff_distance_km is provided.""" + cutoff_factor = self.cutoff_factor + cutoff_distance_km = self.cutoff_distance_km + + if cutoff_factor is None and cutoff_distance_km is None: + raise ValueError("Either cutoff_factor or cutoff_distance_km must be provided.") + if cutoff_factor is not None and cutoff_distance_km is not None: + raise ValueError("cutoff_factor and cutoff_distance_km are mutually exclusive. Provide only one.") + + return self class MultiScaleEdgeSchema(BaseModel): diff --git a/graphs/tests/edges/test_cutoff.py b/graphs/tests/edges/test_cutoff.py index 810b10d17..7c65fdb0c 100644 --- a/graphs/tests/edges/test_cutoff.py +++ b/graphs/tests/edges/test_cutoff.py @@ -21,16 +21,46 @@ def test_init(edge_builder): @pytest.mark.parametrize("edge_builder", [CutOffEdges, ReversedCutOffEdges]) -@pytest.mark.parametrize("cutoff_factor", [-0.5, "hello", None]) -def test_fail_init(edge_builder, cutoff_factor: str): - """Test CutOffEdges initialization with invalid cutoff.""" +@pytest.mark.parametrize("cutoff_factor", [-0.5, "hello"]) +def test_fail_init_invalid_cutoff_factor(edge_builder, cutoff_factor: str): + """Test CutOffEdges initialization with invalid cutoff_factor.""" with pytest.raises(AssertionError): - edge_builder("test_nodes1", "test_nodes2", cutoff_factor) + edge_builder("test_nodes1", "test_nodes2", cutoff_factor=cutoff_factor) + + +@pytest.mark.parametrize("edge_builder", [CutOffEdges, ReversedCutOffEdges]) +def test_fail_init_no_params(edge_builder): + """Test CutOffEdges initialization with neither cutoff_factor nor cutoff_distance_km.""" + with pytest.raises(ValueError, match="Either cutoff_factor or cutoff_distance_km must be provided"): + edge_builder("test_nodes1", "test_nodes2") + + +@pytest.mark.parametrize("edge_builder", [CutOffEdges, ReversedCutOffEdges]) +def test_fail_init_both_params(edge_builder): + """Test CutOffEdges initialization with both cutoff_factor and cutoff_distance_km.""" + with pytest.raises(ValueError, match="mutually exclusive"): + edge_builder("test_nodes1", "test_nodes2", cutoff_factor=0.5, cutoff_distance_km=500.0) @pytest.mark.parametrize("edge_builder", [CutOffEdges, ReversedCutOffEdges]) def test_cutoff(edge_builder, graph_with_nodes: HeteroData): - """Test CutOffEdges.""" - builder = edge_builder("test_nodes", "test_nodes", 0.5) + """Test CutOffEdges with cutoff_factor.""" + builder = edge_builder("test_nodes", "test_nodes", cutoff_factor=0.5) graph = builder.update_graph(graph_with_nodes) assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + +@pytest.mark.parametrize("edge_builder", [CutOffEdges, ReversedCutOffEdges]) +def test_cutoff_with_distance_km(edge_builder, graph_with_nodes: HeteroData): + """Test CutOffEdges with cutoff_distance_km.""" + builder = edge_builder("test_nodes", "test_nodes", cutoff_distance_km=500.0) + graph = builder.update_graph(graph_with_nodes) + assert ("test_nodes", "to", "test_nodes") in graph.edge_types + + +@pytest.mark.parametrize("edge_builder", [CutOffEdges, ReversedCutOffEdges]) +@pytest.mark.parametrize("cutoff_distance_km", [-500.0, "hello"]) +def test_fail_init_invalid_cutoff_distance_km(edge_builder, cutoff_distance_km): + """Test CutOffEdges initialization with invalid cutoff_distance_km.""" + with pytest.raises(AssertionError): + edge_builder("test_nodes1", "test_nodes2", cutoff_distance_km=cutoff_distance_km) diff --git a/graphs/tests/edges/test_edge_attributes.py b/graphs/tests/edges/test_edge_attributes.py index a67a3b3d6..164917234 100644 --- a/graphs/tests/edges/test_edge_attributes.py +++ b/graphs/tests/edges/test_edge_attributes.py @@ -14,9 +14,11 @@ from anemoi.graphs.edges.attributes import AttributeFromSourceNode from anemoi.graphs.edges.attributes import AttributeFromTargetNode +from anemoi.graphs.edges.attributes import DirectionalHarmonics from anemoi.graphs.edges.attributes import EdgeDirection from anemoi.graphs.edges.attributes import EdgeLength from anemoi.graphs.edges.attributes import GaussianDistanceWeights +from anemoi.graphs.edges.attributes import RadialBasisFeatures TEST_EDGES = ("test_nodes", "to", "test_nodes") @@ -63,3 +65,134 @@ def test_fail_edge_features(attribute_builder, graph_nodes_and_edges): source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] attribute_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + +def test_radial_basis_features_default(graph_nodes_and_edges): + """Test RadialBasisFeatures with default parameters (adaptive scaling).""" + edge_attr_builder = RadialBasisFeatures() + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + assert edge_attr.shape[0] == edge_index.shape[1] # num_edges + assert edge_attr.shape[1] == 5 # default 5 centers + + # Check per-target-node per-RBF-center normalization + # Within each RBF center, edges to same target should sum to ~1 + target_indices = edge_index[1] # Target nodes for each edge + for center_idx in range(edge_attr.shape[1]): + for target_node in target_indices.unique(): + mask = target_indices == target_node + center_sum = edge_attr[mask, center_idx].sum() + assert torch.isclose(center_sum, torch.tensor(1.0, dtype=edge_attr.dtype), atol=1e-6) + + +def test_radial_basis_features_global_scale(graph_nodes_and_edges): + """Test RadialBasisFeatures with global r_scale.""" + edge_attr_builder = RadialBasisFeatures(r_scale=1.0) + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + assert edge_attr.shape[0] == edge_index.shape[1] + + # Check per-target-node per-RBF-center normalization + target_indices = edge_index[1] + for center_idx in range(edge_attr.shape[1]): + for target_node in target_indices.unique(): + mask = target_indices == target_node + center_sum = edge_attr[mask, center_idx].sum() + assert torch.isclose(center_sum, torch.tensor(1.0, dtype=edge_attr.dtype), atol=1e-6) + + +def test_radial_basis_features_custom_centers(graph_nodes_and_edges): + """Test RadialBasisFeatures with custom centers.""" + custom_centers = [0.0, 0.33, 0.67, 1.0] + edge_attr_builder = RadialBasisFeatures(centers=custom_centers, sigma=0.15) + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + assert edge_attr.shape[1] == len(custom_centers) + + # Check per-target-node per-RBF-center normalization + target_indices = edge_index[1] + for center_idx in range(edge_attr.shape[1]): + for target_node in target_indices.unique(): + mask = target_indices == target_node + center_sum = edge_attr[mask, center_idx].sum() + assert torch.isclose(center_sum, torch.tensor(1.0, dtype=edge_attr.dtype), atol=1e-6) + + +def test_radial_basis_features_epsilon(graph_nodes_and_edges): + """Test RadialBasisFeatures with custom epsilon.""" + edge_attr_builder = RadialBasisFeatures(epsilon=1e-8) + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + # Should not crash with division by zero + + +def test_directional_harmonics_default(graph_nodes_and_edges): + """Test DirectionalHarmonics with default parameters.""" + edge_attr_builder = DirectionalHarmonics() + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + assert edge_attr.shape[0] == edge_index.shape[1] # num_edges + assert edge_attr.shape[1] == 2 * 3 # default order=3 -> 2*order features + + +@pytest.mark.parametrize("order", [1, 2, 3, 5]) +def test_directional_harmonics_order(graph_nodes_and_edges, order): + """Test DirectionalHarmonics with different orders.""" + edge_attr_builder = DirectionalHarmonics(order=order) + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + assert edge_attr.shape[1] == 2 * order # sin and cos for each order + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-std", "unit-range"]) +def test_directional_harmonics_with_norm(graph_nodes_and_edges, norm): + """Test DirectionalHarmonics with different normalization methods.""" + edge_attr_builder = DirectionalHarmonics(order=2, norm=norm) + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + assert isinstance(edge_attr, torch.Tensor) + assert edge_attr.shape[1] == 2 * 2 # order=2 + + +def test_directional_harmonics_values(graph_nodes_and_edges): + """Test that DirectionalHarmonics produces reasonable values.""" + edge_attr_builder = DirectionalHarmonics(order=1, norm=None) + edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index + source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] + target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) + + # For order=1, features are [sin(ψ), cos(ψ)] + # sin²(ψ) + cos²(ψ) = 1 + sin_vals = edge_attr[:, 0] + cos_vals = edge_attr[:, 1] + norms = sin_vals**2 + cos_vals**2 + + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-6)