Skip to content

Commit 4c4d0e6

Browse files
Merge pull request #1635 from PyTorchLightning/pkl
Fixes CPU DDP breaking change and DDP change
2 parents 8b82ce0 + b993a3e commit 4c4d0e6

File tree

5 files changed

+55
-38
lines changed

5 files changed

+55
-38
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
133133
world_size = {
134134
'ddp': self.num_nodes * self.num_processes,
135135
'ddp2': self.num_nodes,
136+
'ddp_cpu': self.num_processes * self.num_nodes
136137
}
137138
sampler = DistributedSampler(
138139
dataloader.dataset,

pytorch_lightning/trainer/trainer.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -614,17 +614,8 @@ def allowed_type(x):
614614
return bool(parsing.strtobool(x))
615615

616616
if arg == 'gpus':
617-
def allowed_type(x):
618-
if ',' in x:
619-
return str(x)
620-
else:
621-
return int(x)
622-
623-
def arg_default(x):
624-
if ',' in x:
625-
return str(x)
626-
else:
627-
return int(x)
617+
allowed_type = Trainer.allowed_type
618+
arg_default = Trainer.arg_default
628619

629620
parser.add_argument(
630621
f'--{arg}',
@@ -637,6 +628,18 @@ def arg_default(x):
637628

638629
return parser
639630

631+
def allowed_type(x):
632+
if ',' in x:
633+
return str(x)
634+
else:
635+
return int(x)
636+
637+
def arg_default(x):
638+
if ',' in x:
639+
return str(x)
640+
else:
641+
return int(x)
642+
640643
@classmethod
641644
def from_argparse_args(cls, args, **kwargs):
642645

@@ -711,6 +714,10 @@ def fit(
711714
model.logger = self.logger
712715
self.copy_trainer_model_properties(model)
713716

717+
# clean hparams
718+
if hasattr(model, 'hparams'):
719+
parsing.clean_namespace(model.hparams)
720+
714721
# set up the passed in dataloaders (if needed)
715722
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
716723

pytorch_lightning/trainer/training_io.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
LightningDistributedDataParallel,
102102
LightningDataParallel,
103103
)
104-
from pytorch_lightning.utilities import rank_zero_warn
104+
from pytorch_lightning.utilities import rank_zero_warn, parsing
105105

106106
try:
107107
import torch_xla
@@ -325,7 +325,7 @@ def dump_checkpoint(self):
325325
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
326326

327327
if hasattr(model, "hparams"):
328-
self.__clean_namespace(model.hparams)
328+
parsing.clean_namespace(model.hparams)
329329
is_namespace = isinstance(model.hparams, Namespace)
330330
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
331331
checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict'
@@ -339,31 +339,6 @@ def dump_checkpoint(self):
339339

340340
return checkpoint
341341

342-
def __clean_namespace(self, hparams):
343-
"""
344-
Removes all functions from hparams so we can pickle
345-
:param hparams:
346-
:return:
347-
"""
348-
349-
if isinstance(hparams, Namespace):
350-
del_attrs = []
351-
for k in hparams.__dict__:
352-
if callable(getattr(hparams, k)):
353-
del_attrs.append(k)
354-
355-
for k in del_attrs:
356-
delattr(hparams, k)
357-
358-
elif isinstance(hparams, dict):
359-
del_attrs = []
360-
for k, v in hparams.items():
361-
if callable(v):
362-
del_attrs.append(k)
363-
364-
for k in del_attrs:
365-
del hparams[k]
366-
367342
# --------------------
368343
# HPC IO
369344
# --------------------

pytorch_lightning/utilities/parsing.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from argparse import Namespace
2+
3+
14
def strtobool(val):
25
"""Convert a string representation of truth to true (1) or false (0).
36
Copied from the python implementation distutils.utils.strtobool
@@ -18,3 +21,29 @@ def strtobool(val):
1821
return 0
1922
else:
2023
raise ValueError(f'invalid truth value {val}')
24+
25+
26+
def clean_namespace(hparams):
27+
"""
28+
Removes all functions from hparams so we can pickle
29+
:param hparams:
30+
:return:
31+
"""
32+
33+
if isinstance(hparams, Namespace):
34+
del_attrs = []
35+
for k in hparams.__dict__:
36+
if callable(getattr(hparams, k)):
37+
del_attrs.append(k)
38+
39+
for k in del_attrs:
40+
delattr(hparams, k)
41+
42+
elif isinstance(hparams, dict):
43+
del_attrs = []
44+
for k, v in hparams.items():
45+
if callable(v):
46+
del_attrs.append(k)
47+
48+
for k in del_attrs:
49+
del hparams[k]

tests/trainer/test_trainer_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def test_add_argparse_args_redefined(cli_args):
4747
assert depr_name not in args
4848

4949
trainer = Trainer.from_argparse_args(args=args)
50+
51+
# make sure trainer can be pickled
52+
import pickle
53+
pickle.dumps(trainer)
54+
5055
assert isinstance(trainer, Trainer)
5156

5257

0 commit comments

Comments
 (0)