1919import torch
2020import pytest
2121
22- from pytorch_lightning import Trainer
22+ from pytorch_lightning import Trainer , callbacks
2323from tests .base .deterministic_model import DeterministicModel
2424from torch .utils .data import Dataset
2525
@@ -80,6 +80,7 @@ def backward(self, loss, optimizer, optimizer_idx):
8080 max_epochs = 2 ,
8181 log_every_n_steps = 1 ,
8282 weights_summary = None ,
83+ checkpoint_callback = callbacks .ModelCheckpoint (monitor = 'l_se' )
8384 )
8485 trainer .fit (model )
8586
@@ -95,7 +96,6 @@ def backward(self, loss, optimizer, optimizer_idx):
9596 'default' ,
9697 'l_e' ,
9798 'l_s' ,
98- 'l_se' ,
9999 'l_se_step' ,
100100 'l_se_epoch' ,
101101 }
@@ -105,7 +105,6 @@ def backward(self, loss, optimizer, optimizer_idx):
105105 expected_pbar_metrics = {
106106 'p_e' ,
107107 'p_s' ,
108- 'p_se' ,
109108 'p_se_step' ,
110109 'p_se_epoch' ,
111110 }
@@ -116,6 +115,7 @@ def backward(self, loss, optimizer, optimizer_idx):
116115 expected_callback_metrics = set ()
117116 expected_callback_metrics = expected_callback_metrics .union (logged_metrics )
118117 expected_callback_metrics = expected_callback_metrics .union (pbar_metrics )
118+ expected_callback_metrics .update ({'p_se' , 'l_se' })
119119 expected_callback_metrics .remove ('epoch' )
120120 assert callback_metrics == expected_callback_metrics
121121
@@ -163,7 +163,7 @@ def backward(self, loss, optimizer, optimizer_idx):
163163
164164 # make sure all the metrics are available for callbacks
165165 logged_metrics = set (trainer .logged_metrics .keys ())
166- expected_logged_metrics = {'epoch' , 'a' , ' a_step' , 'a_epoch' , 'b' , 'b1' , 'a1' , 'a2' }
166+ expected_logged_metrics = {'epoch' , 'a_step' , 'a_epoch' , 'b' , 'b1' , 'a1' , 'a2' }
167167 assert logged_metrics == expected_logged_metrics
168168
169169 pbar_metrics = set (trainer .progress_bar_metrics .keys ())
@@ -178,6 +178,7 @@ def backward(self, loss, optimizer, optimizer_idx):
178178 expected_callback_metrics = expected_callback_metrics .union (logged_metrics )
179179 expected_callback_metrics = expected_callback_metrics .union (pbar_metrics )
180180 expected_callback_metrics .remove ('epoch' )
181+ expected_callback_metrics .add ('a' )
181182 assert callback_metrics == expected_callback_metrics
182183
183184
@@ -226,23 +227,24 @@ def training_epoch_end(self, outputs):
226227 # make sure all the metrics are available for callbacks
227228 logged_metrics = set (trainer .logged_metrics .keys ())
228229 expected_logged_metrics = {
229- 'a' , ' a_step' , 'a_epoch' ,
230- 'b' , ' b_step' , 'b_epoch' ,
230+ 'a_step' , 'a_epoch' ,
231+ 'b_step' , 'b_epoch' ,
231232 'c' ,
232233 'd/e/f' ,
233234 'epoch'
234235 }
235236 assert logged_metrics == expected_logged_metrics
236237
237238 pbar_metrics = set (trainer .progress_bar_metrics .keys ())
238- expected_pbar_metrics = {'b' , ' c' , 'b_epoch' , 'b_step' }
239+ expected_pbar_metrics = {'c' , 'b_epoch' , 'b_step' }
239240 assert pbar_metrics == expected_pbar_metrics
240241
241242 callback_metrics = set (trainer .callback_metrics .keys ())
242243 callback_metrics .remove ('debug_epoch' )
243244 expected_callback_metrics = set ()
244245 expected_callback_metrics = expected_callback_metrics .union (logged_metrics )
245246 expected_callback_metrics = expected_callback_metrics .union (pbar_metrics )
247+ expected_callback_metrics .update ({'a' , 'b' })
246248 expected_callback_metrics .remove ('epoch' )
247249 assert callback_metrics == expected_callback_metrics
248250
@@ -355,7 +357,7 @@ def train_dataloader(self):
355357 trainer .fit (model )
356358
357359 generated = set (trainer .logged_metrics .keys ())
358- expected = {'a' , ' a_step' , 'a_epoch' , 'epoch' }
360+ expected = {'a_step' , 'a_epoch' , 'epoch' }
359361 assert generated == expected
360362
361363
0 commit comments