-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix wrong behavior of DDPStrategy
option with simple GAN training using DDP
#20936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
00727e8
e6b061a
aa9b027
5503d3a
dc128b4
1fb4027
ec62397
ece7d38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""To run this template just do: python generative_adversarial_net.py. | ||
|
||
After a few epochs, launch TensorBoard to see the images being generated at every batch: | ||
|
||
tensorboard --logdir default | ||
|
||
""" | ||
|
||
import math | ||
|
||
# ! TESTING | ||
import os | ||
import sys | ||
from argparse import ArgumentParser, Namespace | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
sys.path.append(os.path.join(os.getcwd(), "src")) | ||
# ! TESTING | ||
|
||
from lightning.pytorch import cli_lightning_logo | ||
from lightning.pytorch.core import LightningModule | ||
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule | ||
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy | ||
from lightning.pytorch.trainer import Trainer | ||
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
import torchvision | ||
|
||
|
||
def _block(in_feat: int, out_feat: int, normalize: bool = True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers | ||
|
||
|
||
class Generator(nn.Module): | ||
""" | ||
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | ||
Generator( | ||
(model): Sequential(...) | ||
) | ||
""" | ||
|
||
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): | ||
super().__init__() | ||
self.img_shape = img_shape | ||
|
||
self.model = nn.Sequential( | ||
*_block(latent_dim, 128, normalize=False), | ||
*_block(128, 256), | ||
*_block(256, 512), | ||
*_block(512, 1024), | ||
nn.Linear(1024, int(math.prod(img_shape))), | ||
nn.Tanh(), | ||
) | ||
|
||
def forward(self, z): | ||
img = self.model(z) | ||
return img.view(img.size(0), *self.img_shape) | ||
|
||
|
||
class Discriminator(nn.Module): | ||
""" | ||
>>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | ||
Discriminator( | ||
(model): Sequential(...) | ||
) | ||
""" | ||
|
||
def __init__(self, img_shape): | ||
super().__init__() | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(int(math.prod(img_shape)), 512), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(256, 1), | ||
) | ||
|
||
def forward(self, img): | ||
img_flat = img.view(img.size(0), -1) | ||
return self.model(img_flat) | ||
|
||
|
||
class GAN(LightningModule): | ||
""" | ||
>>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | ||
GAN( | ||
(generator): Generator( | ||
(model): Sequential(...) | ||
) | ||
(discriminator): Discriminator( | ||
(model): Sequential(...) | ||
) | ||
) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
img_shape: tuple = (1, 28, 28), | ||
lr: float = 0.0002, | ||
b1: float = 0.5, | ||
b2: float = 0.999, | ||
latent_dim: int = 100, | ||
): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.automatic_optimization = False | ||
|
||
# networks | ||
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) | ||
self.discriminator = Discriminator(img_shape=img_shape) | ||
|
||
self.validation_z = torch.randn(8, self.hparams.latent_dim) | ||
|
||
self.example_input_array = torch.zeros(2, self.hparams.latent_dim) | ||
|
||
# ! TESTING | ||
self.save_path = "pl_test_multi_gpu" | ||
os.makedirs(self.save_path, exist_ok=True) | ||
|
||
def forward(self, z): | ||
return self.generator(z) | ||
|
||
@staticmethod | ||
def adversarial_loss(y_hat, y): | ||
return F.binary_cross_entropy_with_logits(y_hat, y) | ||
|
||
def training_step(self, batch): | ||
imgs, _ = batch | ||
|
||
opt_g, opt_d = self.optimizers() | ||
|
||
# sample noise | ||
z = torch.randn(imgs.shape[0], self.hparams.latent_dim) | ||
z = z.type_as(imgs) | ||
|
||
# Train generator | ||
# ground truth result (ie: all fake) | ||
# put on GPU because we created this tensor inside training_loop | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
|
||
self.toggle_optimizer(opt_g) | ||
# adversarial loss is binary cross-entropy | ||
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) | ||
opt_g.zero_grad() | ||
self.manual_backward(g_loss) | ||
opt_g.step() | ||
self.untoggle_optimizer(opt_g) | ||
|
||
# Train discriminator | ||
# Measure discriminator's ability to classify real from generated samples | ||
# how well can it label as real? | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
|
||
self.toggle_optimizer(opt_d) | ||
real_loss = self.adversarial_loss(self.discriminator(imgs), valid) | ||
|
||
# how well can it label as fake? | ||
fake = torch.zeros(imgs.size(0), 1) | ||
fake = fake.type_as(imgs) | ||
|
||
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) | ||
|
||
# discriminator loss is the average of these | ||
d_loss = (real_loss + fake_loss) / 2 | ||
|
||
opt_d.zero_grad() | ||
self.manual_backward(d_loss) | ||
opt_d.step() | ||
self.untoggle_optimizer(opt_d) | ||
|
||
self.log_dict({"d_loss": d_loss, "g_loss": g_loss}) | ||
|
||
def configure_optimizers(self): | ||
lr = self.hparams.lr | ||
b1 = self.hparams.b1 | ||
b2 = self.hparams.b2 | ||
|
||
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) | ||
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) | ||
return opt_g, opt_d | ||
|
||
# ! TESTING | ||
def on_train_epoch_start(self): | ||
if self.trainer.is_global_zero: | ||
print("GEN: ", self.generator.module.model[0].bias[:10]) | ||
print("DISC: ", self.discriminator.module.model[0].bias[:10]) | ||
|
||
# ! TESTING | ||
def validation_step(self, batch, batch_idx): | ||
pass | ||
|
||
# ! TESTING | ||
@torch.no_grad() | ||
def on_validation_epoch_end(self): | ||
if not self.current_epoch % 5: | ||
return | ||
self.generator.eval(), self.discriminator.eval() | ||
|
||
z = self.validation_z.type_as(self.generator.module.model[0].weight) | ||
sample_imgs = self(z) | ||
|
||
if self.trainer.is_global_zero: | ||
grid = torchvision.utils.make_grid(sample_imgs) | ||
torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png")) | ||
|
||
self.generator.train(), self.discriminator.train() | ||
|
||
|
||
def main(args: Namespace) -> None: | ||
model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim) | ||
|
||
# ! `MultiModelDDPStrategy` is critical for multi-gpu training | ||
# ! Otherwise, it will not work with multiple models. | ||
# ! There are two ways to run training codes with previous `DDPStrategy`; | ||
# ! 1) activate `find_unused_parameters=True`, 2) change from self.manual_backward(loss) to loss.backward() | ||
# ! Neither of them is desirable. | ||
dm = MNISTDataModule() | ||
trainer = Trainer( | ||
accelerator="auto", | ||
devices=[0, 1, 2, 3], | ||
strategy=MultiModelDDPStrategy(), | ||
max_epochs=100, | ||
) | ||
|
||
trainer.fit(model, dm) | ||
Comment on lines
+241
to
+249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe i am reading it wrong, but is this the only core difference between this new |
||
|
||
|
||
if __name__ == "__main__": | ||
cli_lightning_logo() | ||
parser = ArgumentParser() | ||
|
||
# Hyperparameters | ||
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") | ||
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") | ||
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") | ||
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") | ||
args = parser.parse_args() | ||
|
||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -214,7 +214,7 @@ def set_world_ranks(self) -> None: | |
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank | ||
|
||
def _register_ddp_hooks(self) -> None: | ||
log.debug(f"{self.__class__.__name__}: registering ddp hooks") | ||
log.debug(f"{self.__class__.__name__}: registering DDP hooks") | ||
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode | ||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 | ||
if self.root_device.type == "cuda": | ||
|
@@ -419,6 +419,39 @@ def teardown(self) -> None: | |
super().teardown() | ||
|
||
|
||
class MultiModelDDPStrategy(DDPStrategy): | ||
samsara-ku marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add docstring to class. What is the purpose of the class and when to use it compared to the standard |
||
@override | ||
def _setup_model(self, model: Module) -> Module: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The typing is not very happy here as the parent class has the following footprint: def _setup_model(self, model: Module) -> DistributedDataParallel: |
||
device_ids = self.determine_ddp_device_ids() | ||
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") | ||
# https://pytorch.org/docs/stable/notes/cuda.html#id5 | ||
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() | ||
with ctx: | ||
for name, module in model.named_children(): | ||
if isinstance(module, Module): | ||
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) | ||
setattr(model, name, ddp_module) | ||
return model | ||
|
||
@override | ||
def _register_ddp_hooks(self) -> None: | ||
log.debug(f"{self.__class__.__name__}: registering DDP hooks") | ||
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode | ||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 | ||
if self.root_device.type != "cuda": | ||
return | ||
assert isinstance(self.model, Module) | ||
|
||
for name, module in self.model.named_children(): | ||
assert isinstance(module, DistributedDataParallel) | ||
_register_ddp_comm_hook( | ||
model=module, | ||
ddp_comm_state=self._ddp_comm_state, | ||
ddp_comm_hook=self._ddp_comm_hook, | ||
ddp_comm_wrapper=self._ddp_comm_wrapper, | ||
) | ||
|
||
|
||
class _DDPForwardRedirection(_ForwardRedirection): | ||
@override | ||
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this comment mean?