Skip to content

Commit 45d05ff

Browse files
Fixes #4141 (#4169)
* fix val epoch agg * fix val agg metrics * fix val agg metrics * fix val agg metrics
1 parent f064682 commit 45d05ff

File tree

6 files changed

+117
-10
lines changed

6 files changed

+117
-10
lines changed

pytorch_lightning/core/step_result.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def get_epoch_log_metrics(self) -> dict:
276276
if k == '_internal':
277277
continue
278278

279+
if options['forked']:
280+
continue
281+
279282
if options['logger'] and options['on_epoch']:
280283
if isinstance(self[k], Metric):
281284
result[k] = self[k].compute()
@@ -299,6 +302,9 @@ def get_epoch_pbar_metrics(self):
299302
if k == '_internal':
300303
continue
301304

305+
if options['forked']:
306+
continue
307+
302308
if options['prog_bar'] and options['on_epoch']:
303309
if isinstance(self[k], Metric):
304310
result[k] = self[k].compute()
@@ -311,6 +317,22 @@ def get_epoch_pbar_metrics(self):
311317

312318
return result
313319

320+
def get_forked_metrics(self):
321+
"""
322+
Gets the metrics to log at the end of epoch
323+
"""
324+
result = {}
325+
326+
meta = self['meta']
327+
for k, options in meta.items():
328+
if k == '_internal':
329+
continue
330+
331+
if options['forked']:
332+
result[k] = self[k]
333+
334+
return result
335+
314336
def get_batch_pbar_metrics(self, include_forked_originals=True):
315337
"""
316338
Gets the metrics to log at the end of the batch step
@@ -443,6 +465,11 @@ def reduce_on_epoch_end(cls, outputs):
443465
if k == '_internal' or isinstance(result[k], Metric):
444466
continue
445467

468+
# for forked metrics don't reduce, just take the last val
469+
if option['forked']:
470+
result[k] = choose_last(result[k])
471+
continue
472+
446473
if option['on_epoch']:
447474
fx = option['reduce_fx']
448475
if fx == torch.mean:
@@ -531,6 +558,14 @@ def rename_keys(self, map_dict: dict):
531558
del meta[source]
532559

533560

561+
def choose_last(x):
562+
if isinstance(x, (torch.Tensor, list)):
563+
return x[-1]
564+
if isinstance(x, dict):
565+
for k, v in x.items():
566+
x[k] = x[k][-1]
567+
568+
534569
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
535570
for out in outputs:
536571
if 'meta' in out:

pytorch_lightning/trainer/connectors/logger_connector.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,16 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
189189
self.callback_metrics.update(logger_metrics)
190190
self.callback_metrics.update(pbar_metrics)
191191

192+
# forked metrics were dropped, enable them for callbacks
193+
forked_metrics = reduced_epoch_metrics.get_forked_metrics()
194+
self.callback_metrics.update(forked_metrics)
195+
192196
# track the final results for the dataloader
193197
self.eval_loop_results.append(deepcopy(self.callback_metrics))
194198

195199
# actually log
196-
if len(epoch_logger_metrics) > 0:
197-
metrics_to_log.append(epoch_logger_metrics)
200+
if len(logger_metrics) > 0:
201+
metrics_to_log.append(logger_metrics)
198202

199203
# log all the metrics as a s single dict
200204
metrics_to_log = dict(ChainMap(*metrics_to_log))

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode):
218218
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
219219
model = self.trainer.get_model()
220220

221+
# reset results
222+
model._results = Result()
223+
221224
# with a single dataloader don't pass an array
222225
outputs = self.outputs
223226
eval_results = outputs

tests/core/test_metric_result_integration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _ddp_test_fn(rank, worldsize):
8383

8484
epoch_expected = {
8585
"b": cumulative_sum * worldsize,
86-
"a": cumulative_sum * worldsize,
8786
"a_epoch": cumulative_sum * worldsize
8887
}
8988

@@ -136,7 +135,7 @@ def test_result_metric_integration():
136135
assert metric_b.x == metric_b._defaults['x']
137136
assert metric_c.x == metric_c._defaults['x']
138137

139-
epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum}
138+
epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum}
140139

141140
assert set(epoch_log.keys()) == set(epoch_expected.keys())
142141
for k in epoch_expected.keys():

tests/trainer/logging/test_eval_loop_logging_1_0.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Tests to ensure that the training loop works with a dict (1.0)
1616
"""
1717
from pytorch_lightning import Trainer
18-
from pytorch_lightning import callbacks
18+
from pytorch_lightning import callbacks, seed_everything
1919
from tests.base.deterministic_model import DeterministicModel
2020
from tests.base import SimpleModule, BoringModel
2121
import os
@@ -68,7 +68,6 @@ def backward(self, loss, optimizer, optimizer_idx):
6868
'a2',
6969
'a_step',
7070
'a_epoch',
71-
'b',
7271
'b_step/epoch_0',
7372
'b_step/epoch_1',
7473
'b_epoch',
@@ -142,12 +141,10 @@ def backward(self, loss, optimizer, optimizer_idx):
142141
'b_step',
143142
'b_epoch',
144143
'c',
145-
'd',
146144
'd_step/epoch_0',
147145
'd_step/epoch_1',
148146
'd_epoch',
149147
'e',
150-
'f',
151148
'f_step/epoch_0',
152149
'f_step/epoch_1',
153150
'f_epoch',
@@ -247,6 +244,75 @@ def validation_step(self, batch, batch_idx):
247244
assert logged_metrics == expected_logged_metrics
248245

