|
1 | 1 | import torch |
| 2 | +import scipy.spatial |
2 | 3 |
|
3 | 4 | if torch.cuda.is_available(): |
4 | 5 | import radius_cuda |
@@ -40,17 +41,25 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): |
40 | 41 | x = x.view(-1, 1) if x.dim() == 1 else x |
41 | 42 | y = y.view(-1, 1) if y.dim() == 1 else y |
42 | 43 |
|
43 | | - assert x.is_cuda |
44 | 44 | assert x.dim() == 2 and batch_x.dim() == 1 |
45 | 45 | assert y.dim() == 2 and batch_y.dim() == 1 |
46 | 46 | assert x.size(1) == y.size(1) |
47 | 47 | assert x.size(0) == batch_x.size(0) |
48 | 48 | assert y.size(0) == batch_y.size(0) |
49 | 49 |
|
50 | | - op = radius_cuda.radius if x.is_cuda else None |
51 | | - assign_index = op(x, y, r, batch_x, batch_y, max_num_neighbors) |
| 50 | + if x.is_cuda: |
| 51 | + return radius_cuda.radius(x, y, r, batch_x, batch_y, max_num_neighbors) |
52 | 52 |
|
53 | | - return assign_index |
| 53 | + x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1) |
| 54 | + y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1) |
| 55 | + |
| 56 | + tree = scipy.spatial.cKDTree(x) |
| 57 | + col = tree.query_ball_point(y, r) |
| 58 | + col = [torch.tensor(c) for c in col] |
| 59 | + row = [torch.full_like(c, i) for i, c in enumerate(col)] |
| 60 | + row, col = torch.cat(row, dim=0), torch.cat(col, dim=0) |
| 61 | + |
| 62 | + return torch.stack([row, col], dim=0) |
54 | 63 |
|
55 | 64 |
|
56 | 65 | def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): |
|
0 commit comments