Skip to content

Commit 60836e2

Browse files
committed
remove node id
1 parent 3369b5f commit 60836e2

File tree

3 files changed

+7
-14
lines changed

3 files changed

+7
-14
lines changed

cpu/sampler.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
#include <TH/THGenerator.hpp>
55

6-
std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
7-
at::Tensor cumdeg,
8-
at::Tensor col, size_t size,
9-
float factor) {
6+
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
7+
float factor) {
108
THGenerator *generator = THGenerator_new();
119

1210
auto start_ptr = start.data<int64_t>();
@@ -46,9 +44,7 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
4644

4745
int64_t len = e_ids.size();
4846
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
49-
auto n_id = std::get<0>(at::_unique(col.index_select(0, e_id)));
50-
51-
return std::make_tuple(n_id, e_id);
47+
return e_id;
5248
}
5349

5450
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

test/test_sampler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@ def test_neighbor_sampler():
88

99
start = torch.tensor([0, 1])
1010
cumdeg = torch.tensor([0, 3, 7])
11-
col = torch.tensor([1, 2, 3, 0, 2, 3, 4])
1211

13-
n_id, e_id = neighbor_sampler(start, cumdeg, col, size=1.0)
14-
assert n_id.tolist() == [0, 1, 2, 3, 4]
12+
e_id = neighbor_sampler(start, cumdeg, size=1.0)
1513
assert e_id.tolist() == [0, 2, 1, 5, 6, 3, 4]
1614

17-
n_id, e_id = neighbor_sampler(start, cumdeg, col, size=3)
18-
assert n_id.tolist() == [1, 2, 3, 4]
15+
e_id = neighbor_sampler(start, cumdeg, size=3)
1916
assert e_id.tolist() == [1, 0, 2, 4, 5, 6]

torch_cluster/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch_cluster.sampler_cpu
22

33

4-
def neighbor_sampler(start, cumdeg, col, size):
4+
def neighbor_sampler(start, cumdeg, size):
55
assert not start.is_cuda
66

77
factor = 1
@@ -10,4 +10,4 @@ def neighbor_sampler(start, cumdeg, col, size):
1010
size = 2147483647
1111

1212
op = torch_cluster.sampler_cpu.neighbor_sampler
13-
return op(start, cumdeg, col, size, factor)
13+
return op(start, cumdeg, size, factor)

0 commit comments

Comments
 (0)