Skip to content

Commit 28be0cb

Browse files
committed
run pre-commit
1 parent 16326d9 commit 28be0cb

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

graphs/src/anemoi/graphs/inspect.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,6 @@ def inspect(self):
9090
if self.show_nodes:
9191
LOGGER.info("Saving interactive plots of nodes ...")
9292
for nodes_name in self.graph.node_types:
93-
plot_interactive_nodes_2d(self.graph, nodes_name, out_file=self.output_path / f"{nodes_name}_nodes.html")
93+
plot_interactive_nodes_2d(
94+
self.graph, nodes_name, out_file=self.output_path / f"{nodes_name}_nodes.html"
95+
)

graphs/src/anemoi/graphs/plotting/interactive_3d.html.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,4 +953,4 @@ nor does it submit to any jurisdiction. #}
953953
</script>
954954
</body>
955955

956-
</html>
956+
</html>

graphs/src/anemoi/graphs/plotting/interactive_3d_html.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111

1212
import numpy as np
1313
import torch
14-
from torch_geometric.data import HeteroData
1514
from jinja2 import Template
15+
from torch_geometric.data import HeteroData
1616

1717
HTML_TEMPLATE_PATH = Path(__file__).parent / "interactive_3d.html.jinja"
1818

1919

20-
def subset_graph(
20+
def extract_nodes_edges(
2121
graph: HeteroData,
2222
nodes: list[str] | None = None,
2323
edges: list[str] | None = None,
2424
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
25-
"""Load a hetero graph from a file and separate nodes and edges by type.
25+
"""Extracts nodes and edges from a heterogeneous graph.
26+
27+
Optionally filters by specified types.
2628
2729
Parameters
2830
----------
@@ -98,7 +100,7 @@ def plot_interactive_graph_3d(
98100
out_file : str | Path, optional
99101
Name of the file to save the plot. Default is None.
100102
"""
101-
nodes, edges = subset_graph(graph)
103+
nodes, edges = extract_nodes_edges(graph)
102104

103105
for node_set in nodes:
104106
node_lats, node_lons = coords_to_latlon(nodes[node_set].numpy())
@@ -140,4 +142,4 @@ def plot_interactive_graph_3d(
140142
template = Template(HTML_TEMPLATE)
141143
html_output = template.render(nodes=nodes_embed, edges=edges_embed, max_degree=50, min_degree=1)
142144
with open(out_file, "w") as f:
143-
f.write(html_output)
145+
f.write(html_output)

0 commit comments

Comments
 (0)