From 3850871f3417a12a5217f35ab005507923c179c5 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 5 May 2025 21:29:20 +0000 Subject: [PATCH] Add input seeds to sampling output for node sampler --- .../distributed/dist_neighbor_sampler.py | 4 +-- graphlearn_torch/python/loader/transform.py | 29 ++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index 1e34b53f..89e2f58e 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -358,7 +358,7 @@ async def _sample_from_nodes( num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, input_type=input_type, - metadata={} + metadata={'input_seeds': input_seeds}, ) else: @@ -389,7 +389,7 @@ async def _sample_from_nodes( batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, - metadata={} + metadata={'input_seeds': input_seeds}, ) # Reclaim inducer into pool. self.inducer_pool.put(inducer) diff --git a/graphlearn_torch/python/loader/transform.py b/graphlearn_torch/python/loader/transform.py index 8c40c778..6cb98c1e 100644 --- a/graphlearn_torch/python/loader/transform.py +++ b/graphlearn_torch/python/loader/transform.py @@ -113,21 +113,24 @@ def to_hetero_data( # update meta data input_type = hetero_sampler_out.input_type if isinstance(hetero_sampler_out.metadata, dict): - # if edge_dir == 'out', we need to reverse the edge type - res_edge_type = reverse_edge_type(input_type) if edge_dir == 'out' else input_type for k, v in hetero_sampler_out.metadata.items(): - if k == 'edge_label_index': - if edge_dir == 'out': - data[res_edge_type]['edge_label_index'] = \ - torch.stack((v[1], v[0]), dim=0) + if isinstance(input_type, tuple): + # if edge_dir == 'out', we need to reverse the edge type + res_edge_type = reverse_edge_type(input_type) if edge_dir == 'out' else input_type + if k == 'edge_label_index': + if edge_dir == 'out': + data[res_edge_type]['edge_label_index'] = \ + torch.stack((v[1], v[0]), dim=0) + else: + data[res_edge_type]['edge_label_index'] = v + elif k == 'edge_label': + data[res_edge_type]['edge_label'] = v + elif k == 'src_index': + data[input_type[0]]['src_index'] = v + elif k in ['dst_pos_index', 'dst_neg_index']: + data[input_type[-1]][k] = v else: - data[res_edge_type]['edge_label_index'] = v - elif k == 'edge_label': - data[res_edge_type]['edge_label'] = v - elif k == 'src_index': - data[input_type[0]]['src_index'] = v - elif k in ['dst_pos_index', 'dst_neg_index']: - data[input_type[-1]][k] = v + data[k] = v else: data[k] = v elif hetero_sampler_out.metadata is not None: