Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
00727e8
add: `MultiModelDDPStrategy` and its execution codes
samsara-ku Jun 25, 2025
e6b061a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2025
aa9b027
refactor: extract block helper in GAN example
samsara-ku Aug 5, 2025
5503d3a
Merge pull request #1 from samsara-ku/codex/add-tests-for-multimodeld…
samsara-ku Aug 5, 2025
dc128b4
Merge branch 'master' into bugfix/gan-ddp-training
Borda Aug 8, 2025
1fb4027
with
Borda Aug 8, 2025
ec62397
Apply suggestions from code review
Borda Aug 8, 2025
ece7d38
formating
Borda Aug 8, 2025
8b1fe23
misc: resolve some review comments for product consistency
samsara-ku Aug 11, 2025
4b22284
misc: merge gan training example, add docstring of MultiModelDDPStrategy
samsara-ku Aug 12, 2025
97dabf8
misc: add docstring of MultiModelDDPStrategy
samsara-ku Aug 12, 2025
033e8e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2025
57e864a
update
Borda Aug 13, 2025
f157f59
Merge branch 'master' into bugfix/gan-ddp-training
SkafteNicki Aug 13, 2025
24872e7
long line
Borda Aug 13, 2025
3891102
Merge branch 'master' into bugfix/gan-ddp-training
SkafteNicki Aug 14, 2025
8121337
add: set base test case and __init__py for MultiModelDDPStrategy
samsara-ku Aug 15, 2025
c442fc3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2025
ab9b2dd
Merge branch 'master' into bugfix/gan-ddp-training
Borda Aug 18, 2025
a58039e
Merge branch 'master' into bugfix/gan-ddp-training
Borda Sep 2, 2025
2ae8072
Merge branch 'master' into bugfix/gan-ddp-training
Borda Sep 4, 2025
77a81b4
Merge branch 'master' into bugfix/gan-ddp-training
Borda Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions examples/pytorch/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
import torchvision


def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list:
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers


class Generator(nn.Module):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Expand All @@ -47,19 +55,11 @@ class Generator(nn.Module):
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
*_block(latent_dim, 128, normalize=False),
*_block(128, 256),
*_block(256, 512),
*_block(512, 1024),
nn.Linear(1024, int(math.prod(img_shape))),
nn.Tanh(),
)
Expand Down
263 changes: 263 additions & 0 deletions examples/pytorch/domain_templates/generative_adversarial_net_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""To run this template just do: python generative_adversarial_net.py.
After a few epochs, launch TensorBoard to see the images being generated at every batch:
tensorboard --logdir default
"""

import math

# ! TESTING
import os
import sys
from argparse import ArgumentParser, Namespace

import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append(os.path.join(os.getcwd(), "src"))
# ! TESTING

from lightning.pytorch import cli_lightning_logo
from lightning.pytorch.core import LightningModule
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision


def _block(in_feat: int, out_feat: int, normalize: bool = True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers


class Generator(nn.Module):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Generator(
(model): Sequential(...)
)
"""

def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

self.model = nn.Sequential(
*_block(latent_dim, 128, normalize=False),
*_block(128, 256),
*_block(256, 512),
*_block(512, 1024),
nn.Linear(1024, int(math.prod(img_shape))),
nn.Tanh(),
)

def forward(self, z):
img = self.model(z)
return img.view(img.size(0), *self.img_shape)


class Discriminator(nn.Module):
"""
>>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Discriminator(
(model): Sequential(...)
)
"""

def __init__(self, img_shape):
super().__init__()

self.model = nn.Sequential(
nn.Linear(int(math.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
return self.model(img_flat)


class GAN(LightningModule):
"""
>>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
GAN(
(generator): Generator(
(model): Sequential(...)
)
(discriminator): Discriminator(
(model): Sequential(...)
)
)
"""

def __init__(
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False

# networks
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape)
self.discriminator = Discriminator(img_shape=img_shape)

self.validation_z = torch.randn(8, self.hparams.latent_dim)

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

# ! TESTING
self.save_path = "pl_test_multi_gpu"
os.makedirs(self.save_path, exist_ok=True)

def forward(self, z):
return self.generator(z)

@staticmethod
def adversarial_loss(y_hat, y):
return F.binary_cross_entropy_with_logits(y_hat, y)

def training_step(self, batch):
imgs, _ = batch

opt_g, opt_d = self.optimizers()

# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)

# Train generator
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

self.toggle_optimizer(opt_g)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
opt_g.zero_grad()
self.manual_backward(g_loss)
opt_g.step()
self.untoggle_optimizer(opt_g)

# Train discriminator
# Measure discriminator's ability to classify real from generated samples
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

self.toggle_optimizer(opt_d)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2

opt_d.zero_grad()
self.manual_backward(d_loss)
opt_d.step()
self.untoggle_optimizer(opt_d)

self.log_dict({"d_loss": d_loss, "g_loss": g_loss})

def configure_optimizers(self):
lr = self.hparams.lr
b1 = self.hparams.b1
b2 = self.hparams.b2

opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return opt_g, opt_d

# ! TESTING
def on_train_epoch_start(self):
if self.trainer.is_global_zero:
print("GEN: ", self.generator.module.model[0].bias[:10])
print("DISC: ", self.discriminator.module.model[0].bias[:10])

# ! TESTING
def validation_step(self, batch, batch_idx):
pass

# ! TESTING
@torch.no_grad()
def on_validation_epoch_end(self):
if not self.current_epoch % 5:
return
self.generator.eval(), self.discriminator.eval()

z = self.validation_z.type_as(self.generator.module.model[0].weight)
sample_imgs = self(z)

if self.trainer.is_global_zero:
grid = torchvision.utils.make_grid(sample_imgs)
torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png"))

self.generator.train(), self.discriminator.train()


def main(args: Namespace) -> None:
model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim)

# ! `MultiModelDDPStrategy` is critical for multi-gpu training
# ! Otherwise, it will not work with multiple models.
# ! There are two ways to run training codes with previous `DDPStrategy`;
# ! 1) activate `find_unused_parameters=True`, 2) change from self.manual_backward(loss) to loss.backward()
# ! Neither of them is desirable.
dm = MNISTDataModule()
trainer = Trainer(
accelerator="auto",
devices=[0, 1, 2, 3],
strategy=MultiModelDDPStrategy(),
max_epochs=100,
)

trainer.fit(model, dm)


if __name__ == "__main__":
cli_lightning_logo()
parser = ArgumentParser()

# Hyperparameters
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
args = parser.parse_args()

main(args)
33 changes: 33 additions & 0 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,39 @@ def teardown(self) -> None:
super().teardown()


class MultiModelDDPStrategy(DDPStrategy):
@override
def _setup_model(self, model: Module) -> Module:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
for name, module in model.named_children():
if isinstance(module, Module):
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs)
setattr(model, name, ddp_module)

return model

@override
def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type == "cuda":
assert isinstance(self.model, Module)

for name, module in self.model.named_children():
assert isinstance(module, DistributedDataParallel)
_register_ddp_comm_hook(
model=module,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)


class _DDPForwardRedirection(_ForwardRedirection):
@override
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
Expand Down
Loading
Loading