Skip to content

Commit f064682

Browse files
authored
save initial arguments (#4163)
* save initial arguments * typing * chlog * .
1 parent 4290c9e commit f064682

File tree

4 files changed

+42
-3
lines changed

4 files changed

+42
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Fixed
2323

24+
- Fixed `hparams` saving - save the state when `save_hyperparameters()` is called [in `__init__`] ([#4163](https://github.com/PyTorchLightning/pytorch-lightning/pull/4163))
25+
2426

2527

2628
## [1.0.1] - 2020-10-14

pytorch_lightning/core/lightning.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import copy
1617
import inspect
1718
import os
1819
import re
@@ -1448,9 +1449,11 @@ def save_hyperparameters(self, *args, frame=None) -> None:
14481449
init_args = get_init_args(frame)
14491450
assert init_args, "failed to inspect the self init"
14501451
if not args:
1452+
# take all arguments
14511453
hp = init_args
14521454
self._hparams_name = "kwargs" if hp else None
14531455
else:
1456+
# take only listed arguments in `save_hparams`
14541457
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
14551458
if len(isx_non_str) == 1:
14561459
hp = args[isx_non_str[0]]
@@ -1463,6 +1466,8 @@ def save_hyperparameters(self, *args, frame=None) -> None:
14631466
# `hparams` are expected here
14641467
if hp:
14651468
self._set_hparams(hp)
1469+
# make deep copy so there is not other runtime changes reflected
1470+
self._hparams_initial = copy.deepcopy(self._hparams)
14661471

14671472
def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
14681473
if isinstance(hp, Namespace):
@@ -1594,11 +1599,18 @@ def to_torchscript(
15941599
return torchscript_module
15951600

15961601
@property
1597-
def hparams(self) -> Union[AttributeDict, str]:
1602+
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
15981603
if not hasattr(self, "_hparams"):
15991604
self._hparams = AttributeDict()
16001605
return self._hparams
16011606

1607+
@property
1608+
def hparams_initial(self) -> AttributeDict:
1609+
if not hasattr(self, "_hparams_initial"):
1610+
self._hparams_initial = AttributeDict()
1611+
# prevent any change
1612+
return copy.deepcopy(self._hparams_initial)
1613+
16021614
@hparams.setter
16031615
def hparams(self, hp: Union[dict, Namespace, Any]):
16041616
hparams_assignment_name = self.__get_hparams_assignment_variable()

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def setup_training(self, model: LightningModule):
129129
# log hyper-parameters
130130
if self.trainer.logger is not None:
131131
# save exp to get started (this is where the first experiment logs are written)
132-
self.trainer.logger.log_hyperparams(ref_model.hparams)
132+
self.trainer.logger.log_hyperparams(ref_model.hparams_initial)
133133
self.trainer.logger.log_graph(ref_model)
134134
self.trainer.logger.save()
135135

tests/models/test_hparams.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytorch_lightning import Trainer, LightningModule
2626
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
2727
from pytorch_lightning.utilities import AttributeDict, is_picklable
28-
from tests.base import EvalModelTemplate, TrialMNIST
28+
from tests.base import EvalModelTemplate, TrialMNIST, BoringModel
2929

3030

3131
class SaveHparamsModel(EvalModelTemplate):
@@ -554,3 +554,28 @@ def test_args(tmpdir):
554554
with pytest.raises(TypeError, match="__init__\(\) got an unexpected keyword argument 'test'"):
555555
SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)
556556

557+
558+
class RuntimeParamChangeModel(BoringModel):
559+
def __init__(self, running_arg):
560+
super().__init__()
561+
self.save_hyperparameters()
562+
563+
564+
def test_init_arg_with_runtime_change(tmpdir):
565+
model = RuntimeParamChangeModel(123)
566+
assert model.hparams.running_arg == 123
567+
model.hparams.running_arg = -1
568+
assert model.hparams.running_arg == -1
569+
570+
trainer = Trainer(
571+
default_root_dir=tmpdir,
572+
limit_train_batches=2,
573+
limit_val_batches=2,
574+
limit_test_batches=2,
575+
max_epochs=1,
576+
)
577+
trainer.fit(model)
578+
579+
path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
580+
hparams = load_hparams_from_yaml(path_yaml)
581+
assert hparams.get('running_arg') == 123

0 commit comments

Comments
 (0)