Skip to content
This repository was archived by the owner on Aug 2, 2024. It is now read-only.

Commit e7e3b2a

Browse files
Fix CUDA and DDP training (#267)
* Fix GPU training for new PyTorch * Fix CUDA DDP
1 parent e76e0d9 commit e7e3b2a

File tree

7 files changed

+48
-31
lines changed

7 files changed

+48
-31
lines changed

examples/components/CCFRAUD/traininsilo/conda.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ name: ccfraud_train_conda_env
22
channels:
33
- defaults
44
- pytorch
5+
- nvidia
56
dependencies:
67
- python=3.8
78
- pip=22.3.1
89
- pytorch=1.13.1
9-
- torchvision=0.14.1
10-
- cudatoolkit=11.3
10+
- pytorch-cuda=11.6
1111
- pip:
1212
- azureml-mlflow==1.48.0
1313
- pandas==1.5.2

examples/components/CCFRAUD/traininsilo/run.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,13 +569,13 @@ def run(args):
569569
logger.info(f"Distributed process rank: {os.environ['RANK']}")
570570
logger.info(f"Distributed world size: {os.environ['WORLD_SIZE']}")
571571

572-
if int(os.environ["WORLD_SIZE"]) > 1 and torch.cuda.is_available():
572+
if int(os.environ.get("WORLD_SIZE", "1")) > 1 and torch.cuda.is_available():
573573
dist.init_process_group(
574574
"nccl",
575575
rank=int(os.environ["RANK"]),
576-
world_size=int(os.environ["WORLD_SIZE"]),
576+
world_size=int(os.environ.get("WORLD_SIZE", "1")),
577577
)
578-
elif int(os.environ["WORLD_SIZE"]) > 1:
578+
elif int(os.environ.get("WORLD_SIZE", "1")) > 1:
579579
dist.init_process_group("gloo")
580580

581581
trainer = CCFraudTrainer(
@@ -594,11 +594,12 @@ def run(args):
594594
experiment_name=args.metrics_prefix,
595595
iteration_name=args.iteration_name,
596596
device_id=int(os.environ["RANK"]),
597-
distributed=int(os.environ["WORLD_SIZE"]) > 1 and torch.cuda.is_available(),
597+
distributed=int(os.environ.get("WORLD_SIZE", "1")) > 1
598+
and torch.cuda.is_available(),
598599
)
599600
trainer.execute(args.checkpoint)
600601

601-
if torch.cuda.is_available() or int(os.environ["WORLD_SIZE"]) > 1:
602+
if int(os.environ.get("WORLD_SIZE", "1")) > 1:
602603
dist.destroy_process_group()
603604

604605

examples/components/MNIST/traininsilo/conda.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ name: mnist_conda_env
22
channels:
33
- defaults
44
- pytorch
5+
- nvidia
56
dependencies:
67
- python=3.8
78
- pip=22.3.1
89
- pytorch=1.13.1
910
- torchvision=0.14.1
10-
- cudatoolkit=11.3
11+
- pytorch-cuda=11.6
1112
- pip:
1213
- azureml-mlflow==1.48.0
1314
- pandas==1.5.2

examples/components/NER/traininsilo/conda.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ name: ner_train_conda_env
22
channels:
33
- defaults
44
- pytorch
5+
- nvidia
56
dependencies:
67
- python=3.8
78
- pip=22.3.1
89
- pytorch=1.13.1
9-
- cudatoolkit=11.3
10+
- pytorch-cuda=11.6
1011
- pip:
1112
- azureml-mlflow==1.48.0
1213
- pandas==1.5.2

examples/components/NER/traininsilo/run.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,15 @@ def __init__(
140140

141141
if self._distributed:
142142
logger.info("Setting up distributed samplers.")
143-
self.train_sampler_ = DistributedSampler(self.train_dataset_)
144-
self.test_sampler_ = DistributedSampler(self.test_dataset_)
143+
self.train_sampler_ = DistributedSampler(train_dataset)
144+
self.test_sampler_ = DistributedSampler(test_dataset)
145145
else:
146146
self.train_sampler_ = None
147147
self.test_sampler_ = None
148148

149149
self.train_loader_ = DataLoader(
150150
train_dataset,
151-
shuffle=True,
151+
shuffle=(not self._distributed),
152152
collate_fn=data_collator,
153153
batch_size=self._batch_size,
154154
sampler=self.train_sampler_,
@@ -157,7 +157,7 @@ def __init__(
157157
test_dataset,
158158
collate_fn=data_collator,
159159
batch_size=self._batch_size,
160-
sampler=self.train_sampler_,
160+
sampler=self.test_sampler_,
161161
)
162162

163163
logger.info(f"Train loader steps: {len(self.train_loader_)}")
@@ -179,19 +179,19 @@ def __init__(
179179
trainable_params += p.numel()
180180
logger.info(f"Trainable parameters: {trainable_params}")
181181

182-
self.model_.train()
182+
self.model_.to(self.device_)
183183
if self._distributed:
184184
self.model_ = DDP(
185185
self.model_,
186186
device_ids=[self._rank] if self._rank is not None else None,
187187
output_device=self._rank,
188188
)
189-
self.model_.to(self.device_)
190189
self.metric_ = evaluate.load("seqeval")
191190

192191
# DP
193192
logger.info(f"DP: {dp}")
194193
if dp:
194+
self.model_.train()
195195
if not ModuleValidator.is_valid(self.model_):
196196
self.model_ = ModuleValidator.fix(self.model_)
197197

@@ -625,13 +625,13 @@ def run(args):
625625
logger.info(f"Distributed process rank: {os.environ['RANK']}")
626626
logger.info(f"Distributed world size: {os.environ['WORLD_SIZE']}")
627627

628-
if int(os.environ["WORLD_SIZE"]) > 1 and torch.cuda.is_available():
628+
if int(os.environ.get("WORLD_SIZE", "1")) > 1 and torch.cuda.is_available():
629629
dist.init_process_group(
630630
"nccl",
631631
rank=int(os.environ["RANK"]),
632-
world_size=int(os.environ["WORLD_SIZE"]),
632+
world_size=int(os.environ.get("WORLD_SIZE", "1")),
633633
)
634-
elif int(os.environ["WORLD_SIZE"]) > 1:
634+
elif int(os.environ.get("WORLD_SIZE", "1")) > 1:
635635
dist.init_process_group("gloo")
636636

637637
trainer = NERTrainer(
@@ -651,11 +651,12 @@ def run(args):
651651
iteration_num=args.iteration_num,
652652
batch_size=args.batch_size,
653653
device_id=int(os.environ["RANK"]),
654-
distributed=int(os.environ["WORLD_SIZE"]) > 1 and torch.cuda.is_available(),
654+
distributed=int(os.environ.get("WORLD_SIZE", "1")) > 1
655+
and torch.cuda.is_available(),
655656
)
656657
trainer.execute(args.checkpoint)
657658

658-
if torch.cuda.is_available() or int(os.environ["WORLD_SIZE"]) > 1:
659+
if int(os.environ.get("WORLD_SIZE", "1")) > 1:
659660
dist.destroy_process_group()
660661

661662

examples/components/PNEUMONIA/traininsilo/conda.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ name: pneumonia_train_conda_env
22
channels:
33
- defaults
44
- pytorch
5+
- nvidia
56
dependencies:
67
- python=3.8
78
- pip=22.3.1
89
- pytorch=1.13.1
910
- torchvision=0.13.1
10-
- cudatoolkit=11.3
11+
- pytorch-cuda=11.6
1112
- pip:
1213
- azureml-mlflow==1.48.0
1314
- opacus==1.3.0

examples/components/PNEUMONIA/traininsilo/run.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ def __init__(
9494

9595
# Training setup
9696
self.model_ = PneumoniaNetwork()
97-
self.device_ = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9897
self.model_.to(self.device_)
98+
if self._distributed:
99+
self.model_ = DDP(
100+
self.model_,
101+
device_ids=[self._rank] if self._rank is not None else None,
102+
output_device=self._rank,
103+
)
99104
self.loss_ = nn.CrossEntropyLoss()
100105

101106
# Data setup
@@ -125,7 +130,7 @@ def __init__(
125130
self.train_loader_ = DataLoader(
126131
dataset=self.train_dataset_,
127132
batch_size=32,
128-
shuffle=True,
133+
shuffle=(not self._distributed),
129134
drop_last=True,
130135
sampler=self.train_sampler_,
131136
)
@@ -245,9 +250,15 @@ def local_train(self, checkpoint):
245250
checkpoint: Previous model checkpoint from where training has to be started.
246251
"""
247252
if checkpoint:
248-
self.model_.load_state_dict(
249-
torch.load(checkpoint + "/model.pt", map_location=self.device_)
250-
)
253+
if self._distributed:
254+
# DDP comes with "module." prefix: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
255+
self.model_.module.load_state_dict(
256+
torch.load(checkpoint + "/model.pt", map_location=self.device_)
257+
)
258+
else:
259+
self.model_.load_state_dict(
260+
torch.load(checkpoint + "/model.pt", map_location=self.device_)
261+
)
251262

252263
with mlflow.start_run() as mlflow_run:
253264
# get Mlflow client and root run id
@@ -449,13 +460,13 @@ def run(args):
449460
logger.info(f"Distributed process rank: {os.environ['RANK']}")
450461
logger.info(f"Distributed world size: {os.environ['WORLD_SIZE']}")
451462

452-
if int(os.environ["WORLD_SIZE"]) > 1 and torch.cuda.is_available():
463+
if int(os.environ.get("WORLD_SIZE", "1")) > 1 and torch.cuda.is_available():
453464
dist.init_process_group(
454465
"nccl",
455466
rank=int(os.environ["RANK"]),
456-
world_size=int(os.environ["WORLD_SIZE"]),
467+
world_size=int(os.environ.get("WORLD_SIZE", "1")),
457468
)
458-
elif int(os.environ["WORLD_SIZE"]) > 1:
469+
elif int(os.environ.get("WORLD_SIZE", "1")) > 1:
459470
dist.init_process_group("gloo")
460471

461472
trainer = PTLearner(
@@ -471,12 +482,13 @@ def run(args):
471482
iteration_num=args.iteration_num,
472483
model_path=args.model + "/model.pt",
473484
device_id=int(os.environ["RANK"]),
474-
distributed=int(os.environ["WORLD_SIZE"]) > 1 and torch.cuda.is_available(),
485+
distributed=int(os.environ.get("WORLD_SIZE", "1")) > 1
486+
and torch.cuda.is_available(),
475487
)
476488

477489
trainer.execute(args.checkpoint)
478490

479-
if torch.cuda.is_available() or int(os.environ["WORLD_SIZE"]) > 1:
491+
if int(os.environ.get("WORLD_SIZE", "1")) > 1:
480492
dist.destroy_process_group()
481493

482494

0 commit comments

Comments
 (0)