Skip to content

Commit 4116005

Browse files
committed
fix neighbor sampling
1 parent 69fada5 commit 4116005

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

csrc/cpu/sampler_cpu.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
1515
auto num_neighbors = row_end - row_start;
1616

1717
int64_t size = count;
18-
if (count < 1) {
18+
if (count < 1)
1919
size = int64_t(ceil(factor * float(num_neighbors)));
20-
}
20+
if (size > num_neighbors)
21+
size = num_neighbors;
2122

2223
// If the number of neighbors is approximately equal to the number of
2324
// neighbors which are requested, we use `randperm` to sample without
@@ -26,16 +27,16 @@ torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
2627
std::unordered_set<int64_t> set;
2728
if (size < 0.7 * float(num_neighbors)) {
2829
while (int64_t(set.size()) < size) {
29-
int64_t sample = (rand() % num_neighbors) + row_start;
30-
set.insert(sample);
30+
int64_t sample = rand() % num_neighbors;
31+
set.insert(sample + row_start);
3132
}
3233
std::vector<int64_t> v(set.begin(), set.end());
3334
e_ids.insert(e_ids.end(), v.begin(), v.end());
3435
} else {
35-
auto sample = at::randperm(num_neighbors, start.options()) + row_start;
36+
auto sample = torch::randperm(num_neighbors, start.options());
3637
auto sample_data = sample.data_ptr<int64_t>();
3738
for (auto j = 0; j < size; j++) {
38-
e_ids.push_back(sample_data[j]);
39+
e_ids.push_back(sample_data[j] + row_start);
3940
}
4041
}
4142
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_extensions():
6363

6464
setup(
6565
name='torch_cluster',
66-
version='1.5.1',
66+
version='1.5.2',
6767
author='Matthias Fey',
6868
author_email='[email protected]',
6969
url='https://github.com/rusty1s/pytorch_cluster',

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__version__ = '1.5.1'
6+
__version__ = '1.5.2'
77
expected_torch_version = (1, 4)
88

99
try:

0 commit comments

Comments
 (0)