Skip to content

Commit fca1e9a

Browse files
JPXKQXfrazane
authored andcommitted
fix(graphs,tests): new test and fix anemoi-graphs tests with gpu (#637)
## Description New test for `GaussianWeightsAttritbute`. Fix anemoi-graphs tests with GPU. ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 7e9b47b commit fca1e9a

File tree

4 files changed

+20
-15
lines changed

4 files changed

+20
-15
lines changed

graphs/src/anemoi/graphs/edges/attributes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def subset_node_information(self, source_nodes: NodeStorage, target_nodes: NodeS
6060

6161
def forward(self, x: tuple[NodeStorage, NodeStorage], edge_index: Adj, size: Size = None) -> torch.Tensor:
6262
x = self.subset_node_information(*x)
63-
return self.propagate(edge_index, x=x, size=size)
63+
return self.propagate(edge_index.to(self.device), x=x, size=size)
6464

6565
@abstractmethod
6666
def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: ...

graphs/src/anemoi/graphs/edges/builders/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
109109
"""
110110
for attr_name, attr_config in config.items():
111111
edge_index = graph[self.name].edge_index
112-
edge_builder = instantiate(attr_config)
113-
graph[self.name][attr_name] = edge_builder(
112+
edge_attribute_builder = instantiate(attr_config)
113+
graph[self.name][attr_name] = edge_attribute_builder(
114114
x=(graph[self.name[0]], graph[self.name[2]]), edge_index=edge_index
115115
)
116116
return graph

graphs/src/anemoi/graphs/processors/post_process.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929

3030
class PostProcessor(ABC):
31+
"""Base PostProcessor class."""
3132

3233
@abstractmethod
3334
def update_graph(self, graph: HeteroData, **kwargs: Any) -> HeteroData:
@@ -44,6 +45,7 @@ def __init__(
4445
) -> None:
4546
self.nodes_names = (nodes_name,) if isinstance(nodes_name, str) else tuple(nodes_name)
4647
self.save_mask_indices_to_attr = save_mask_indices_to_attr
48+
super().__init__()
4749

4850
def removing_nodes(self, graph: HeteroData, mask: torch.Tensor, nodes_name: str) -> HeteroData:
4951
"""Remove nodes based on the mask passed."""
@@ -113,7 +115,7 @@ def update_graph(self, graph: HeteroData, **kwargs: Any) -> HeteroData:
113115
The post-processed graph.
114116
"""
115117
for nodes_name in self.nodes_names:
116-
mask = self.compute_mask(graph, nodes_name)
118+
mask = self.compute_mask(graph, nodes_name).cpu()
117119
LOGGER.info(f"Removing {(~mask).sum()} nodes from {nodes_name}.")
118120
graph = self.removing_nodes(graph, mask, nodes_name)
119121
graph = self.update_edge_indices(graph, mask, nodes_name)
@@ -248,6 +250,7 @@ class BaseSortEdgeIndex(PostProcessor, ABC):
248250
def __init__(self, descending: bool = True) -> None:
249251
assert self.nodes_axis is not None, f"{self.__class__.__name__} must define the nodes_axis class attribute."
250252
self.descending = descending
253+
super().__init__()
251254

252255
def get_sorting_mask(self, edges: dict) -> torch.Tensor:
253256
sort_indices = torch.sort(edges["edge_index"], descending=self.descending, dim=1)
@@ -311,15 +314,15 @@ def __init__(
311314
self.source_name = source_name
312315
self.target_name = target_name
313316
self.edges_name = (self.source_name, "to", self.target_name)
314-
self.mask: torch.Tensor = None
317+
super().__init__()
315318

316-
def removing_edges(self, graph: HeteroData) -> HeteroData:
319+
def removing_edges(self, graph: HeteroData, mask: torch.Tensor) -> HeteroData:
317320
"""Remove edges based on the mask passed."""
318321
for attr_name in graph[self.edges_name].edge_attrs():
319322
if attr_name == "edge_index":
320-
graph[self.edges_name][attr_name] = graph[self.edges_name][attr_name][:, self.mask]
323+
graph[self.edges_name][attr_name] = graph[self.edges_name][attr_name].cpu()[:, mask]
321324
else:
322-
graph[self.edges_name][attr_name] = graph[self.edges_name][attr_name][self.mask, :]
325+
graph[self.edges_name][attr_name] = graph[self.edges_name][attr_name].cpu()[mask, :]
323326

324327
return graph
325328

@@ -351,9 +354,9 @@ def update_graph(self, graph: HeteroData, **kwargs: Any) -> HeteroData:
351354
HeteroData
352355
The post-processed graph.
353356
"""
354-
self.mask = self.compute_mask(graph)
355-
LOGGER.info(f"Removing {(~self.mask).sum()} edges from {self.edges_name}.")
356-
graph = self.removing_edges(graph)
357+
mask = self.compute_mask(graph).cpu()
358+
LOGGER.info(f"Removing {(~mask).sum()} edges from {self.edges_name}.")
359+
graph = self.removing_edges(graph, mask)
357360
graph_config = kwargs.get("graph_config", {})
358361
graph = self.recompute_attributes(graph, graph_config)
359362
return graph
@@ -398,14 +401,14 @@ def compute_mask(self, graph: HeteroData) -> torch.Tensor:
398401
target_nodes = graph[self.target_name]
399402
edge_index = graph[self.edges_name].edge_index
400403
lengths = EARTH_RADIUS * EdgeLength()(x=(source_nodes, target_nodes), edge_index=edge_index)
401-
mask = torch.where(lengths > self.treshold, False, True).squeeze()
404+
mask = torch.where(lengths > self.treshold, False, True).squeeze().cpu()
402405
cases = [
403406
(self.source_mask_attr_name, source_nodes, 0),
404407
(self.target_mask_attr_name, target_nodes, 1),
405408
]
406409
for mask_attr_name, nodes, i in cases:
407410
if mask_attr_name:
408411
attr_mask = nodes[mask_attr_name].squeeze()
409-
edge_mask = attr_mask[edge_index[i]]
412+
edge_mask = attr_mask[edge_index[i]].cpu()
410413
mask = torch.logical_or(mask, ~edge_mask)
411414
return mask

graphs/tests/edges/test_edge_attributes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from anemoi.graphs.edges.attributes import AttributeFromTargetNode
1717
from anemoi.graphs.edges.attributes import EdgeDirection
1818
from anemoi.graphs.edges.attributes import EdgeLength
19+
from anemoi.graphs.edges.attributes import GaussianDistanceWeights
1920

2021
TEST_EDGES = ("test_nodes", "to", "test_nodes")
2122

@@ -31,10 +32,11 @@ def test_directional_features(graph_nodes_and_edges, norm):
3132
assert isinstance(edge_attr, torch.Tensor)
3233

3334

35+
@pytest.mark.parametrize("edge_attr_cls", [EdgeLength, GaussianDistanceWeights])
3436
@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-std", "unit-range"])
35-
def test_edge_lengths(graph_nodes_and_edges, norm):
37+
def test_edge_lengths(edge_attr_cls, graph_nodes_and_edges, norm):
3638
"""Test EdgeLength compute method."""
37-
edge_attr_builder = EdgeLength(norm=norm)
39+
edge_attr_builder = edge_attr_cls(norm=norm)
3840
edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index
3941
source_nodes = graph_nodes_and_edges[TEST_EDGES[0]]
4042
target_nodes = graph_nodes_and_edges[TEST_EDGES[2]]

0 commit comments

Comments
 (0)