Skip to content

Commit 25ee51b

Browse files
awaelchliJeremy JordanBordajeremyjordanwilliamFalcon
authored
Continue Jeremy's early stopping PR #1504 (#2391)
* add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <[email protected]> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <[email protected]> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <[email protected]> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <[email protected]> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jeremy Jordan <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: Jeremy Jordan <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent 1e16681 commit 25ee51b

32 files changed

+532
-230
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242

4343
- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))
4444

45+
- Fixed several issues with early stopping and checkpoint callbacks ([#1504](https://github.com/PyTorchLightning/pytorch-lightning/pull/1504), [#2391](https://github.com/PyTorchLightning/pytorch-lightning/pull/2391))
46+
4547
- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))
4648

4749
- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))

docs/source/callbacks.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ We successfully extended functionality without polluting our super clean
4848

4949
----------------
5050

51+
Best Practices
52+
==============
53+
54+
1. Callbacks should be isolated in their functionality. Your callback should not rely on the
55+
behavior of other callbacks in order to work properly.
56+
2. Do not manually call methods from the callback. The callbacks are designed to be
57+
invoked at specific times during training. Directly calling methods (eg. `on_validation_end`)
58+
is strongly discouraged.
59+
3. Whenever possible, your callbacks should not depend on the order in which they are executed.
60+
61+
62+
---------
63+
5164
.. automodule:: pytorch_lightning.callbacks.base
5265
:noindex:
5366
:exclude-members:

docs/source/experiment_logging.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
9292
.. testcode::
9393

