Skip to content

Commit 07c92be

Browse files
committed
test graph gen
1 parent 9725bf7 commit 07c92be

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

test/test_knn.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import pytest
44
import torch
5-
from torch_cluster import knn
5+
from torch_cluster import knn, knn_graph
66

77
from .utils import grad_dtypes, devices, tensor
88

99

1010
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
11-
def test_radius(dtype, device):
11+
def test_knn(dtype, device):
1212
x = tensor([
1313
[-1, -1],
1414
[-1, +1],
@@ -27,9 +27,24 @@ def test_radius(dtype, device):
2727
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
2828
batch_y = tensor([0, 1], torch.long, device)
2929

30-
out = knn(x, y, 2, batch_x, batch_y)
31-
assert out[0].tolist() == [0, 0, 1, 1]
32-
col = out[1][:2].tolist()
33-
assert col == [2, 3] or col == [3, 2]
34-
col = out[1][2:].tolist()
35-
assert col == [4, 5] or col == [5, 4]
30+
row, col = knn(x, y, 2, batch_x, batch_y)
31+
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
32+
33+
assert row.tolist() == [0, 0, 1, 1]
34+
assert col.tolist() == [2, 3, 4, 5]
35+
36+
37+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
38+
def test_knn_graph(dtype, device):
39+
x = tensor([
40+
[-1, -1],
41+
[-1, +1],
42+
[+1, +1],
43+
[+1, -1],
44+
], dtype, device)
45+
46+
row, col = knn_graph(x, k=2)
47+
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
48+
49+
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
50+
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

test/test_radius.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from torch_cluster import radius
5+
from torch_cluster import radius, radius_graph
66

77
from .utils import grad_dtypes, devices, tensor
88

@@ -28,4 +28,20 @@ def test_radius(dtype, device):
2828
batch_y = tensor([0, 1], torch.long, device)
2929

3030
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
31+
3132
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
33+
34+
35+
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
36+
def test_radius_graph(dtype, device):
37+
x = tensor([
38+
[-1, -1],
39+
[-1, +1],
40+
[+1, +1],
41+
[+1, -1],
42+
], dtype, device)
43+
44+
row, col = radius_graph(x, r=2)
45+
46+
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
47+
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

0 commit comments

Comments
 (0)