Skip to content

Commit e416db9

Browse files
committed
added exisiting tests
1 parent a39e717 commit e416db9

File tree

1 file changed

+88
-10
lines changed

1 file changed

+88
-10
lines changed

optax/schedules/_schedule_test.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,77 @@ def test_with_giant_int_steps(self):
483483

484484
class WarmupCosineDecayTest(chex.TestCase):
485485

486+
@chex.all_variants
487+
@parameterized.named_parameters(
488+
('with end value', 10, 0.5, 1e-4),
489+
('without end value', 5, 3, 0.0),
490+
)
491+
def test_limits(self, init_value, peak_value, end_value):
492+
"""Check cosine schedule decay for the entire training schedule."""
493+
schedule_fn = self.variant(
494+
_schedule.warmup_cosine_decay_schedule(
495+
init_value=init_value,
496+
peak_value=peak_value,
497+
warmup_steps=100,
498+
decay_steps=1000,
499+
end_value=end_value,
500+
)
501+
)
502+
503+
np.testing.assert_allclose(init_value, schedule_fn(0))
504+
np.testing.assert_allclose(peak_value, schedule_fn(100))
505+
np.testing.assert_allclose(end_value, schedule_fn(1000), rtol=1e-3)
506+
507+
@chex.all_variants
508+
def test_with_exponent(self):
509+
"""Check that we get correct results when running with exponent on."""
510+
schedule_fn = self.variant(
511+
_schedule.warmup_cosine_decay_schedule(
512+
init_value=0.2,
513+
peak_value=1.21,
514+
end_value=-3.0,
515+
warmup_steps=50,
516+
decay_steps=100,
517+
exponent=2,
518+
)
519+
)
520+
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
521+
np.testing.assert_allclose(
522+
output,
523+
np.array([
524+
0.20000004768371582,
525+
0.4020000100135803,
526+
1.2100000381469727,
527+
-1.947500228881836,
528+
-3.000000238418579,
529+
]),
530+
rtol=1e-6,
531+
atol=1e-8,
532+
)
533+
534+
@chex.all_variants
535+
def test_zero_peak_value(self):
536+
"""Check that we get correct results when running with zero peak value."""
537+
schedule_fn = self.variant(
538+
_schedule.warmup_cosine_decay_schedule(
539+
init_value=0.2,
540+
peak_value=0,
541+
end_value=-3.0,
542+
warmup_steps=50,
543+
decay_steps=100,
544+
exponent=2,
545+
)
546+
)
547+
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
548+
np.testing.assert_allclose(
549+
output, np.array([0.2, 0.16, 0.0, 0.0, 0.0]), rtol=1e-6, atol=1e-8
550+
)
551+
486552
@chex.all_variants
487553
def test_monotonicity_and_exponent_ordering(self):
488554
init, peak, end = 0.0, 1.0, 0.1
489555
warmup_steps, decay_steps = 5, 25
556+
490557
base = self.variant(_schedule.warmup_cosine_decay_schedule(
491558
init_value=init, peak_value=peak,
492559
warmup_steps=warmup_steps, decay_steps=decay_steps,
@@ -495,24 +562,35 @@ def test_monotonicity_and_exponent_ordering(self):
495562
init_value=init, peak_value=peak,
496563
warmup_steps=warmup_steps, decay_steps=decay_steps,
497564
end_value=end, exponent=2.0))
565+
498566
steps = np.arange(decay_steps + 1)
499-
vals, vals2 = base(steps), steep(steps)
500-
for t in range(warmup_steps):
501-
self.assertLess(float(vals[t]), float(vals[t + 1]))
502-
self.assertAlmostEqual(float(vals[warmup_steps]), peak, places=6)
503-
for t in range(warmup_steps, decay_steps):
504-
self.assertGreaterEqual(float(vals[t]), float(vals[t + 1]))
505-
self.assertAlmostEqual(float(vals[-1]), end, places=6)
506-
for t in range(warmup_steps, decay_steps + 1):
507-
self.assertLessEqual(float(vals2[t]), float(vals[t]) + 1e-12)
567+
vals = base(steps)
568+
vals2 = steep(steps)
569+
570+
with self.subTest("warmup increases"):
571+
for t in range(warmup_steps):
572+
self.assertLess(float(vals[t]), float(vals[t + 1]))
573+
574+
with self.subTest("peak at boundary"):
575+
self.assertAlmostEqual(float(vals[warmup_steps]), peak, places=6)
576+
577+
with self.subTest("cosine decay nonincreasing and end value"):
578+
for t in range(warmup_steps, decay_steps):
579+
self.assertGreaterEqual(float(vals[t]), float(vals[t + 1]))
580+
self.assertAlmostEqual(float(vals[-1]), end, places=6)
581+
582+
with self.subTest("exponent ordering (p=2 ≤ p=1)"):
583+
for t in range(warmup_steps, decay_steps + 1):
584+
self.assertLessEqual(float(vals2[t]), float(vals[t]) + 1e-12)
585+
508586

509587
def test_raises_when_decay_equals_warmup(self):
510588
with self.assertRaises(ValueError):
511589
_schedule.warmup_cosine_decay_schedule(
512590
init_value=0.2, peak_value=1.2,
513591
warmup_steps=10, decay_steps=10, end_value=0.0)
514592

515-
#ee
593+
516594
class SGDRTest(chex.TestCase):
517595

518596
@chex.all_variants

0 commit comments

Comments
 (0)