9494
from pytorch_lightning.loggers import NeptuneLogger
95+
9596
neptune_logger = NeptuneLogger(
9697
api_key='ANONYMOUS', # replace with your own
9798
project_name='shared/pytorch-lightning-integration',
@@ -193,7 +194,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
193194
.. testcode::
194195

195196
from pytorch_lightning.loggers import WandbLogger
196-
wandb_logger = WandbLogger()
197+
wandb_logger = WandbLogger(offline=True)
197198
trainer = Trainer(logger=wandb_logger)
198199

199200
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your

docs/source/weights_loading.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Automatic saving
2929
Checkpointing is enabled by default to the current working directory.
3030
To change the checkpoint path pass in:
3131

32-
.. testcode::
32+
.. code-block:: python
3333
3434
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
3535

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Monitor a validation metric and stop training when it stops improving.
66
77
"""
8+
from copy import deepcopy
89

910
import numpy as np
1011
import torch
@@ -58,7 +59,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
5859
self.verbose = verbose
5960
self.strict = strict
6061
self.min_delta = min_delta
61-
self.wait = 0
62+
self.wait_count = 0
6263
self.stopped_epoch = 0
6364
self.mode = mode
6465

@@ -76,12 +77,17 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
7677
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
7778

7879
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
80+
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
7981

8082
def _validate_condition_metric(self, logs):
8183
"""
8284
Checks that the condition metric for early stopping is good
83-
:param logs:
84-
:return:
85+
86+
Args:
87+
logs: callback metrics from validation output
88+
89+
Return:
90+
True if specified metric is available
8591
"""
8692
monitor_val = logs.get(self.monitor)
8793
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
@@ -103,39 +109,48 @@ def _validate_condition_metric(self, logs):
103109
def monitor_op(self):
104110
return self.mode_dict[self.mode]
105111

106-
def on_train_start(self, trainer, pl_module):
107-
# Allow instances to be re-used
108-
self.wait = 0
109-
self.stopped_epoch = 0
110-
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
112+
def state_dict(self):
113+
return {
114+
'wait_count': self.wait_count,
115+
'stopped_epoch': self.stopped_epoch,
116+
'best_score': self.best_score,
117+
'patience': self.patience
118+
}
119+
120+
def load_state_dict(self, state_dict):
121+
state_dict = deepcopy(state_dict)
122+
self.wait_count = state_dict['wait_count']
123+
self.stopped_epoch = state_dict['stopped_epoch']
124+
self.best_score = state_dict['best_score']
125+
self.patience = state_dict['patience']
126+
127+
def on_sanity_check_end(self, trainer, pl_module):
128+
logs = trainer.callback_metrics
129+
self._validate_condition_metric(logs)
111130

112131
def on_validation_end(self, trainer, pl_module):
113-
return self._run_early_stopping_check(trainer, pl_module)
132+
self._run_early_stopping_check(trainer, pl_module)
114133

115134
def _run_early_stopping_check(self, trainer, pl_module):
116135
logs = trainer.callback_metrics
117-
stop_training = False
118136
if not self._validate_condition_metric(logs):
119-
return stop_training
137+
return # short circuit if metric not present
120138

121139
current = logs.get(self.monitor)
122140
if not isinstance(current, torch.Tensor):
123141
current = torch.tensor(current)
124142

125-
if self.monitor_op(current - self.min_delta, self.best):
126-
self.best = current
127-
self.wait = 0
143+
if self.monitor_op(current - self.min_delta, self.best_score):
144+
self.best_score = current
145+
self.wait_count = 0
128146
else:
129-
self.wait += 1
130-
if self.wait >= self.patience:
147+
self.wait_count += 1
148+
if self.wait_count >= self.patience:
131149
self.stopped_epoch = trainer.current_epoch
132-
stop_training = True
133-
self.on_train_end(trainer, pl_module)
134-
135-
return stop_training
150+
trainer.should_stop = True
136151

137152
def on_train_end(self, trainer, pl_module):
138153
if self.stopped_epoch > 0 and self.verbose > 0:
139154
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
140155
' but will start from "0" in v0.8.0.', DeprecationWarning)
141-
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
156+
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,41 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
226226
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
227227
return filepath
228228

229+
def on_train_start(self, trainer, pl_module):
230+
"""
231+
Determine model checkpoint save directory at runtime. References attributes from the
232+
Trainer's logger to determine where to save checkpoints.
233+
"""
234+
if self.dirpath is not None:
235+
return # short circuit
236+
237+
self.filename = '{epoch}'
238+
239+
if trainer.logger is not None and trainer.logger.experiment is not None:
240+
# weights_save_path overrides anything
241+
if getattr(trainer, 'weights_save_path', None) is not None:
242+
save_dir = trainer.weights_save_path
243+
else:
244+
save_dir = (getattr(trainer.logger, 'save_dir', None)
245+
or getattr(trainer.logger, '_save_dir', None)
246+
or trainer.default_root_dir)
247+
248+
version = trainer.logger.version if isinstance(
249+
trainer.logger.version, str) else f'version_{trainer.logger.version}'
250+
ckpt_path = os.path.join(
251+
save_dir,
252+
trainer.logger.name,
253+
version,
254+
"checkpoints"
255+
)
256+
else:
257+
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
258+
259+
self.dirpath = ckpt_path
260+
os.makedirs(self.dirpath, exist_ok=True)
261+
trainer.ckpt_path = ckpt_path
262+
trainer.weights_save_path = ckpt_path
263+
229264
@rank_zero_only
230265
def on_validation_end(self, trainer, pl_module):
231266
# only run on main process

pytorch_lightning/loggers/wandb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
131131
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
132132

133133
@property
134-
def name(self) -> str:
134+
def name(self) -> Optional[str]:
135135
# don't create an experiment if we don't have one
136136
name = self._experiment.project_name() if self._experiment else None
137137
return name
138138

139139
@property
140-
def version(self) -> str:
140+
def version(self) -> Optional[str]:
141141
# don't create an experiment if we don't have one
142142
return self._experiment.id if self._experiment else None

pytorch_lightning/trainer/callback_config.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,79 +32,47 @@ def save_checkpoint(self, *args):
3232
def is_overridden(self, *args):
3333
"""Warning: this is just empty shell for code implemented in other class."""
3434

35-
def configure_checkpoint_callback(self):
35+
def configure_checkpoint_callback(self, checkpoint_callback):
3636
"""
3737
Weight path set in this priority:
3838
Checkpoint_callback's path (if passed in).
3939
User provided weights_saved_path
4040
Otherwise use os.getcwd()
4141
"""
42-
ckpt_path = self.default_root_dir
43-
if self.checkpoint_callback:
44-
# init a default one
45-
if self.logger is not None and self.logger.experiment is not None:
46-
save_dir = (getattr(self.logger, 'save_dir', None) or
47-
getattr(self.logger, '_save_dir', None) or
48-
self.default_root_dir)
49-
50-
# weights_save_path overrides anything
51-
if self.weights_save_path is not None:
52-
save_dir = self.weights_save_path
53-
54-
version = self.logger.version if isinstance(
55-
self.logger.version, str) else f'version_{self.logger.version}'
56-
ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints")
57-
else:
58-
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")
59-
42+
if checkpoint_callback is True:
6043
# when no val step is defined, use 'loss' otherwise 'val_loss'
6144
train_step_only = not self.is_overridden('validation_step')
6245
monitor_key = 'loss' if train_step_only else 'val_loss'
46+
checkpoint_callback = ModelCheckpoint(
47+
filepath=None,
48+
monitor=monitor_key
49+
)
50+
elif checkpoint_callback is False:
51+
checkpoint_callback = None
6352

64-
if self.checkpoint_callback is True:
65-
os.makedirs(ckpt_path, exist_ok=True)
66-
self.checkpoint_callback = ModelCheckpoint(
67-
filepath=ckpt_path,
68-
monitor=monitor_key
69-
)
70-
# If user specified None in filepath, override with runtime default
71-
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
72-
and self.checkpoint_callback.dirpath is None:
73-
self.checkpoint_callback.dirpath = ckpt_path
74-
self.checkpoint_callback.filename = '{epoch}'
75-
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
76-
elif self.checkpoint_callback is False:
77-
self.checkpoint_callback = None
78-
79-
self.ckpt_path = ckpt_path
80-
81-
if self.checkpoint_callback:
82-
# set the path for the callbacks
83-
self.checkpoint_callback.save_function = self.save_checkpoint
84-
85-
# if checkpoint callback used, then override the weights path
86-
self.weights_save_path = self.checkpoint_callback.dirpath
53+
if checkpoint_callback:
54+
checkpoint_callback.save_function = self.save_checkpoint
8755

8856
# if weights_save_path is still none here, set to current working dir
8957
if self.weights_save_path is None:
9058
self.weights_save_path = self.default_root_dir
9159

60+
return checkpoint_callback
61+
9262
def configure_early_stopping(self, early_stop_callback):
9363
if early_stop_callback is True or None:
94-
self.early_stop_callback = EarlyStopping(
64+
early_stop_callback = EarlyStopping(
9565
monitor='val_loss',
9666
patience=3,
9767
strict=True,
9868
verbose=True,
9969
mode='min'
10070
)
101-
self.enable_early_stop = True
10271
elif not early_stop_callback:
103-
self.early_stop_callback = None
104-
self.enable_early_stop = False
72+
early_stop_callback = None
10573
else:
106-
self.early_stop_callback = early_stop_callback
107-
self.enable_early_stop = True
74+
early_stop_callback = early_stop_callback
75+
return early_stop_callback
10876

10977
def configure_progress_bar(self, refresh_rate=1, process_position=0):
11078
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ class TrainerDDPMixin(ABC):
172172
num_gpu_nodes: int
173173
gpus: List[int]
174174
logger: Union[LightningLoggerBase, bool]
175-
checkpoint_callback: Union[ModelCheckpoint, bool]
176175
data_parallel_device_ids: ...
177176
distributed_backend: Optional[str]
178177
amp_level: str

pytorch_lightning/trainer/lr_finder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def lr_find(self,
163163
# Disable standard checkpoint & early stopping
164164
self.checkpoint_callback = False
165165
self.early_stop_callback = None
166-
self.enable_early_stop = False
167166

168167
# Required for saving the model
169168
self.optimizers, self.schedulers = [], [],
@@ -215,7 +214,6 @@ def __lr_finder_dump_params(self, model):
215214
'max_steps': self.max_steps,
216215
'checkpoint_callback': self.checkpoint_callback,
217216
'early_stop_callback': self.early_stop_callback,
218-
'enable_early_stop': self.enable_early_stop,
219217
'configure_optimizers': model.configure_optimizers,
220218
}
221219

@@ -226,7 +224,6 @@ def __lr_finder_restore_params(self, model):
226224
self.max_steps = self.__dumped_params['max_steps']
227225
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
228226
self.early_stop_callback = self.__dumped_params['early_stop_callback']
229-
self.enable_early_stop = self.__dumped_params['enable_early_stop']
230227
model.configure_optimizers = self.__dumped_params['configure_optimizers']
231228
del self.__dumped_params
232229

0 commit comments

Comments
 (0)