From 00727e85605ae808439e82c2cd36a279d9cb2b03 Mon Sep 17 00:00:00 2001 From: Yunhoe Ku Date: Wed, 25 Jun 2025 11:26:41 +0900 Subject: [PATCH 1/7] add: `MultiModelDDPStrategy` and its execution codes --- .../generative_adversarial_net_ddp.py | 260 ++++++++++++++++++ src/lightning/pytorch/strategies/ddp.py | 50 +++- 2 files changed, 297 insertions(+), 13 deletions(-) create mode 100644 examples/pytorch/domain_templates/generative_adversarial_net_ddp.py diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py new file mode 100644 index 0000000000000..ba5e1d98b328a --- /dev/null +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -0,0 +1,260 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""To run this template just do: python generative_adversarial_net.py. + +After a few epochs, launch TensorBoard to see the images being generated at every batch: + +tensorboard --logdir default + +""" +import math +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ! TESTING +import os +import sys + +sys.path.append(os.path.join(os.getcwd(), "src")) # noqa: E402 +# ! TESTING + +from lightning.pytorch import cli_lightning_logo +from lightning.pytorch.core import LightningModule +from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE +from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy + +if _TORCHVISION_AVAILABLE: + import torchvision + + +class Generator(nn.Module): + """ + >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Generator( + (model): Sequential(...) + ) + """ + + def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): + super().__init__() + self.img_shape = img_shape + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(math.prod(img_shape))), + nn.Tanh(), + ) + + def forward(self, z): + img = self.model(z) + return img.view(img.size(0), *self.img_shape) + + +class Discriminator(nn.Module): + """ + >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Discriminator( + (model): Sequential(...) + ) + """ + + def __init__(self, img_shape): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(int(math.prod(img_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + ) + + def forward(self, img): + img_flat = img.view(img.size(0), -1) + return self.model(img_flat) + + +class GAN(LightningModule): + """ + >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + GAN( + (generator): Generator( + (model): Sequential(...) + ) + (discriminator): Discriminator( + (model): Sequential(...) + ) + ) + """ + + def __init__( + self, + img_shape: tuple = (1, 28, 28), + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + latent_dim: int = 100, + ): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False + + # networks + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) + self.discriminator = Discriminator(img_shape=img_shape) + + self.validation_z = torch.randn(8, self.hparams.latent_dim) + + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + + # ! TESTING + self.save_path = "pl_test_multi_gpu" + os.makedirs(self.save_path, exist_ok=True) + + def forward(self, z): + return self.generator(z) + + @staticmethod + def adversarial_loss(y_hat, y): + return F.binary_cross_entropy_with_logits(y_hat, y) + + def training_step(self, batch): + imgs, _ = batch + + opt_g, opt_d = self.optimizers() + + # sample noise + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) + z = z.type_as(imgs) + + # Train generator + # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + self.toggle_optimizer(opt_g) + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) + opt_g.zero_grad() + self.manual_backward(g_loss) + opt_g.step() + self.untoggle_optimizer(opt_g) + + # Train discriminator + # Measure discriminator's ability to classify real from generated samples + # how well can it label as real? + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + self.toggle_optimizer(opt_d) + real_loss = self.adversarial_loss(self.discriminator(imgs), valid) + + # how well can it label as fake? + fake = torch.zeros(imgs.size(0), 1) + fake = fake.type_as(imgs) + + fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) + + # discriminator loss is the average of these + d_loss = (real_loss + fake_loss) / 2 + + opt_d.zero_grad() + self.manual_backward(d_loss) + opt_d.step() + self.untoggle_optimizer(opt_d) + + self.log_dict({"d_loss": d_loss, "g_loss": g_loss}) + + def configure_optimizers(self): + lr = self.hparams.lr + b1 = self.hparams.b1 + b2 = self.hparams.b2 + + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + return opt_g, opt_d + + # ! TESTING + def on_train_epoch_start(self): + if self.trainer.is_global_zero: + print("GEN: ", self.generator.module.model[0].bias[:10]) + print("DISC: ", self.discriminator.module.model[0].bias[:10]) + + # ! TESTING + def validation_step(self, batch, batch_idx): + pass + + # ! TESTING + @torch.no_grad() + def on_validation_epoch_end(self): + if self.current_epoch % 5: + self.generator.eval(), self.discriminator.eval() + + z = self.validation_z.type_as(self.generator.module.model[0].weight) + sample_imgs = self(z) + + if self.trainer.is_global_zero: + grid = torchvision.utils.make_grid(sample_imgs) + torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png")) + + self.generator.train(), self.discriminator.train() + + +def main(args: Namespace) -> None: + model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim) + + # ! `MultiModelDDPStrategy` is critical for multi-gpu training + # ! Otherwise, it will not work with multiple models. + # ! There are two ways to run training codes with previous `DDPStrategy`; + # ! 1) activate `find_unused_parameters=True`, 2) change from self.manual_backward(loss) to loss.backward() + # ! Neither of them is desirable. + dm = MNISTDataModule() + trainer = Trainer( + accelerator="auto", + devices=[0, 1, 2, 3], + strategy=MultiModelDDPStrategy(), + max_epochs=100, + ) + + trainer.fit(model, dm) + + +if __name__ == "__main__": + cli_lightning_logo() + parser = ArgumentParser() + + # Hyperparameters + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") + args = parser.parse_args() + + main(args) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..4e4bce82f2c5c 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -107,9 +107,7 @@ def __init__( @property def is_distributed(self) -> bool: # pragma: no-cover """Legacy property kept for backwards compatibility.""" - rank_zero_deprecation( - f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6 - ) + rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6) return True @property @@ -229,9 +227,7 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") if self._model_averaging_period is None: - raise ValueError( - "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." - ) + raise ValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.") from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer for optimizer in self.optimizers: @@ -240,10 +236,7 @@ def _enable_model_averaging(self) -> None: is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if isinstance(optimizer, (ZeroRedundancyOptimizer, PostLocalSGDOptimizer)) or is_distributed_optimizer: - raise ValueError( - f"Currently model averaging cannot work with a distributed optimizer of type " - f"{optimizer.__class__.__name__}." - ) + raise ValueError(f"Currently model averaging cannot work with a distributed optimizer of type " f"{optimizer.__class__.__name__}.") assert self._ddp_comm_state is not None self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( @@ -323,9 +316,7 @@ def model_to_device(self) -> None: self.model.to(self.root_device) @override - def reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -419,6 +410,39 @@ def teardown(self) -> None: super().teardown() +class MultiModelDDPStrategy(DDPStrategy): + @override + def _setup_model(self, model: Module) -> Module: + device_ids = self.determine_ddp_device_ids() + log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + with ctx: + for name, module in model.named_children(): + if isinstance(module, Module): + ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) + setattr(model, name, ddp_module) + + return model + + @override + def _register_ddp_hooks(self) -> None: + log.debug(f"{self.__class__.__name__}: registering ddp hooks") + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if self.root_device.type == "cuda": + assert isinstance(self.model, Module) + + for name, module in self.model.named_children(): + assert isinstance(module, DistributedDataParallel) + _register_ddp_comm_hook( + model=module, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) + + class _DDPForwardRedirection(_ForwardRedirection): @override def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: From e6b061afee0875490a9553f44cd7288df20209a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Jun 2025 02:45:24 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../generative_adversarial_net_ddp.py | 13 +++++++------ src/lightning/pytorch/strategies/ddp.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py index ba5e1d98b328a..7faec21cb8276 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -18,26 +18,27 @@ tensorboard --logdir default """ + import math + +# ! TESTING +import os +import sys from argparse import ArgumentParser, Namespace import torch import torch.nn as nn import torch.nn.functional as F -# ! TESTING -import os -import sys - -sys.path.append(os.path.join(os.getcwd(), "src")) # noqa: E402 +sys.path.append(os.path.join(os.getcwd(), "src")) # ! TESTING from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy if _TORCHVISION_AVAILABLE: import torchvision diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4e4bce82f2c5c..f69baa7ae2b13 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -107,7 +107,9 @@ def __init__( @property def is_distributed(self) -> bool: # pragma: no-cover """Legacy property kept for backwards compatibility.""" - rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6) + rank_zero_deprecation( + f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6 + ) return True @property @@ -227,7 +229,9 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") if self._model_averaging_period is None: - raise ValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.") + raise ValueError( + "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." + ) from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer for optimizer in self.optimizers: @@ -236,7 +240,10 @@ def _enable_model_averaging(self) -> None: is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if isinstance(optimizer, (ZeroRedundancyOptimizer, PostLocalSGDOptimizer)) or is_distributed_optimizer: - raise ValueError(f"Currently model averaging cannot work with a distributed optimizer of type " f"{optimizer.__class__.__name__}.") + raise ValueError( + f"Currently model averaging cannot work with a distributed optimizer of type " + f"{optimizer.__class__.__name__}." + ) assert self._ddp_comm_state is not None self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( @@ -316,7 +323,9 @@ def model_to_device(self) -> None: self.model.to(self.root_device) @override - def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: From aa9b027872b49402038d72895e0ae707ce677386 Mon Sep 17 00:00:00 2001 From: "Yunhoe, Ku" <59246456+samsara-ku@users.noreply.github.com> Date: Tue, 5 Aug 2025 17:10:11 +0900 Subject: [PATCH 3/7] refactor: extract block helper in GAN example --- .../generative_adversarial_net.py | 24 ++--- .../generative_adversarial_net_ddp.py | 23 ++--- .../strategies/test_multi_model_ddp.py | 95 +++++++++++++++++++ 3 files changed, 119 insertions(+), 23 deletions(-) create mode 100644 tests/tests_pytorch/strategies/test_multi_model_ddp.py diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 7ce7682d82c76..310fe7af0e08c 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -36,6 +36,14 @@ import torchvision +def _block(in_feat: int, out_feat: int, normalize: bool = True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + class Generator(nn.Module): """ >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -47,19 +55,11 @@ class Generator(nn.Module): def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape - - def block(in_feat, out_feat, normalize=True): - layers = [nn.Linear(in_feat, out_feat)] - if normalize: - layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - self.model = nn.Sequential( - *block(latent_dim, 128, normalize=False), - *block(128, 256), - *block(256, 512), - *block(512, 1024), + *_block(latent_dim, 128, normalize=False), + *_block(128, 256), + *_block(256, 512), + *_block(512, 1024), nn.Linear(1024, int(math.prod(img_shape))), nn.Tanh(), ) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py index 7faec21cb8276..6293a34706e2c 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -44,6 +44,14 @@ import torchvision +def _block(in_feat: int, out_feat: int, normalize: bool = True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + class Generator(nn.Module): """ >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -56,18 +64,11 @@ def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape - def block(in_feat, out_feat, normalize=True): - layers = [nn.Linear(in_feat, out_feat)] - if normalize: - layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - self.model = nn.Sequential( - *block(latent_dim, 128, normalize=False), - *block(128, 256), - *block(256, 512), - *block(512, 1024), + *_block(latent_dim, 128, normalize=False), + *_block(128, 256), + *_block(256, 512), + *_block(512, 1024), nn.Linear(1024, int(math.prod(img_shape))), nn.Tanh(), ) diff --git a/tests/tests_pytorch/strategies/test_multi_model_ddp.py b/tests/tests_pytorch/strategies/test_multi_model_ddp.py new file mode 100644 index 0000000000000..6e44a05ae95c8 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_multi_model_ddp.py @@ -0,0 +1,95 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import PropertyMock + +import torch +from torch import nn + +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy + + +def test_multi_model_ddp_setup_and_register_hooks(): + class Parent(nn.Module): + def __init__(self): + super().__init__() + self.gen = nn.Linear(1, 1) + self.dis = nn.Linear(1, 1) + + model = Parent() + original_children = [model.gen, model.dis] + + strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) + + wrapped_modules = [] + wrapped_device_ids = [] + + class DummyDDP(nn.Module): + def __init__(self, module: nn.Module, device_ids=None, **kwargs): + super().__init__() + self.module = module + wrapped_modules.append(module) + wrapped_device_ids.append(device_ids) + + with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): + returned_model = strategy._setup_model(model) + assert returned_model is model + assert isinstance(model.gen, DummyDDP) + assert isinstance(model.dis, DummyDDP) + assert wrapped_modules == original_children + assert wrapped_device_ids == [None, None] + + strategy.model = model + with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: + with mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device: + root_device.return_value = torch.device("cuda", 0) + strategy._register_ddp_hooks() + + assert register_hook.call_count == 2 + register_hook.assert_any_call( + model=model.gen, + ddp_comm_state=strategy._ddp_comm_state, + ddp_comm_hook=strategy._ddp_comm_hook, + ddp_comm_wrapper=strategy._ddp_comm_wrapper, + ) + register_hook.assert_any_call( + model=model.dis, + ddp_comm_state=strategy._ddp_comm_state, + ddp_comm_hook=strategy._ddp_comm_hook, + ddp_comm_wrapper=strategy._ddp_comm_wrapper, + ) + + +def test_multi_model_ddp_register_hooks_cpu_noop(): + class Parent(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gen = nn.Linear(1, 1) + self.dis = nn.Linear(1, 1) + + model = Parent() + strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) + + class DummyDDP(nn.Module): + def __init__(self, module: nn.Module, device_ids=None, **kwargs): + super().__init__() + self.module = module + + with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): + strategy.model = strategy._setup_model(model) + + with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: + strategy._register_ddp_hooks() + + register_hook.assert_not_called() From 1fb4027ecb70509f924162c3ddb3925a0ef43ec9 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Aug 2025 19:34:36 +0200 Subject: [PATCH 4/7] with --- tests/tests_pytorch/strategies/test_multi_model_ddp.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_multi_model_ddp.py b/tests/tests_pytorch/strategies/test_multi_model_ddp.py index 6e44a05ae95c8..053f2a3312bf8 100644 --- a/tests/tests_pytorch/strategies/test_multi_model_ddp.py +++ b/tests/tests_pytorch/strategies/test_multi_model_ddp.py @@ -51,10 +51,12 @@ def __init__(self, module: nn.Module, device_ids=None, **kwargs): assert wrapped_device_ids == [None, None] strategy.model = model - with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: - with mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device: - root_device.return_value = torch.device("cuda", 0) - strategy._register_ddp_hooks() + with ( + mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook, + mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device, + ): + root_device.return_value = torch.device("cuda", 0) + strategy._register_ddp_hooks() assert register_hook.call_count == 2 register_hook.assert_any_call( From ec623976a5427abbec26f39da4e75c62c61bdfa1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 8 Aug 2025 19:37:48 +0200 Subject: [PATCH 5/7] Apply suggestions from code review --- .../generative_adversarial_net.py | 2 +- .../generative_adversarial_net_ddp.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 310fe7af0e08c..26dce027704e7 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -36,7 +36,7 @@ import torchvision -def _block(in_feat: int, out_feat: int, normalize: bool = True): +def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list: layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py index 6293a34706e2c..dbcd3bcc50875 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -216,17 +216,18 @@ def validation_step(self, batch, batch_idx): # ! TESTING @torch.no_grad() def on_validation_epoch_end(self): - if self.current_epoch % 5: - self.generator.eval(), self.discriminator.eval() + if not self.current_epoch % 5: + return + self.generator.eval(), self.discriminator.eval() - z = self.validation_z.type_as(self.generator.module.model[0].weight) - sample_imgs = self(z) + z = self.validation_z.type_as(self.generator.module.model[0].weight) + sample_imgs = self(z) - if self.trainer.is_global_zero: - grid = torchvision.utils.make_grid(sample_imgs) - torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png")) + if self.trainer.is_global_zero: + grid = torchvision.utils.make_grid(sample_imgs) + torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png")) - self.generator.train(), self.discriminator.train() + self.generator.train(), self.discriminator.train() def main(args: Namespace) -> None: From ece7d38abb908e96eb72af363606d2585e8843f2 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 8 Aug 2025 23:46:19 +0200 Subject: [PATCH 6/7] formating --- src/lightning/pytorch/strategies/ddp.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index f69baa7ae2b13..502e560d3e1de 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -214,7 +214,7 @@ def set_world_ranks(self) -> None: rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _register_ddp_hooks(self) -> None: - log.debug(f"{self.__class__.__name__}: registering ddp hooks") + log.debug(f"{self.__class__.__name__}: registering DDP hooks") # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if self.root_device.type == "cuda": @@ -431,25 +431,25 @@ def _setup_model(self, model: Module) -> Module: if isinstance(module, Module): ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) setattr(model, name, ddp_module) - return model @override def _register_ddp_hooks(self) -> None: - log.debug(f"{self.__class__.__name__}: registering ddp hooks") + log.debug(f"{self.__class__.__name__}: registering DDP hooks") # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 - if self.root_device.type == "cuda": - assert isinstance(self.model, Module) - - for name, module in self.model.named_children(): - assert isinstance(module, DistributedDataParallel) - _register_ddp_comm_hook( - model=module, - ddp_comm_state=self._ddp_comm_state, - ddp_comm_hook=self._ddp_comm_hook, - ddp_comm_wrapper=self._ddp_comm_wrapper, - ) + if self.root_device.type != "cuda": + return + assert isinstance(self.model, Module) + + for name, module in self.model.named_children(): + assert isinstance(module, DistributedDataParallel) + _register_ddp_comm_hook( + model=module, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) class _DDPForwardRedirection(_ForwardRedirection): From 8b1fe23f69f50664833b1ff5e17d0a7e33c0e8f7 Mon Sep 17 00:00:00 2001 From: "Yunhoe, Ku" Date: Mon, 11 Aug 2025 17:12:05 +0900 Subject: [PATCH 7/7] misc: resolve some review comments for product consistency --- .../generative_adversarial_net_ddp.py | 51 ++++++------------- src/lightning/pytorch/strategies/ddp.py | 2 +- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py index dbcd3bcc50875..e0251185bb65b 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -20,19 +20,12 @@ """ import math - -# ! TESTING -import os -import sys from argparse import ArgumentParser, Namespace import torch import torch.nn as nn import torch.nn.functional as F -sys.path.append(os.path.join(os.getcwd(), "src")) -# ! TESTING - from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule @@ -44,7 +37,7 @@ import torchvision -def _block(in_feat: int, out_feat: int, normalize: bool = True): +def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list: layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) @@ -135,10 +128,6 @@ def __init__( self.example_input_array = torch.zeros(2, self.hparams.latent_dim) - # ! TESTING - self.save_path = "pl_test_multi_gpu" - os.makedirs(self.save_path, exist_ok=True) - def forward(self, z): return self.generator(z) @@ -203,36 +192,25 @@ def configure_optimizers(self): opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) return opt_g, opt_d - # ! TESTING - def on_train_epoch_start(self): - if self.trainer.is_global_zero: - print("GEN: ", self.generator.module.model[0].bias[:10]) - print("DISC: ", self.discriminator.module.model[0].bias[:10]) - - # ! TESTING - def validation_step(self, batch, batch_idx): - pass + def on_train_epoch_end(self): + z = self.validation_z.type_as(self.generator.model[0].weight) - # ! TESTING - @torch.no_grad() - def on_validation_epoch_end(self): - if not self.current_epoch % 5: - return - self.generator.eval(), self.discriminator.eval() - - z = self.validation_z.type_as(self.generator.module.model[0].weight) + # log sampled images` sample_imgs = self(z) - - if self.trainer.is_global_zero: - grid = torchvision.utils.make_grid(sample_imgs) - torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png")) - - self.generator.train(), self.discriminator.train() + grid = torchvision.utils.make_grid(sample_imgs) + for logger in self.loggers: + logger.experiment.add_image("generated_images", grid, self.current_epoch) def main(args: Namespace) -> None: + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim) + # ------------------------ + # 2 INIT TRAINER + # ------------------------ # ! `MultiModelDDPStrategy` is critical for multi-gpu training # ! Otherwise, it will not work with multiple models. # ! There are two ways to run training codes with previous `DDPStrategy`; @@ -246,6 +224,9 @@ def main(args: Namespace) -> None: max_epochs=100, ) + # ------------------------ + # 3 START TRAINING + # ------------------------ trainer.fit(model, dm) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 502e560d3e1de..aa758ba69c54d 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -421,7 +421,7 @@ def teardown(self) -> None: class MultiModelDDPStrategy(DDPStrategy): @override - def _setup_model(self, model: Module) -> Module: + def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5