|
1 | 1 | import torch |
| 2 | +import scipy.cluster |
2 | 3 |
|
3 | 4 | if torch.cuda.is_available(): |
4 | 5 | import nearest_cuda |
@@ -35,14 +36,24 @@ def nearest(x, y, batch_x=None, batch_y=None): |
35 | 36 | x = x.view(-1, 1) if x.dim() == 1 else x |
36 | 37 | y = y.view(-1, 1) if y.dim() == 1 else y |
37 | 38 |
|
38 | | - assert x.is_cuda |
39 | 39 | assert x.dim() == 2 and batch_x.dim() == 1 |
40 | 40 | assert y.dim() == 2 and batch_y.dim() == 1 |
41 | 41 | assert x.size(1) == y.size(1) |
42 | 42 | assert x.size(0) == batch_x.size(0) |
43 | 43 | assert y.size(0) == batch_y.size(0) |
44 | 44 |
|
45 | | - op = nearest_cuda.nearest if x.is_cuda else None |
46 | | - out = op(x, y, batch_x, batch_y) |
| 45 | + if x.is_cuda: |
| 46 | + return nearest_cuda.nearest(x, y, batch_x, batch_y) |
47 | 47 |
|
48 | | - return out |
| 48 | + # Rescale x and y. |
| 49 | + min_xy = min(x.min().item(), y.min().item()) |
| 50 | + x, y = x - min_xy, y - min_xy |
| 51 | + |
| 52 | + max_xy = max(x.max().item(), y.max().item()) |
| 53 | + x, y, = x / max_xy, y / max_xy |
| 54 | + |
| 55 | + # Concat batch/features to ensure no cross-links between examples exist. |
| 56 | + x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1) |
| 57 | + y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1) |
| 58 | + |
| 59 | + return torch.from_numpy(scipy.cluster.vq.vq(x, y)[0]).to(torch.long) |
0 commit comments