1111
1212import numpy as np
1313import torch
14- from torch_geometric .data import HeteroData
1514from jinja2 import Template
15+ from torch_geometric .data import HeteroData
1616
1717HTML_TEMPLATE_PATH = Path (__file__ ).parent / "interactive_3d.html.jinja"
1818
1919
20- def subset_graph (
20+ def extract_nodes_edges (
2121 graph : HeteroData ,
2222 nodes : list [str ] | None = None ,
2323 edges : list [str ] | None = None ,
2424) -> tuple [dict [str , torch .Tensor ], dict [str , torch .Tensor ]]:
25- """Load a hetero graph from a file and separate nodes and edges by type.
25+ """Extracts nodes and edges from a heterogeneous graph.
26+
27+ Optionally filters by specified types.
2628
2729 Parameters
2830 ----------
@@ -98,7 +100,7 @@ def plot_interactive_graph_3d(
98100 out_file : str | Path, optional
99101 Name of the file to save the plot. Default is None.
100102 """
101- nodes , edges = subset_graph (graph )
103+ nodes , edges = extract_nodes_edges (graph )
102104
103105 for node_set in nodes :
104106 node_lats , node_lons = coords_to_latlon (nodes [node_set ].numpy ())
@@ -140,4 +142,4 @@ def plot_interactive_graph_3d(
140142 template = Template (HTML_TEMPLATE )
141143 html_output = template .render (nodes = nodes_embed , edges = edges_embed , max_degree = 50 , min_degree = 1 )
142144 with open (out_file , "w" ) as f :
143- f .write (html_output )
145+ f .write (html_output )
0 commit comments