Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7778149

Browse files
nshazeerRyan Sepassi
authored andcommitted
Revert "noam" learning-rate-scheme to use linear warmup. Add learning_rate_schedule hparam to specify a schedule that does not have separate warmup and decay phases.
PiperOrigin-RevId: 185042750
1 parent 5e8bc75 commit 7778149

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

tensor2tensor/layers/common_hparams.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def basic_params1():
6363
optimizer_momentum_nesterov=False,
6464
weight_decay=1e-6,
6565
weight_noise=0.0,
66+
learning_rate_schedule="warmup_and_decay",
67+
# If learning_rate_schedule=="warmup_and_decay", then this specifies
68+
# the decay part of the schedule.
69+
# The warmup is always exponential.
70+
# TODO(noam): add a hyperparameter to control the warmup.
6671
learning_rate_decay_scheme="none",
6772
# decay_steps and decay_staircase for learning_rate_decay_scheme=="exp"
6873
learning_rate_decay_steps=5000,

tensor2tensor/models/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def transformer_base_v1():
877877
hparams.max_length = 256
878878
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
879879
hparams.optimizer_adam_epsilon = 1e-9
880-
hparams.learning_rate_decay_scheme = "noam"
880+
hparams.learning_rate_schedule = "linear_warmup_rsqrt_decay"
881881
hparams.learning_rate = 0.1
882882
hparams.learning_rate_warmup_steps = 4000
883883
hparams.initializer_gain = 1.0

tensor2tensor/utils/optimize.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ def learning_rate_decay(hparams, warmup_steps=0):
168168
hparams.learning_rate_boundaries,
169169
hparams.learning_rate_multiples)
170170

171-
if scheme == "noam":
172-
return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum(
173-
(global_step + 1) * warmup_steps**-1.5, (global_step + 1)**-0.5)
174-
175171
if scheme == "cosine":
176172
cycle_steps = hparams.learning_rate_cosine_cycle_steps
177173
cycle_position = global_step % (2 * cycle_steps)
@@ -224,6 +220,23 @@ def learning_rate_decay_with_warmup(hparams, num_worker_replicas=1):
224220
return tf.where(global_step < warmup_steps, warmup, decay)
225221

226222

223+
def learning_rate_schedule(hparams, num_worker_replicas=1):
224+
"""Learning rate schedule based on hparams."""
225+
schedule = hparams.learning_rate_schedule
226+
warmup_steps = tf.to_float(hparams.learning_rate_warmup_steps)
227+
global_step = tf.to_float(tf.train.get_or_create_global_step())
228+
if hparams.learning_rate_decay_scheme == "noam":
229+
# backwards compatiblity with previous behavior
230+
schedule = "linear_warmup_rsqrt_decay"
231+
if schedule == "warmup_and_decay":
232+
return learning_rate_decay_with_warmup(hparams, num_worker_replicas)
233+
elif schedule == "linear_warmup_rsqrt_decay":
234+
return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum(
235+
(global_step + 1) * warmup_steps**-1.5, (global_step + 1)**-0.5)
236+
else:
237+
raise ValueError("Unrecognized learning rate schedule: %s" % schedule)
238+
239+
227240
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
228241
"""Apply weight decay and weight noise."""
229242
if var_list is None:

tensor2tensor/utils/t2t_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def optimize(self, loss, num_async_replicas=1):
296296
"""Return a training op minimizing loss."""
297297
tf.logging.info("Base learning rate: %f", self.hparams.learning_rate)
298298
lr = self.hparams.learning_rate
299-
decay_rate = optimize.learning_rate_decay_with_warmup(self.hparams)
299+
decay_rate = optimize.learning_rate_schedule(self.hparams)
300300
lr *= decay_rate
301301
if self.hparams.learning_rate_minimum:
302302
lr_min = float(self.hparams.learning_rate_minimum)

0 commit comments

Comments
 (0)