Skip to content

Commit 21208fc

Browse files
committed
nearest cpu
1 parent 07c92be commit 21208fc

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

test/test_nearest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
import torch
55
from torch_cluster import nearest
66

7-
from .utils import tensor, grad_dtypes
7+
from .utils import grad_dtypes, devices, tensor
88

9-
devices = [torch.device('cuda')]
109

11-
12-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
1310
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1411
def test_nearest(dtype, device):
1512
x = tensor([

torch_cluster/nearest.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import scipy.cluster
23

34
if torch.cuda.is_available():
45
import nearest_cuda
@@ -35,14 +36,24 @@ def nearest(x, y, batch_x=None, batch_y=None):
3536
x = x.view(-1, 1) if x.dim() == 1 else x
3637
y = y.view(-1, 1) if y.dim() == 1 else y
3738

38-
assert x.is_cuda
3939
assert x.dim() == 2 and batch_x.dim() == 1
4040
assert y.dim() == 2 and batch_y.dim() == 1
4141
assert x.size(1) == y.size(1)
4242
assert x.size(0) == batch_x.size(0)
4343
assert y.size(0) == batch_y.size(0)
4444

45-
op = nearest_cuda.nearest if x.is_cuda else None
46-
out = op(x, y, batch_x, batch_y)
45+
if x.is_cuda:
46+
return nearest_cuda.nearest(x, y, batch_x, batch_y)
4747

48-
return out
48+
# Rescale x and y.
49+
min_xy = min(x.min().item(), y.min().item())
50+
x, y = x - min_xy, y - min_xy
51+
52+
max_xy = max(x.max().item(), y.max().item())
53+
x, y, = x / max_xy, y / max_xy
54+
55+
# Concat batch/features to ensure no cross-links between examples exist.
56+
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
57+
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
58+
59+
return torch.from_numpy(scipy.cluster.vq.vq(x, y)[0]).to(torch.long)

0 commit comments

Comments
 (0)