diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 7ce7682d82c76..26dce027704e7 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -36,6 +36,14 @@ import torchvision +def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list: + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + class Generator(nn.Module): """ >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -47,19 +55,11 @@ class Generator(nn.Module): def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape - - def block(in_feat, out_feat, normalize=True): - layers = [nn.Linear(in_feat, out_feat)] - if normalize: - layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - self.model = nn.Sequential( - *block(latent_dim, 128, normalize=False), - *block(128, 256), - *block(256, 512), - *block(512, 1024), + *_block(latent_dim, 128, normalize=False), + *_block(128, 256), + *_block(256, 512), + *_block(512, 1024), nn.Linear(1024, int(math.prod(img_shape))), nn.Tanh(), ) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py new file mode 100644 index 0000000000000..e0251185bb65b --- /dev/null +++ b/examples/pytorch/domain_templates/generative_adversarial_net_ddp.py @@ -0,0 +1,244 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""To run this template just do: python generative_adversarial_net.py. + +After a few epochs, launch TensorBoard to see the images being generated at every batch: + +tensorboard --logdir default + +""" + +import math +from argparse import ArgumentParser, Namespace + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lightning.pytorch import cli_lightning_logo +from lightning.pytorch.core import LightningModule +from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision + + +def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list: + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + +class Generator(nn.Module): + """ + >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Generator( + (model): Sequential(...) + ) + """ + + def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): + super().__init__() + self.img_shape = img_shape + + self.model = nn.Sequential( + *_block(latent_dim, 128, normalize=False), + *_block(128, 256), + *_block(256, 512), + *_block(512, 1024), + nn.Linear(1024, int(math.prod(img_shape))), + nn.Tanh(), + ) + + def forward(self, z): + img = self.model(z) + return img.view(img.size(0), *self.img_shape) + + +class Discriminator(nn.Module): + """ + >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Discriminator( + (model): Sequential(...) + ) + """ + + def __init__(self, img_shape): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(int(math.prod(img_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + ) + + def forward(self, img): + img_flat = img.view(img.size(0), -1) + return self.model(img_flat) + + +class GAN(LightningModule): + """ + >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + GAN( + (generator): Generator( + (model): Sequential(...) + ) + (discriminator): Discriminator( + (model): Sequential(...) + ) + ) + """ + + def __init__( + self, + img_shape: tuple = (1, 28, 28), + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + latent_dim: int = 100, + ): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False + + # networks + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) + self.discriminator = Discriminator(img_shape=img_shape) + + self.validation_z = torch.randn(8, self.hparams.latent_dim) + + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + + def forward(self, z): + return self.generator(z) + + @staticmethod + def adversarial_loss(y_hat, y): + return F.binary_cross_entropy_with_logits(y_hat, y) + + def training_step(self, batch): + imgs, _ = batch + + opt_g, opt_d = self.optimizers() + + # sample noise + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) + z = z.type_as(imgs) + + # Train generator + # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + self.toggle_optimizer(opt_g) + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) + opt_g.zero_grad() + self.manual_backward(g_loss) + opt_g.step() + self.untoggle_optimizer(opt_g) + + # Train discriminator + # Measure discriminator's ability to classify real from generated samples + # how well can it label as real? + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + self.toggle_optimizer(opt_d) + real_loss = self.adversarial_loss(self.discriminator(imgs), valid) + + # how well can it label as fake? + fake = torch.zeros(imgs.size(0), 1) + fake = fake.type_as(imgs) + + fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) + + # discriminator loss is the average of these + d_loss = (real_loss + fake_loss) / 2 + + opt_d.zero_grad() + self.manual_backward(d_loss) + opt_d.step() + self.untoggle_optimizer(opt_d) + + self.log_dict({"d_loss": d_loss, "g_loss": g_loss}) + + def configure_optimizers(self): + lr = self.hparams.lr + b1 = self.hparams.b1 + b2 = self.hparams.b2 + + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + return opt_g, opt_d + + def on_train_epoch_end(self): + z = self.validation_z.type_as(self.generator.model[0].weight) + + # log sampled images` + sample_imgs = self(z) + grid = torchvision.utils.make_grid(sample_imgs) + for logger in self.loggers: + logger.experiment.add_image("generated_images", grid, self.current_epoch) + + +def main(args: Namespace) -> None: + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ + model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim) + + # ------------------------ + # 2 INIT TRAINER + # ------------------------ + # ! `MultiModelDDPStrategy` is critical for multi-gpu training + # ! Otherwise, it will not work with multiple models. + # ! There are two ways to run training codes with previous `DDPStrategy`; + # ! 1) activate `find_unused_parameters=True`, 2) change from self.manual_backward(loss) to loss.backward() + # ! Neither of them is desirable. + dm = MNISTDataModule() + trainer = Trainer( + accelerator="auto", + devices=[0, 1, 2, 3], + strategy=MultiModelDDPStrategy(), + max_epochs=100, + ) + + # ------------------------ + # 3 START TRAINING + # ------------------------ + trainer.fit(model, dm) + + +if __name__ == "__main__": + cli_lightning_logo() + parser = ArgumentParser() + + # Hyperparameters + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") + args = parser.parse_args() + + main(args) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..aa758ba69c54d 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -214,7 +214,7 @@ def set_world_ranks(self) -> None: rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _register_ddp_hooks(self) -> None: - log.debug(f"{self.__class__.__name__}: registering ddp hooks") + log.debug(f"{self.__class__.__name__}: registering DDP hooks") # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if self.root_device.type == "cuda": @@ -419,6 +419,39 @@ def teardown(self) -> None: super().teardown() +class MultiModelDDPStrategy(DDPStrategy): + @override + def _setup_model(self, model: Module) -> DistributedDataParallel: + device_ids = self.determine_ddp_device_ids() + log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + with ctx: + for name, module in model.named_children(): + if isinstance(module, Module): + ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) + setattr(model, name, ddp_module) + return model + + @override + def _register_ddp_hooks(self) -> None: + log.debug(f"{self.__class__.__name__}: registering DDP hooks") + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if self.root_device.type != "cuda": + return + assert isinstance(self.model, Module) + + for name, module in self.model.named_children(): + assert isinstance(module, DistributedDataParallel) + _register_ddp_comm_hook( + model=module, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) + + class _DDPForwardRedirection(_ForwardRedirection): @override def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: diff --git a/tests/tests_pytorch/strategies/test_multi_model_ddp.py b/tests/tests_pytorch/strategies/test_multi_model_ddp.py new file mode 100644 index 0000000000000..053f2a3312bf8 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_multi_model_ddp.py @@ -0,0 +1,97 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import PropertyMock + +import torch +from torch import nn + +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy + + +def test_multi_model_ddp_setup_and_register_hooks(): + class Parent(nn.Module): + def __init__(self): + super().__init__() + self.gen = nn.Linear(1, 1) + self.dis = nn.Linear(1, 1) + + model = Parent() + original_children = [model.gen, model.dis] + + strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) + + wrapped_modules = [] + wrapped_device_ids = [] + + class DummyDDP(nn.Module): + def __init__(self, module: nn.Module, device_ids=None, **kwargs): + super().__init__() + self.module = module + wrapped_modules.append(module) + wrapped_device_ids.append(device_ids) + + with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): + returned_model = strategy._setup_model(model) + assert returned_model is model + assert isinstance(model.gen, DummyDDP) + assert isinstance(model.dis, DummyDDP) + assert wrapped_modules == original_children + assert wrapped_device_ids == [None, None] + + strategy.model = model + with ( + mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook, + mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device, + ): + root_device.return_value = torch.device("cuda", 0) + strategy._register_ddp_hooks() + + assert register_hook.call_count == 2 + register_hook.assert_any_call( + model=model.gen, + ddp_comm_state=strategy._ddp_comm_state, + ddp_comm_hook=strategy._ddp_comm_hook, + ddp_comm_wrapper=strategy._ddp_comm_wrapper, + ) + register_hook.assert_any_call( + model=model.dis, + ddp_comm_state=strategy._ddp_comm_state, + ddp_comm_hook=strategy._ddp_comm_hook, + ddp_comm_wrapper=strategy._ddp_comm_wrapper, + ) + + +def test_multi_model_ddp_register_hooks_cpu_noop(): + class Parent(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gen = nn.Linear(1, 1) + self.dis = nn.Linear(1, 1) + + model = Parent() + strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) + + class DummyDDP(nn.Module): + def __init__(self, module: nn.Module, device_ids=None, **kwargs): + super().__init__() + self.module = module + + with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): + strategy.model = strategy._setup_model(model) + + with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: + strategy._register_ddp_hooks() + + register_hook.assert_not_called()