We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3369b5f commit 60836e2Copy full SHA for 60836e2
cpu/sampler.cpp
@@ -3,10 +3,8 @@
3
4
#include <TH/THGenerator.hpp>
5
6
-std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
7
- at::Tensor cumdeg,
8
- at::Tensor col, size_t size,
9
- float factor) {
+at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
+ float factor) {
10
THGenerator *generator = THGenerator_new();
11
12
auto start_ptr = start.data<int64_t>();
@@ -46,9 +44,7 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
46
44
47
45
int64_t len = e_ids.size();
48
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
49
- auto n_id = std::get<0>(at::_unique(col.index_select(0, e_id)));
50
-
51
- return std::make_tuple(n_id, e_id);
+ return e_id;
52
}
53
54
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
test/test_sampler.py
@@ -8,12 +8,9 @@ def test_neighbor_sampler():
start = torch.tensor([0, 1])
cumdeg = torch.tensor([0, 3, 7])
- col = torch.tensor([1, 2, 3, 0, 2, 3, 4])
13
- n_id, e_id = neighbor_sampler(start, cumdeg, col, size=1.0)
14
- assert n_id.tolist() == [0, 1, 2, 3, 4]
+ e_id = neighbor_sampler(start, cumdeg, size=1.0)
15
assert e_id.tolist() == [0, 2, 1, 5, 6, 3, 4]
16
17
- n_id, e_id = neighbor_sampler(start, cumdeg, col, size=3)
18
- assert n_id.tolist() == [1, 2, 3, 4]
+ e_id = neighbor_sampler(start, cumdeg, size=3)
19
assert e_id.tolist() == [1, 0, 2, 4, 5, 6]
torch_cluster/sampler.py
@@ -1,7 +1,7 @@
1
import torch_cluster.sampler_cpu
2
-def neighbor_sampler(start, cumdeg, col, size):
+def neighbor_sampler(start, cumdeg, size):
assert not start.is_cuda
factor = 1
@@ -10,4 +10,4 @@ def neighbor_sampler(start, cumdeg, col, size):
size = 2147483647
op = torch_cluster.sampler_cpu.neighbor_sampler
- return op(start, cumdeg, col, size, factor)
+ return op(start, cumdeg, size, factor)
0 commit comments