Skip to content

Commit 8b82ce0

Browse files
Merge pull request #1630 from PyTorchLightning/hparams_logger
Allow metrics logged together with hparams
2 parents 26933a9 + ccd49cf commit 8b82ce0

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Added
1010

1111
### Changed
12+
13+
- Allow logging of metrics togther with hparams ([#1630](https://github.com/PyTorchLightning/pytorch-lightning/pull/1630))
1214

1315
### Deprecated
1416

pytorch_lightning/loggers/tensorboard.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def experiment(self) -> SummaryWriter:
101101
return self._experiment
102102

103103
@rank_zero_only
104-
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
104+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
105+
metrics: Optional[Dict[str, Any]] = None) -> None:
105106
params = self._convert_params(params)
106107
params = self._flatten_dict(params)
107108
sanitized_params = self._sanitize_params(params)
@@ -114,7 +115,9 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
114115
)
115116
else:
116117
from torch.utils.tensorboard.summary import hparams
117-
exp, ssi, sei = hparams(sanitized_params, {})
118+
if metrics is None:
119+
metrics = {}
120+
exp, ssi, sei = hparams(sanitized_params, metrics)
118121
writer = self.experiment._get_file_writer()
119122
writer.add_summary(exp)
120123
writer.add_summary(ssi)

tests/loggers/test_tensorboard.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,19 @@ def test_tensorboard_log_hyperparams(tmpdir):
7777
"layer": torch.nn.BatchNorm1d
7878
}
7979
logger.log_hyperparams(hparams)
80+
81+
82+
def test_tensorboard_log_hparams_and_metrics(tmpdir):
83+
logger = TensorBoardLogger(tmpdir)
84+
hparams = {
85+
"float": 0.3,
86+
"int": 1,
87+
"string": "abc",
88+
"bool": True,
89+
"dict": {'a': {'b': 'c'}},
90+
"list": [1, 2, 3],
91+
"namespace": Namespace(foo=Namespace(bar='buzz')),
92+
"layer": torch.nn.BatchNorm1d
93+
}
94+
metrics = {'abc': torch.tensor([0.54])}
95+
logger.log_hyperparams(hparams, metrics)

0 commit comments

Comments
 (0)