22
33import pytest
44import torch
5- from torch_cluster import knn
5+ from torch_cluster import knn , knn_graph
66
77from .utils import grad_dtypes , devices , tensor
88
99
1010@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
11- def test_radius (dtype , device ):
11+ def test_knn (dtype , device ):
1212 x = tensor ([
1313 [- 1 , - 1 ],
1414 [- 1 , + 1 ],
@@ -27,9 +27,24 @@ def test_radius(dtype, device):
2727 batch_x = tensor ([0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ], torch .long , device )
2828 batch_y = tensor ([0 , 1 ], torch .long , device )
2929
30- out = knn (x , y , 2 , batch_x , batch_y )
31- assert out [0 ].tolist () == [0 , 0 , 1 , 1 ]
32- col = out [1 ][:2 ].tolist ()
33- assert col == [2 , 3 ] or col == [3 , 2 ]
34- col = out [1 ][2 :].tolist ()
35- assert col == [4 , 5 ] or col == [5 , 4 ]
30+ row , col = knn (x , y , 2 , batch_x , batch_y )
31+ col = col .view (- 1 , 2 ).sort (dim = - 1 )[0 ].view (- 1 )
32+
33+ assert row .tolist () == [0 , 0 , 1 , 1 ]
34+ assert col .tolist () == [2 , 3 , 4 , 5 ]
35+
36+
37+ @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
38+ def test_knn_graph (dtype , device ):
39+ x = tensor ([
40+ [- 1 , - 1 ],
41+ [- 1 , + 1 ],
42+ [+ 1 , + 1 ],
43+ [+ 1 , - 1 ],
44+ ], dtype , device )
45+
46+ row , col = knn_graph (x , k = 2 )
47+ col = col .view (- 1 , 2 ).sort (dim = - 1 )[0 ].view (- 1 )
48+
49+ assert row .tolist () == [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ]
50+ assert col .tolist () == [1 , 3 , 0 , 2 , 1 , 3 , 0 , 2 ]
0 commit comments