diff --git a/optax/schedules/_schedule.py b/optax/schedules/_schedule.py index d5a3f0c70..21f49cd24 100644 --- a/optax/schedules/_schedule.py +++ b/optax/schedules/_schedule.py @@ -266,11 +266,11 @@ def cosine_decay_schedule( .. math:: - \frac{I (1 - E)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + E\,, + \frac{(I - E)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + E\,, where :math:`T` is the number of decay steps (``decay_steps``), :math:`p` is the ``exponent``, :math:`I` is the initial value (``init_value``) and - :math:`E` is the end value,. + :math:`E` is the end value (``end_value``). References: Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts @@ -286,8 +286,8 @@ def cosine_decay_schedule( ``t`` is the current timestep and ``T`` is the ``decay_steps``. The exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T))) ** exponent``. Defaults to 1.0. - alpha: The minimum value of the multiplier used to adjust the - learning rate. Defaults to 0.0. + alpha: Deprecated, use end_value instead. The minimum value of the + multiplier used to adjust the learning rate. Defaults to 0.0. Returns: schedule @@ -316,8 +316,7 @@ def cosine_decay_schedule( def schedule(count): count = jnp.minimum(count, decay_steps) cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps)) - decayed = (1 - end_value) * cosine_decay ** exponent + end_value - return init_value * decayed + return (init_value - end_value) * cosine_decay ** exponent + end_value return schedule @@ -501,7 +500,6 @@ def warmup_cosine_decay_schedule( schedule A function that maps step counts to values """ - alpha = 0. if peak_value == 0. else end_value / peak_value schedules = [ linear_schedule( init_value=init_value, @@ -511,7 +509,7 @@ def warmup_cosine_decay_schedule( cosine_decay_schedule( init_value=peak_value, decay_steps=decay_steps - warmup_steps, - alpha=alpha, + end_value=end_value, exponent=exponent, ), ] diff --git a/optax/schedules/_schedule_test.py b/optax/schedules/_schedule_test.py index 8826f981a..5b26de957 100644 --- a/optax/schedules/_schedule_test.py +++ b/optax/schedules/_schedule_test.py @@ -300,6 +300,24 @@ def test_immutable_count(self): class CosineDecayTest(chex.TestCase): + @chex.all_variants + def test_init_value_end_value(self): + """Check cosine schedule decay for the entire training schedule.""" + initial_value = 1.5 + end_value = 0.2 + num_steps = 10 + schedule_fn = self.variant( + _schedule.cosine_decay_schedule(initial_value, num_steps, end_value)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(num_steps + 1): + # Compute next value. + generated_vals.append(schedule_fn(count)) + + # Test that the first and last values are correct. + self.assertAlmostEqual(generated_vals[0], initial_value) + self.assertAlmostEqual(generated_vals[-1], end_value) + @chex.all_variants def test_decay_count_smaller_count(self): """Check cosine schedule decay for the entire training schedule.""" @@ -345,23 +363,28 @@ def test_decay_count_greater_count(self): def test_decay_count_greater_count_with_end_value(self): """Check cosine schedule decay for a part of the training schedule.""" # Get schedule function. - initial_value = 0.1 + initial_value = 0.2 + end_value = 0.1 + num_steps = 5 schedule_fn = self.variant( - _schedule.cosine_decay_schedule(initial_value, 5, 0.1)) + _schedule.cosine_decay_schedule(initial_value, num_steps, end_value)) # Test that generated values equal the expected schedule values. generated_vals = [] - for count in range(12): + for count in range(2 * num_steps): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. - expected_multipliers = np.array( - 0.5 + 0.5 * np.cos( - np.pi * np.array( - [0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.]))) - expected_multipliers = 0.9 * expected_multipliers + 0.1 + cos_values = 0.5 * (1 + np.cos(np.pi * np.linspace(0, 1, num_steps + 1))) + expected_values = ( + (initial_value - end_value) * cos_values + end_value + ) + # padd with [end_value] at the end. + expected_values = np.concatenate( + (expected_values, [end_value] * (num_steps - 1)) + ) np.testing.assert_allclose( - initial_value * expected_multipliers, + expected_values, np.array(generated_vals), atol=1e-3) def test_cosine_alpha_exception(self):