Skip to content

Commit 4e6cb0c

Browse files
committed
knn cpu implementation
1 parent a9ad9d5 commit 4e6cb0c

File tree

3 files changed

+31
-13
lines changed

3 files changed

+31
-13
lines changed

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
__version__ = '1.2.1'
2525
url = 'https://github.com/rusty1s/pytorch_cluster'
2626

27-
install_requires = []
27+
install_requires = ['scipy']
2828
setup_requires = ['pytest-runner']
2929
tests_require = ['pytest', 'pytest-cov']
3030

@@ -43,5 +43,4 @@
4343
tests_require=tests_require,
4444
ext_modules=ext_modules,
4545
cmdclass=cmdclass,
46-
packages=find_packages(),
47-
)
46+
packages=find_packages(), )

test/test_knn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
import torch
55
from torch_cluster import knn
66

7-
from .utils import tensor
7+
from .utils import grad_dtypes, devices, tensor
88

9-
devices = [torch.device('cuda')]
10-
grad_dtypes = [torch.float]
119

12-
13-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
1410
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1511
def test_radius(dtype, device):
1612
x = tensor([
@@ -32,4 +28,8 @@ def test_radius(dtype, device):
3228
batch_y = tensor([0, 1], torch.long, device)
3329

3430
out = knn(x, y, 2, batch_x, batch_y)
35-
assert out.tolist() == [[0, 0, 1, 1], [2, 3, 4, 5]]
31+
assert out[0].tolist() == [0, 0, 1, 1]
32+
col = out[1][:2].tolist()
33+
assert col == [2, 3] or col == [3, 2]
34+
col = out[1][2:].tolist()
35+
assert col == [4, 5] or col == [5, 4]

torch_cluster/knn.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import scipy.spatial
23

34
if torch.cuda.is_available():
45
import knn_cuda
@@ -38,17 +39,35 @@ def knn(x, y, k, batch_x=None, batch_y=None):
3839
x = x.view(-1, 1) if x.dim() == 1 else x
3940
y = y.view(-1, 1) if y.dim() == 1 else y
4041

41-
assert x.is_cuda
4242
assert x.dim() == 2 and batch_x.dim() == 1
4343
assert y.dim() == 2 and batch_y.dim() == 1
4444
assert x.size(1) == y.size(1)
4545
assert x.size(0) == batch_x.size(0)
4646
assert y.size(0) == batch_y.size(0)
4747

48-
op = knn_cuda.knn if x.is_cuda else None
49-
assign_index = op(x, y, k, batch_x, batch_y)
48+
if x.is_cuda:
49+
assign_index = knn_cuda.knn(x, y, k, batch_x, batch_y)
50+
return assign_index
5051

51-
return assign_index
52+
# Rescale x and y.
53+
min_xy = min(x.min().item(), y.min().item())
54+
x, y = x - min_xy, y - min_xy
55+
56+
max_xy = max(x.max().item(), y.max().item())
57+
x, y, = x / max_xy, y / max_xy
58+
59+
# Concat batch/features to ensure no cross-links between examples exist.
60+
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
61+
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
62+
63+
tree = scipy.spatial.cKDTree(x)
64+
dist, col = tree.query(y, k=k, distance_upper_bound=x.size(1))
65+
dist, col = torch.tensor(dist), torch.tensor(col)
66+
row = torch.arange(col.size(0)).view(-1, 1).repeat(1, k)
67+
mask = 1 - torch.isinf(dist).view(-1)
68+
row, col = row.view(-1)[mask], col.view(-1)[mask]
69+
70+
return torch.stack([row, col], dim=0)
5271

5372

5473
def knn_graph(x, k, batch=None, loop=False):

0 commit comments

Comments
 (0)