1- from pytorch_lightning .callbacks import Callback
2- from tests .base .boring_model import BoringModel
3- from pytorch_lightning import accelerators , Trainer
4- from pytorch_lightning .plugins .ddp_plugin import DDPPlugin
5- import pytest
61import os
72from unittest import mock
83
4+ import pytest
5+ from pytorch_lightning import Trainer , accelerators
6+ from pytorch_lightning .callbacks import Callback
7+ from pytorch_lightning .plugins .ddp_plugin import DDPPlugin
8+ from tests .base .boring_model import BoringModel
99
10- @mock .patch .dict (os .environ , {
11- "CUDA_VISIBLE_DEVICES" : "0,1" ,
12- "SLURM_NTASKS" : "2" ,
13- "SLURM_JOB_NAME" : "SOME_NAME" ,
14- "SLURM_NODEID" : "0" ,
15- "LOCAL_RANK" : "0" ,
16- "SLURM_LOCALID" : "0"
17- })
18- @mock .patch ('torch.cuda.device_count' , return_value = 2 )
19- @pytest .mark .parametrize (['ddp_backend' , 'gpus' , 'num_processes' ],
20- [('ddp_cpu' , None , None ), ('ddp' , 2 , 0 ), ('ddp2' , 2 , 0 ), ('ddp_spawn' , 2 , 0 )])
21- def test_ddp_choice_default_ddp_cpu (tmpdir , ddp_backend , gpus , num_processes ):
2210
11+ @mock .patch .dict (
12+ os .environ ,
13+ {
14+ "CUDA_VISIBLE_DEVICES" : "0,1" ,
15+ "SLURM_NTASKS" : "2" ,
16+ "SLURM_JOB_NAME" : "SOME_NAME" ,
17+ "SLURM_NODEID" : "0" ,
18+ "LOCAL_RANK" : "0" ,
19+ "SLURM_LOCALID" : "0" ,
20+ },
21+ )
22+ @mock .patch ("torch.cuda.device_count" , return_value = 2 )
23+ @pytest .mark .parametrize (
24+ ["ddp_backend" , "gpus" , "num_processes" ],
25+ [("ddp_cpu" , None , None ), ("ddp" , 2 , 0 ), ("ddp2" , 2 , 0 ), ("ddp_spawn" , 2 , 0 )],
26+ )
27+ def test_ddp_choice_default_ddp_cpu (tmpdir , ddp_backend , gpus , num_processes ):
2328 class CB (Callback ):
2429 def on_fit_start (self , trainer , pl_module ):
2530 assert isinstance (trainer .accelerator_backend .ddp_plugin , DDPPlugin )
@@ -31,24 +36,29 @@ def on_fit_start(self, trainer, pl_module):
3136 gpus = gpus ,
3237 num_processes = num_processes ,
3338 distributed_backend = ddp_backend ,
34- callbacks = [CB ()]
39+ callbacks = [CB ()],
3540 )
3641
3742 with pytest .raises (SystemExit ):
3843 trainer .fit (model )
3944
4045
41- @mock .patch .dict (os .environ , {
42- "CUDA_VISIBLE_DEVICES" : "0,1" ,
43- "SLURM_NTASKS" : "2" ,
44- "SLURM_JOB_NAME" : "SOME_NAME" ,
45- "SLURM_NODEID" : "0" ,
46- "LOCAL_RANK" : "0" ,
47- "SLURM_LOCALID" : "0"
48- })
49- @mock .patch ('torch.cuda.device_count' , return_value = 2 )
50- @pytest .mark .parametrize (['ddp_backend' , 'gpus' , 'num_processes' ],
51- [('ddp_cpu' , None , None ), ('ddp' , 2 , 0 ), ('ddp2' , 2 , 0 ), ('ddp_spawn' , 2 , 0 )])
46+ @mock .patch .dict (
47+ os .environ ,
48+ {
49+ "CUDA_VISIBLE_DEVICES" : "0,1" ,
50+ "SLURM_NTASKS" : "2" ,
51+ "SLURM_JOB_NAME" : "SOME_NAME" ,
52+ "SLURM_NODEID" : "0" ,
53+ "LOCAL_RANK" : "0" ,
54+ "SLURM_LOCALID" : "0" ,
55+ },
56+ )
57+ @mock .patch ("torch.cuda.device_count" , return_value = 2 )
58+ @pytest .mark .parametrize (
59+ ["ddp_backend" , "gpus" , "num_processes" ],
60+ [("ddp_cpu" , None , None ), ("ddp" , 2 , 0 ), ("ddp2" , 2 , 0 ), ("ddp_spawn" , 2 , 0 )],
61+ )
5262def test_ddp_choice_custom_ddp_cpu (tmpdir , ddp_backend , gpus , num_processes ):
5363 class MyDDP (DDPPlugin ):
5464 pass
@@ -65,7 +75,48 @@ def on_fit_start(self, trainer, pl_module):
6575 num_processes = num_processes ,
6676 distributed_backend = ddp_backend ,
6777 plugins = [MyDDP ()],
68- callbacks = [CB ()]
78+ callbacks = [CB ()],
79+ )
80+
81+ with pytest .raises (SystemExit ):
82+ trainer .fit (model )
83+
84+
85+ @mock .patch .dict (
86+ os .environ ,
87+ {
88+ "CUDA_VISIBLE_DEVICES" : "0,1" ,
89+ "SLURM_NTASKS" : "2" ,
90+ "SLURM_JOB_NAME" : "SOME_NAME" ,
91+ "SLURM_NODEID" : "0" ,
92+ "LOCAL_RANK" : "0" ,
93+ "SLURM_LOCALID" : "0" ,
94+ },
95+ )
96+ @mock .patch ("torch.cuda.device_count" , return_value = 2 )
97+ @pytest .mark .parametrize (
98+ ["ddp_backend" , "gpus" , "num_processes" ],
99+ [("ddp_cpu" , None , None ), ("ddp" , 2 , 0 ), ("ddp2" , 2 , 0 ), ("ddp_spawn" , 2 , 0 )],
100+ )
101+ def test_ddp_choice_custom_ddp_cpu_custom_args (
102+ tmpdir , ddp_backend , gpus , num_processes
103+ ):
104+ class MyDDP (DDPPlugin ):
105+ pass
106+
107+ class CB (Callback ):
108+ def on_fit_start (self , trainer , pl_module ):
109+ assert isinstance (trainer .accelerator_backend .ddp_plugin , MyDDP )
110+ raise SystemExit ()
111+
112+ model = BoringModel ()
113+ trainer = Trainer (
114+ fast_dev_run = True ,
115+ gpus = gpus ,
116+ num_processes = num_processes ,
117+ distributed_backend = ddp_backend ,
118+ plugins = [MyDDP (broadcast_buffers = False , find_unused_parameters = True )],
119+ callbacks = [CB ()],
69120 )
70121
71122 with pytest .raises (SystemExit ):
0 commit comments