Skip to content

Commit 9c33077

Browse files
committed
fix fps implementation
1 parent aff91e0 commit 9c33077

File tree

5 files changed

+21
-9
lines changed

5 files changed

+21
-9
lines changed

csrc/cpu/fps_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
3535

3636
int64_t start_idx = 0;
3737
if (random_start) {
38-
start_idx = rand() % src.size(0);
38+
start_idx = rand() % y.size(0);
3939
}
4040

4141
out_data[out_start] = src_start + start_idx;

csrc/cuda/fps_cuda.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "utils.cuh"
66

7-
#define THREADS 1024
7+
#define THREADS 256
88

99
template <typename scalar_t>
1010
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
@@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
3131
int64_t best_idx = 0;
3232

3333
for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
34-
scalar_t tmp;
35-
scalar_t dd = (scalar_t)0.;
34+
scalar_t tmp, dd = (scalar_t)0.;
3635
for (int64_t d = 0; d < dim; d++) {
3736
tmp = src[dim * old + d] - src[dim * n + d];
3837
dd += tmp * tmp;
3938
}
40-
dist[n] = min(dist[n], dd);
41-
if (dist[n] > best) {
42-
best = dist[n];
39+
dd = min(dist[n], dd);
40+
dist[n] = dd;
41+
if (dd > best) {
42+
best = dd;
4343
best_idx = n;
4444
}
4545
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_extensions():
6363

6464
setup(
6565
name='torch_cluster',
66-
version='1.5.2',
66+
version='1.5.3',
6767
author='Matthias Fey',
6868
author_email='[email protected]',
6969
url='https://github.com/rusty1s/pytorch_cluster',

test/test_fps.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,15 @@ def test_fps(dtype, device):
2626

2727
out = fps(x, ratio=0.5, random_start=False)
2828
assert out.sort()[0].tolist() == [0, 5, 6, 7]
29+
30+
31+
@pytest.mark.parametrize('device', devices)
32+
def test_random_fps(device):
33+
N = 1024
34+
for _ in range(5):
35+
pos = torch.randn((2 * N, 3), device=device)
36+
batch_1 = torch.zeros(N, dtype=torch.long, device=device)
37+
batch_2 = torch.ones(N, dtype=torch.long, device=device)
38+
batch = torch.cat([batch_1, batch_2])
39+
idx = fps(pos, batch, ratio=0.5)
40+
assert idx.min() >= 0 and idx.max() < 2 * N

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__version__ = '1.5.2'
6+
__version__ = '1.5.3'
77
expected_torch_version = (1, 4)
88

99
try:

0 commit comments

Comments
 (0)