Skip to content

Commit 0961868

Browse files
committed
Merge branch 'master' into changes
2 parents 30a2666 + 346bf40 commit 0961868

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

keras/callbacks.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,26 +1587,35 @@ def _get_most_recently_modified_file_matching_pattern(self, pattern):
15871587
class BackupAndRestore(Callback):
15881588
"""Callback to back up and restore the training state.
15891589
1590-
`BackupAndRestore` callback is intended to recover from interruptions that
1591-
happened in the middle of a model.fit execution by backing up the
1592-
training states in a temporary checkpoint file (based on TF CheckpointManager)
1593-
at the end of each epoch. If training restarted before completion, the
1594-
training state and model are restored to the most recently saved state at the
1595-
beginning of a new model.fit() run.
1596-
Note that user is responsible to bring jobs back up.
1590+
`BackupAndRestore` callback is intended to recover training from an
1591+
interruption that has happened in the middle of a `Model.fit` execution, by
1592+
backing up the training states in a temporary checkpoint file (with the help
1593+
of a `tf.train.CheckpointManager`), at the end of each epoch. Each backup
1594+
overwrites the previously written checkpoint file, so at any given time there
1595+
is at most one such checkpoint file for backup/restoring purpose.
1596+
1597+
If training restarts before completion, the training state (which includes the
1598+
`Model` weights and epoch number) is restored to the most recently saved state
1599+
at the beginning of a new `Model.fit` run. At the completion of a `Model.fit`
1600+
run, the temporary checkpoint file is deleted.
1601+
1602+
Note that the user is responsible to bring jobs back after the interruption.
15971603
This callback is important for the backup and restore mechanism for fault
1598-
tolerance purpose. And the model to be restored from an previous checkpoint is
1604+
tolerance purpose, and the model to be restored from an previous checkpoint is
15991605
expected to be the same as the one used to back up. If user changes arguments
16001606
passed to compile or fit, the checkpoint saved for fault tolerance can become
16011607
invalid.
16021608
16031609
Note:
1604-
1. This callback is not compatible with disabling eager execution.
1605-
2. A checkpoint is saved at the end of each epoch, when restoring we'll redo
1606-
any partial work from an unfinished epoch in which the training got restarted
1607-
(so the work done before a interruption doesn't affect the final model state).
1608-
3. This works for both single worker and multi-worker mode, only
1609-
MirroredStrategy and MultiWorkerMirroredStrategy are supported for now.
1610+
1. This callback is not compatible with eager execution disabled.
1611+
2. A checkpoint is saved at the end of each epoch. After restoring,
1612+
`Model.fit` redoes any partial work during the unfinished epoch in which the
1613+
training got restarted (so the work done before the interruption doesn't
1614+
affect the final model state).
1615+
3. This works for both single worker and multi-worker modes. When `Model.fit`
1616+
is used with `tf.distribute`, it supports `tf.distribute.MirroredStrategy`,
1617+
`tf.distribute.MultiWorkerMirroredStrategy`, `tf.distribute.TPUStrategy`, and
1618+
`tf.distribute.experimental.ParameterServerStrategy`.
16101619
16111620
Example:
16121621

keras/layers/convolutional.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3514,6 +3514,13 @@ def compute_output_shape(self, input_shape):
35143514
def call(self, inputs):
35153515
# pylint: disable=invalid-unary-operand-type
35163516
if self.data_format == 'channels_first':
3517+
if ((inputs.shape[2] is not None and
3518+
sum(self.cropping[0]) >= inputs.shape[2]) or
3519+
(inputs.shape[3] is not None and
3520+
sum(self.cropping[1]) >= inputs.shape[3])):
3521+
raise ValueError('Argument `cropping` must be '
3522+
'greater than the input shape. Received: inputs.shape='
3523+
f'{inputs.shape}, and cropping={self.cropping}')
35173524
if self.cropping[0][1] == self.cropping[1][1] == 0:
35183525
return inputs[:, :, self.cropping[0][0]:, self.cropping[1][0]:]
35193526
elif self.cropping[0][1] == 0:
@@ -3525,6 +3532,13 @@ def call(self, inputs):
35253532
return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1],
35263533
self.cropping[1][0]:-self.cropping[1][1]]
35273534
else:
3535+
if ((inputs.shape[1] is not None and
3536+
sum(self.cropping[0]) >= inputs.shape[1]) or
3537+
(inputs.shape[2] is not None and
3538+
sum(self.cropping[1]) >= inputs.shape[2])):
3539+
raise ValueError('Argument `cropping` must be '
3540+
'greater than the input shape. Received: inputs.shape='
3541+
f'{inputs.shape}, and cropping={self.cropping}')
35283542
if self.cropping[0][1] == self.cropping[1][1] == 0:
35293543
return inputs[:, self.cropping[0][0]:, self.cropping[1][0]:, :]
35303544
elif self.cropping[0][1] == 0:

keras/layers/convolutional_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,10 @@ def test_cropping_2d(self):
10821082
keras.layers.Cropping2D(cropping=(1, 1, 1))
10831083
with self.assertRaises(ValueError):
10841084
keras.layers.Cropping2D(cropping=None)
1085+
with self.assertRaises(ValueError):
1086+
input_layer = keras.layers.Input(
1087+
shape=(num_samples, input_len_dim1, input_len_dim2, stack_size))
1088+
keras.layers.Cropping2D(cropping=((5, 4), (3, 4)))(input_layer)
10851089

10861090
def test_cropping_3d(self):
10871091
num_samples = 2

0 commit comments

Comments
 (0)