@@ -35,6 +35,11 @@ def test_radius(dtype, device):
3535 assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (0 , 2 ), (0 , 3 ), (1 , 1 ),
3636 (1 , 2 ), (1 , 5 ), (1 , 6 )])
3737
38+ jit = torch .jit .script (radius )
39+ edge_index = jit (x , y , 2 , max_num_neighbors = 4 )
40+ assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (0 , 2 ), (0 , 3 ), (1 , 1 ),
41+ (1 , 2 ), (1 , 5 ), (1 , 6 )])
42+
3843 edge_index = radius (x , y , 2 , batch_x , batch_y , max_num_neighbors = 4 )
3944 assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (0 , 2 ), (0 , 3 ), (1 , 5 ),
4045 (1 , 6 )])
@@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
6469 assert to_set (edge_index ) == set ([(1 , 0 ), (3 , 0 ), (0 , 1 ), (2 , 1 ), (1 , 2 ),
6570 (3 , 2 ), (0 , 3 ), (2 , 3 )])
6671
72+ jit = torch .jit .script (radius_graph )
73+ edge_index = jit (x , r = 2.5 , flow = 'source_to_target' )
74+ assert to_set (edge_index ) == set ([(1 , 0 ), (3 , 0 ), (0 , 1 ), (2 , 1 ), (1 , 2 ),
75+ (3 , 2 ), (0 , 3 ), (2 , 3 )])
76+
6777
6878@pytest .mark .parametrize ('dtype,device' , product ([torch .float ], devices ))
6979def test_radius_graph_large (dtype , device ):
7080 x = torch .randn (1000 , 3 , dtype = dtype , device = device )
7181
72- edge_index = radius_graph (x , r = 0.5 , flow = 'target_to_source' , loop = True ,
82+ edge_index = radius_graph (x ,
83+ r = 0.5 ,
84+ flow = 'target_to_source' ,
85+ loop = True ,
7386 max_num_neighbors = 2000 )
7487
7588 tree = scipy .spatial .cKDTree (x .cpu ().numpy ())
0 commit comments