44import torch
55from torch_cluster import fps
66
7- from .utils import tensor , grad_dtypes
7+ from .utils import grad_dtypes , devices , tensor
88
9- devices = [torch .device ('cuda' )]
109
11-
12- @pytest .mark .skipif (not torch .cuda .is_available (), reason = 'CUDA not available' )
1310@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
1411def test_fps (dtype , device ):
1512 x = tensor ([
@@ -26,25 +23,3 @@ def test_fps(dtype, device):
2623
2724 out = fps (x , batch , ratio = 0.5 , random_start = False )
2825 assert out .tolist () == [0 , 2 , 4 , 6 ]
29-
30-
31- @pytest .mark .skipif (not torch .cuda .is_available (), reason = 'CUDA not available' )
32- @pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
33- def test_fps_speed (dtype , device ):
34- return
35- batch_size , num_nodes = 100 , 10000
36- x = torch .randn ((batch_size * num_nodes , 3 ), dtype = dtype , device = device )
37- batch = torch .arange (batch_size , dtype = torch .long , device = device )
38- batch = batch .view (- 1 , 1 ).repeat (1 , num_nodes ).view (- 1 )
39-
40- out = fps (x , batch , ratio = 0.5 , random_start = True )
41- assert out .size (0 ) == batch_size * num_nodes * 0.5
42- assert out .min ().item () >= 0 and out .max ().item () < batch_size * num_nodes
43-
44- batch_size , num_nodes , dim = 100 , 300 , 128
45- x = torch .randn ((batch_size * num_nodes , dim ), dtype = dtype , device = device )
46- batch = torch .arange (batch_size , dtype = torch .long , device = device )
47- batch = batch .view (- 1 , 1 ).repeat (1 , num_nodes ).view (- 1 )
48- out = fps (x , batch , ratio = 0.5 , random_start = True )
49- assert out .size (0 ) == batch_size * num_nodes * 0.5
50- assert out .min ().item () >= 0 and out .max ().item () < batch_size * num_nodes
0 commit comments