Skip to content

Commit 4d1c67e

Browse files
committed
tests: fold warmup-cosine edge cases into _schedule_test.py and compress cases
1 parent e3a4a05 commit 4d1c67e

File tree

2 files changed

+28
-216
lines changed

2 files changed

+28
-216
lines changed

optax/schedules/_schedule_test.py

Lines changed: 28 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -484,72 +484,35 @@ def test_with_giant_int_steps(self):
484484
class WarmupCosineDecayTest(chex.TestCase):
485485

486486
@chex.all_variants
487-
@parameterized.named_parameters(
488-
('with end value', 10, 0.5, 1e-4),
489-
('without end value', 5, 3, 0.0),
490-
)
491-
def test_limits(self, init_value, peak_value, end_value):
492-
"""Check cosine schedule decay for the entire training schedule."""
493-
schedule_fn = self.variant(
494-
_schedule.warmup_cosine_decay_schedule(
495-
init_value=init_value,
496-
peak_value=peak_value,
497-
warmup_steps=100,
498-
decay_steps=1000,
499-
end_value=end_value,
500-
)
501-
)
502-
503-
np.testing.assert_allclose(init_value, schedule_fn(0))
504-
np.testing.assert_allclose(peak_value, schedule_fn(100))
505-
np.testing.assert_allclose(end_value, schedule_fn(1000), rtol=1e-3)
506-
507-
@chex.all_variants
508-
def test_with_exponent(self):
509-
"""Check that we get correct results when running with exponent on."""
510-
schedule_fn = self.variant(
511-
_schedule.warmup_cosine_decay_schedule(
512-
init_value=0.2,
513-
peak_value=1.21,
514-
end_value=-3.0,
515-
warmup_steps=50,
516-
decay_steps=100,
517-
exponent=2,
518-
)
519-
)
520-
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
521-
np.testing.assert_allclose(
522-
output,
523-
np.array([
524-
0.20000004768371582,
525-
0.4020000100135803,
526-
1.2100000381469727,
527-
-1.947500228881836,
528-
-3.000000238418579,
529-
]),
530-
rtol=1e-6,
531-
atol=1e-8,
532-
)
533-
534-
@chex.all_variants
535-
def test_zero_peak_value(self):
536-
"""Check that we get correct results when running with zero peak value."""
537-
schedule_fn = self.variant(
538-
_schedule.warmup_cosine_decay_schedule(
539-
init_value=0.2,
540-
peak_value=0,
541-
end_value=-3.0,
542-
warmup_steps=50,
543-
decay_steps=100,
544-
exponent=2,
545-
)
546-
)
547-
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
548-
np.testing.assert_allclose(
549-
output, np.array([0.2, 0.16, 0.0, 0.0, 0.0]), rtol=1e-6, atol=1e-8
550-
)
551-
487+
def test_monotonicity_and_exponent_ordering(self):
488+
init, peak, end = 0.0, 1.0, 0.1
489+
warmup_steps, decay_steps = 5, 25
490+
base = self.variant(_schedule.warmup_cosine_decay_schedule(
491+
init_value=init, peak_value=peak,
492+
warmup_steps=warmup_steps, decay_steps=decay_steps,
493+
end_value=end, exponent=1.0))
494+
steep = self.variant(_schedule.warmup_cosine_decay_schedule(
495+
init_value=init, peak_value=peak,
496+
warmup_steps=warmup_steps, decay_steps=decay_steps,
497+
end_value=end, exponent=2.0))
498+
steps = np.arange(decay_steps + 1)
499+
vals, vals2 = base(steps), steep(steps)
500+
for t in range(warmup_steps):
501+
self.assertLess(float(vals[t]), float(vals[t + 1]))
502+
self.assertAlmostEqual(float(vals[warmup_steps]), peak, places=6)
503+
for t in range(warmup_steps, decay_steps):
504+
self.assertGreaterEqual(float(vals[t]), float(vals[t + 1]))
505+
self.assertAlmostEqual(float(vals[-1]), end, places=6)
506+
for t in range(warmup_steps, decay_steps + 1):
507+
self.assertLessEqual(float(vals2[t]), float(vals[t]) + 1e-12)
508+
509+
def test_raises_when_decay_equals_warmup(self):
510+
with self.assertRaises(ValueError):
511+
_schedule.warmup_cosine_decay_schedule(
512+
init_value=0.2, peak_value=1.2,
513+
warmup_steps=10, decay_steps=10, end_value=0.0)
552514

515+
#ee
553516
class SGDRTest(chex.TestCase):
554517

555518
@chex.all_variants

optax/schedules/schedules_warmup_cosine_test.py

Lines changed: 0 additions & 151 deletions
This file was deleted.

0 commit comments

Comments
 (0)