249246

247+
def test_eval_logging_auto_reduce(tmpdir):
248+
"""
249+
Tests that only training_step can be used
250+
"""
251+
seed_everything(1234)
252+
253+
os.environ['PL_DEV_DEBUG'] = '1'
254+
255+
class TestModel(BoringModel):
256+
def on_pretrain_routine_end(self) -> None:
257+
self.seen_vals = []
258+
self.manual_epoch_end_mean = None
259+
260+
def on_validation_epoch_start(self) -> None:
261+
self.seen_vals = []
262+
263+
def validation_step(self, batch, batch_idx):
264+
output = self.layer(batch)
265+
loss = self.loss(batch, output)
266+
self.seen_vals.append(loss)
267+
self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True)
268+
return {"x": loss}
269+
270+
def validation_epoch_end(self, outputs) -> None:
271+
for passed_in, manually_tracked in zip(outputs, self.seen_vals):
272+
assert passed_in['x'] == manually_tracked
273+
self.manual_epoch_end_mean = torch.stack([x['x'] for x in outputs]).mean()
274+
275+
model = TestModel()
276+
277+
trainer = Trainer(
278+
default_root_dir=tmpdir,
279+
limit_train_batches=3,
280+
limit_val_batches=3,
281+
max_epochs=1,
282+
log_every_n_steps=1,
283+
weights_summary=None,
284+
checkpoint_callback=callbacks.ModelCheckpoint('val_loss')
285+
)
286+
trainer.fit(model)
287+
288+
# make sure all the metrics are available for callbacks
289+
manual_mean = model.manual_epoch_end_mean
290+
callback_metrics = set(trainer.callback_metrics.keys())
291+
assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'}
292+
293+
# make sure values are correct
294+
assert trainer.logged_metrics['val_loss_epoch'] == manual_mean
295+
assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step/epoch_0']
296+
297+
# make sure correct values were logged
298+
logged_val = trainer.dev_debugger.logged_metrics
299+
300+
# sanity check
301+
assert logged_val[0]['global_step'] == 0
302+
assert logged_val[1]['global_step'] == 0
303+
304+
# 3 val batches
305+
assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[0]
306+
assert logged_val[3]['val_loss_step/epoch_0'] == model.seen_vals[1]
307+
assert logged_val[4]['val_loss_step/epoch_0'] == model.seen_vals[2]
308+
309+
# epoch mean
310+
assert logged_val[5]['val_loss_epoch'] == model.manual_epoch_end_mean
311+
312+
# only those logged
313+
assert len(logged_val) == 6
314+
315+
250316
def test_monitor_val_epoch_end(tmpdir):
251317
epoch_min_loss_override = 0
252318
model = SimpleModule()

tests/trainer/logging/test_train_loop_logging_1_0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ def val_dataloader(self):
395395

396396
generated = set(trainer.logger_connector.logged_metrics)
397397
expected = {
398-
'a_epoch', 'a',
399-
'n', 'n_step/epoch_0', 'n_epoch',
398+
'a_epoch',
399+
'n_step/epoch_0', 'n_epoch',
400400
'epoch'
401401
}
402402

0 commit comments

Comments
 (0)