|
1 | 1 | import torch |
| 2 | +import scipy.spatial |
2 | 3 |
|
3 | 4 | if torch.cuda.is_available(): |
4 | 5 | import knn_cuda |
@@ -38,17 +39,35 @@ def knn(x, y, k, batch_x=None, batch_y=None): |
38 | 39 | x = x.view(-1, 1) if x.dim() == 1 else x |
39 | 40 | y = y.view(-1, 1) if y.dim() == 1 else y |
40 | 41 |
|
41 | | - assert x.is_cuda |
42 | 42 | assert x.dim() == 2 and batch_x.dim() == 1 |
43 | 43 | assert y.dim() == 2 and batch_y.dim() == 1 |
44 | 44 | assert x.size(1) == y.size(1) |
45 | 45 | assert x.size(0) == batch_x.size(0) |
46 | 46 | assert y.size(0) == batch_y.size(0) |
47 | 47 |
|
48 | | - op = knn_cuda.knn if x.is_cuda else None |
49 | | - assign_index = op(x, y, k, batch_x, batch_y) |
| 48 | + if x.is_cuda: |
| 49 | + assign_index = knn_cuda.knn(x, y, k, batch_x, batch_y) |
| 50 | + return assign_index |
50 | 51 |
|
51 | | - return assign_index |
| 52 | + # Rescale x and y. |
| 53 | + min_xy = min(x.min().item(), y.min().item()) |
| 54 | + x, y = x - min_xy, y - min_xy |
| 55 | + |
| 56 | + max_xy = max(x.max().item(), y.max().item()) |
| 57 | + x, y, = x / max_xy, y / max_xy |
| 58 | + |
| 59 | + # Concat batch/features to ensure no cross-links between examples exist. |
| 60 | + x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) |
| 61 | + y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1) |
| 62 | + |
| 63 | + tree = scipy.spatial.cKDTree(x) |
| 64 | + dist, col = tree.query(y, k=k, distance_upper_bound=x.size(1)) |
| 65 | + dist, col = torch.tensor(dist), torch.tensor(col) |
| 66 | + row = torch.arange(col.size(0)).view(-1, 1).repeat(1, k) |
| 67 | + mask = 1 - torch.isinf(dist).view(-1) |
| 68 | + row, col = row.view(-1)[mask], col.view(-1)[mask] |
| 69 | + |
| 70 | + return torch.stack([row, col], dim=0) |
52 | 71 |
|
53 | 72 |
|
54 | 73 | def knn_graph(x, k, batch=None, loop=False): |
|
0 commit comments