Skip to content

Commit 969b787

Browse files
ssmmnn11anaprietonemJPXKQX
authored
fix: sparse export (#686)
Fixes for the export of a sparse matrix --------- Co-authored-by: Ana Prieto Nemesio <[email protected]> Co-authored-by: Mario Santa Cruz <[email protected]>
1 parent 14c0235 commit 969b787

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

graphs/docs/usage/create_sparse_matrices.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ For example, to export only the ``gauss_weight`` attribute for the
7474
7575
% anemoi-graphs export_to_sparse graph_recipe.yaml output_dir/ \
7676
--edges-attributes-name gauss_weight \
77-
--edges-name data->down
77+
--edges-name data down
7878
7979
You can specify multiple attributes or subgraphs by repeating the
8080
arguments:
@@ -84,8 +84,8 @@ arguments:
8484
% anemoi-graphs export_to_sparse graph_recipe.yaml output_dir/ \
8585
--edges-attributes-name gauss_weight \
8686
--edges-attributes-name another_weight \
87-
--edges-name data->down \
88-
--edges-name down->data
87+
--edges-name data down \
88+
--edges-name down data
8989
9090
This flexibility allows you to generate only the sparse matrices you
9191
need for your application, reducing storage and processing time.

graphs/src/anemoi/graphs/commands/export_to_sparse.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ExportToSparse(Command):
1818
1919
Example usage specifying an edge attribute:
2020
```
21-
anemoi-graphs export-to-sparse graph.pt output_path --edge_attribute_name edge_attr
21+
anemoi-graphs export-to-sparse graph.pt output_path --edge-attribute-name edge_attr
2222
```
2323
2424
Example usage specifying a subset of edges:
@@ -33,7 +33,7 @@ class ExportToSparse(Command):
3333
def add_arguments(self, command_parser):
3434
command_parser.add_argument("graph", help="Path to the graph (a .PT file) or a config file defining the graph.")
3535
command_parser.add_argument("output_path", help="Path to store the inspection results.")
36-
command_parser.add_argument("--edge_attribute_name", default=None, help="Name of the edge attribute to export.")
36+
command_parser.add_argument("--edge-attribute-name", default=None, help="Name of the edge attribute to export.")
3737
command_parser.add_argument(
3838
"--edges-name",
3939
nargs=2,
@@ -46,8 +46,7 @@ def run(self, args):
4646
kwargs = vars(args)
4747
edges_name = kwargs.get("edges_name", None)
4848
if edges_name is not None:
49-
# Convert list of lists to list of tuples
50-
kwargs["edges_name"] = [tuple(pair) for pair in edges_name]
49+
kwargs["edges_name"] = [(pair[0], "to", pair[1]) for pair in edges_name]
5150

5251
GraphExporter(
5352
graph=kwargs["graph"],

graphs/src/anemoi/graphs/export.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,46 @@ def get_nodes_info(self, source_name: str, target_name: str) -> tuple[int, int]:
6363

6464
@staticmethod
6565
def get_sparse_matrix(edge_index, edge_attribute, num_source_nodes, num_target_nodes):
66-
# Create sparse matrix
66+
"""Create sparse matrix for y = A x.
67+
68+
x is defined on source nodes, and output is on target nodes.
69+
70+
Arguments
71+
---------
72+
edge_index : torch.Tensor
73+
(2, E) tensor with rows: target nodes, cols: source nodes
74+
edge_attribute : torch.Tensor
75+
Edge weights/attributes
76+
num_source_nodes : int
77+
Number of source nodes (n_in)
78+
num_target_nodes : int
79+
Number of target nodes (n_out)
80+
81+
Returns
82+
-------
83+
torch.sparse_coo_tensor
84+
Sparse COO tensor on target nodes
85+
"""
86+
rows = edge_index[1]
87+
cols = edge_index[0]
88+
indices = torch.stack([rows, cols])
89+
6790
A = torch.sparse_coo_tensor(
68-
edge_index, edge_attribute, (num_source_nodes, num_target_nodes), device=edge_index.device
91+
indices,
92+
edge_attribute,
93+
(num_target_nodes, num_source_nodes),
94+
device=edge_index.device,
6995
)
7096
return A.coalesce()
7197

7298
@staticmethod
7399
def convert_to_scipy_sparse(A):
74100
"""Convert PyTorch sparse tensor to SciPy sparse matrix and save.
75101
76-
Args:
77-
A: PyTorch sparse COO tensor
78-
filename: Output filename (.npz extension)
102+
Arguments
103+
---------
104+
A : torch.sparse_coo_tensor
105+
Sparse tensor
79106
"""
80107
# Get indices and values from PyTorch sparse tensor
81108
indices = A.indices().cpu().numpy()

0 commit comments

Comments
 (0)