Skip to content

Commit 6878f3b

Browse files
ananthsubSeanNaren
andauthored
Enable DDP Plugin to pass through args to LightningDistributedDataParallel (#4382)
* Update ddp_plugin.py * Update ddp_plugin.py * Update ddp_plugin.py * Update test_ddp_plugin.py * Update pytorch_lightning/plugins/ddp_plugin.py * Update pytorch_lightning/plugins/ddp_plugin.py * Fixed imports, make ddp_kwargs protected Co-authored-by: SeanNaren <[email protected]>
1 parent c50c225 commit 6878f3b

File tree

2 files changed

+103
-35
lines changed

2 files changed

+103
-35
lines changed

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
1+
from typing import List, Dict, Any
2+
23
from pytorch_lightning.core.lightning import LightningModule
3-
from typing import List
4+
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
45

56

67
class DDPPlugin(object):
78
"""
89
Plugin to link a custom ddp implementation to any arbitrary accelerator.
910
11+
This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
12+
which in turn forwards all args to `DistributedDataParallel`.
13+
1014
Example::
1115
1216
class MyDDP(DDPPlugin):
@@ -17,11 +21,16 @@ def configure_ddp(self, model, device_ids):
1721
1822
my_ddp = MyDDP()
1923
trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
20-
2124
"""
2225

23-
def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> LightningDistributedDataParallel:
26+
def __init__(self, **kwargs):
27+
self._ddp_kwargs: Dict[str, Any] = kwargs
28+
29+
def configure_ddp(
30+
self, model: LightningModule, device_ids: List[int]
31+
) -> LightningDistributedDataParallel:
2432
"""
33+
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
2534
Override to define a custom DDP implementation.
2635
2736
.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
@@ -43,5 +52,13 @@ def configure_ddp(self, model, device_ids):
4352
the model wrapped in LightningDistributedDataParallel
4453
4554
"""
46-
model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True)
55+
# if unset, default `find_unused_parameters` `True`
56+
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
57+
"find_unused_parameters", True
58+
)
59+
model = LightningDistributedDataParallel(
60+
model,
61+
device_ids=device_ids,
62+
**self._ddp_kwargs,
63+
)
4764
return model

tests/plugins/test_ddp_plugin.py

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1-
from pytorch_lightning.callbacks import Callback
2-
from tests.base.boring_model import BoringModel
3-
from pytorch_lightning import accelerators, Trainer
4-
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
5-
import pytest
61
import os
72
from unittest import mock
83

4+
import pytest
5+
from pytorch_lightning import Trainer, accelerators
6+
from pytorch_lightning.callbacks import Callback
7+
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
8+
from tests.base.boring_model import BoringModel
99

10-
@mock.patch.dict(os.environ, {
11-
"CUDA_VISIBLE_DEVICES": "0,1",
12-
"SLURM_NTASKS": "2",
13-
"SLURM_JOB_NAME": "SOME_NAME",
14-
"SLURM_NODEID": "0",
15-
"LOCAL_RANK": "0",
16-
"SLURM_LOCALID": "0"
17-
})
18-
@mock.patch('torch.cuda.device_count', return_value=2)
19-
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
20-
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
21-
def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
2210

11+
@mock.patch.dict(
12+
os.environ,
13+
{
14+
"CUDA_VISIBLE_DEVICES": "0,1",
15+
"SLURM_NTASKS": "2",
16+
"SLURM_JOB_NAME": "SOME_NAME",
17+
"SLURM_NODEID": "0",
18+
"LOCAL_RANK": "0",
19+
"SLURM_LOCALID": "0",
20+
},
21+
)
22+
@mock.patch("torch.cuda.device_count", return_value=2)
23+
@pytest.mark.parametrize(
24+
["ddp_backend", "gpus", "num_processes"],
25+
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
26+
)
27+
def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
2328
class CB(Callback):
2429
def on_fit_start(self, trainer, pl_module):
2530
assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPPlugin)
@@ -31,24 +36,29 @@ def on_fit_start(self, trainer, pl_module):
3136
gpus=gpus,
3237
num_processes=num_processes,
3338
distributed_backend=ddp_backend,
34-
callbacks=[CB()]
39+
callbacks=[CB()],
3540
)
3641

3742
with pytest.raises(SystemExit):
3843
trainer.fit(model)
3944

4045

41-
@mock.patch.dict(os.environ, {
42-
"CUDA_VISIBLE_DEVICES": "0,1",
43-
"SLURM_NTASKS": "2",
44-
"SLURM_JOB_NAME": "SOME_NAME",
45-
"SLURM_NODEID": "0",
46-
"LOCAL_RANK": "0",
47-
"SLURM_LOCALID": "0"
48-
})
49-
@mock.patch('torch.cuda.device_count', return_value=2)
50-
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
51-
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
46+
@mock.patch.dict(
47+
os.environ,
48+
{
49+
"CUDA_VISIBLE_DEVICES": "0,1",
50+
"SLURM_NTASKS": "2",
51+
"SLURM_JOB_NAME": "SOME_NAME",
52+
"SLURM_NODEID": "0",
53+
"LOCAL_RANK": "0",
54+
"SLURM_LOCALID": "0",
55+
},
56+
)
57+
@mock.patch("torch.cuda.device_count", return_value=2)
58+
@pytest.mark.parametrize(
59+
["ddp_backend", "gpus", "num_processes"],
60+
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
61+
)
5262
def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
5363
class MyDDP(DDPPlugin):
5464
pass
@@ -65,7 +75,48 @@ def on_fit_start(self, trainer, pl_module):
6575
num_processes=num_processes,
6676
distributed_backend=ddp_backend,
6777
plugins=[MyDDP()],
68-
callbacks=[CB()]
78+
callbacks=[CB()],
79+
)
80+
81+
with pytest.raises(SystemExit):
82+
trainer.fit(model)
83+
84+
85+
@mock.patch.dict(
86+
os.environ,
87+
{
88+
"CUDA_VISIBLE_DEVICES": "0,1",
89+
"SLURM_NTASKS": "2",
90+
"SLURM_JOB_NAME": "SOME_NAME",
91+
"SLURM_NODEID": "0",
92+
"LOCAL_RANK": "0",
93+
"SLURM_LOCALID": "0",
94+
},
95+
)
96+
@mock.patch("torch.cuda.device_count", return_value=2)
97+
@pytest.mark.parametrize(
98+
["ddp_backend", "gpus", "num_processes"],
99+
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
100+
)
101+
def test_ddp_choice_custom_ddp_cpu_custom_args(
102+
tmpdir, ddp_backend, gpus, num_processes
103+
):
104+
class MyDDP(DDPPlugin):
105+
pass
106+
107+
class CB(Callback):
108+
def on_fit_start(self, trainer, pl_module):
109+
assert isinstance(trainer.accelerator_backend.ddp_plugin, MyDDP)
110+
raise SystemExit()
111+
112+
model = BoringModel()
113+
trainer = Trainer(
114+
fast_dev_run=True,
115+
gpus=gpus,
116+
num_processes=num_processes,
117+
distributed_backend=ddp_backend,
118+
plugins=[MyDDP(broadcast_buffers=False, find_unused_parameters=True)],
119+
callbacks=[CB()],
69120
)
70121

71122
with pytest.raises(SystemExit):

0 commit comments

Comments
 (0)