Skip to content

Commit 9725bf7

Browse files
committed
radius cpu version
1 parent 4e6cb0c commit 9725bf7

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

test/test_radius.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
import torch
55
from torch_cluster import radius
66

7-
from .utils import tensor, grad_dtypes
7+
from .utils import grad_dtypes, devices, tensor
88

9-
devices = [torch.device('cuda')]
109

11-
12-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
1310
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1411
def test_radius(dtype, device):
1512
x = tensor([

torch_cluster/knn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
4646
assert y.size(0) == batch_y.size(0)
4747

4848
if x.is_cuda:
49-
assign_index = knn_cuda.knn(x, y, k, batch_x, batch_y)
50-
return assign_index
49+
return knn_cuda.knn(x, y, k, batch_x, batch_y)
5150

5251
# Rescale x and y.
5352
min_xy = min(x.min().item(), y.min().item())

torch_cluster/radius.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import scipy.spatial
23

34
if torch.cuda.is_available():
45
import radius_cuda
@@ -40,17 +41,25 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
4041
x = x.view(-1, 1) if x.dim() == 1 else x
4142
y = y.view(-1, 1) if y.dim() == 1 else y
4243

43-
assert x.is_cuda
4444
assert x.dim() == 2 and batch_x.dim() == 1
4545
assert y.dim() == 2 and batch_y.dim() == 1
4646
assert x.size(1) == y.size(1)
4747
assert x.size(0) == batch_x.size(0)
4848
assert y.size(0) == batch_y.size(0)
4949

50-
op = radius_cuda.radius if x.is_cuda else None
51-
assign_index = op(x, y, r, batch_x, batch_y, max_num_neighbors)
50+
if x.is_cuda:
51+
return radius_cuda.radius(x, y, r, batch_x, batch_y, max_num_neighbors)
5252

53-
return assign_index
53+
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
54+
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
55+
56+
tree = scipy.spatial.cKDTree(x)
57+
col = tree.query_ball_point(y, r)
58+
col = [torch.tensor(c) for c in col]
59+
row = [torch.full_like(c, i) for i, c in enumerate(col)]
60+
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
61+
62+
return torch.stack([row, col], dim=0)
5463

5564

5665
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):

0 commit comments

Comments
 (0)