Skip to content

Commit b0f9f81

Browse files
committed
fps cpu version
1 parent 0a03833 commit b0f9f81

File tree

3 files changed

+7
-32
lines changed

3 files changed

+7
-32
lines changed

cuda/fps_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
169169

170170
auto deg = degree(batch, batch_size);
171171
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
172-
auto k = (deg.toType(at::kFloat) * ratio).round().toType(at::kLong);
172+
auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong);
173173
auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);
174174

175175
at::Tensor start;

test/test_fps.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
import torch
55
from torch_cluster import fps
66

7-
from .utils import tensor, grad_dtypes
7+
from .utils import grad_dtypes, devices, tensor
88

9-
devices = [torch.device('cuda')]
109

11-
12-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
1310
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1411
def test_fps(dtype, device):
1512
x = tensor([
@@ -26,25 +23,3 @@ def test_fps(dtype, device):
2623

2724
out = fps(x, batch, ratio=0.5, random_start=False)
2825
assert out.tolist() == [0, 2, 4, 6]
29-
30-
31-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
32-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
33-
def test_fps_speed(dtype, device):
34-
return
35-
batch_size, num_nodes = 100, 10000
36-
x = torch.randn((batch_size * num_nodes, 3), dtype=dtype, device=device)
37-
batch = torch.arange(batch_size, dtype=torch.long, device=device)
38-
batch = batch.view(-1, 1).repeat(1, num_nodes).view(-1)
39-
40-
out = fps(x, batch, ratio=0.5, random_start=True)
41-
assert out.size(0) == batch_size * num_nodes * 0.5
42-
assert out.min().item() >= 0 and out.max().item() < batch_size * num_nodes
43-
44-
batch_size, num_nodes, dim = 100, 300, 128
45-
x = torch.randn((batch_size * num_nodes, dim), dtype=dtype, device=device)
46-
batch = torch.arange(batch_size, dtype=torch.long, device=device)
47-
batch = batch.view(-1, 1).repeat(1, num_nodes).view(-1)
48-
out = fps(x, batch, ratio=0.5, random_start=True)
49-
assert out.size(0) == batch_size * num_nodes * 0.5
50-
assert out.min().item() >= 0 and out.max().item() < batch_size * num_nodes

torch_cluster/fps.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import fps_cpu
23

34
if torch.cuda.is_available():
45
import fps_cuda
@@ -39,12 +40,11 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
3940

4041
x = x.view(-1, 1) if x.dim() == 1 else x
4142

42-
assert x.is_cuda
4343
assert x.dim() == 2 and batch.dim() == 1
4444
assert x.size(0) == batch.size(0)
4545
assert ratio > 0 and ratio < 1
4646

47-
op = fps_cuda.fps if x.is_cuda else None
48-
out = op(x, batch, ratio, random_start)
49-
50-
return out
47+
if x.is_cuda:
48+
return fps_cuda.fps(x, batch, ratio, random_start)
49+
else:
50+
return fps_cpu.fps(x, batch, ratio, random_start)

0 commit comments

Comments
 (0)