Skip to content

Commit c5d65b7

Browse files
committed
added exisiting tests
1 parent a39e717 commit c5d65b7

File tree

1 file changed

+76
-9
lines changed

1 file changed

+76
-9
lines changed

optax/schedules/_schedule_test.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ class WarmupCosineDecayTest(chex.TestCase):
487487
def test_monotonicity_and_exponent_ordering(self):
488488
init, peak, end = 0.0, 1.0, 0.1
489489
warmup_steps, decay_steps = 5, 25
490+
490491
base = self.variant(_schedule.warmup_cosine_decay_schedule(
491492
init_value=init, peak_value=peak,
492493
warmup_steps=warmup_steps, decay_steps=decay_steps,
@@ -495,16 +496,82 @@ def test_monotonicity_and_exponent_ordering(self):
495496
init_value=init, peak_value=peak,
496497
warmup_steps=warmup_steps, decay_steps=decay_steps,
497498
end_value=end, exponent=2.0))
499+
498500
steps = np.arange(decay_steps + 1)
499-
vals, vals2 = base(steps), steep(steps)
500-
for t in range(warmup_steps):
501-
self.assertLess(float(vals[t]), float(vals[t + 1]))
502-
self.assertAlmostEqual(float(vals[warmup_steps]), peak, places=6)
503-
for t in range(warmup_steps, decay_steps):
504-
self.assertGreaterEqual(float(vals[t]), float(vals[t + 1]))
505-
self.assertAlmostEqual(float(vals[-1]), end, places=6)
506-
for t in range(warmup_steps, decay_steps + 1):
507-
self.assertLessEqual(float(vals2[t]), float(vals[t]) + 1e-12)
501+
vals = base(steps)
502+
vals2 = steep(steps)
503+
504+
with self.subTest("warmup increases"):
505+
for t in range(warmup_steps):
506+
self.assertLess(float(vals[t]), float(vals[t + 1]))
507+
508+
with self.subTest("peak at boundary"):
509+
self.assertAlmostEqual(float(vals[warmup_steps]), peak, places=6)
510+
511+
with self.subTest("cosine decay nonincreasing and end value"):
512+
for t in range(warmup_steps, decay_steps):
513+
self.assertGreaterEqual(float(vals[t]), float(vals[t + 1]))
514+
self.assertAlmostEqual(float(vals[-1]), end, places=6)
515+
516+
with self.subTest("exponent ordering (p=2 ≤ p=1)"):
517+
for t in range(warmup_steps, decay_steps + 1):
518+
self.assertLessEqual(float(vals2[t]), float(vals[t]) + 1e-12)
519+
520+
@chex.all_variants
521+
def test_regression_values_subtests(self):
522+
523+
for name, init_value, peak_value, end_value in [
524+
("with end value", 10, 0.5, 1e-4),
525+
("without end value", 5, 3, 0.0),
526+
]:
527+
with self.subTest(f"limits: {name}"):
528+
schedule_fn = self.variant(_schedule.warmup_cosine_decay_schedule(
529+
init_value=init_value,
530+
peak_value=peak_value,
531+
warmup_steps=100,
532+
decay_steps=1000,
533+
end_value=end_value,
534+
))
535+
np.testing.assert_allclose(init_value, schedule_fn(0))
536+
np.testing.assert_allclose(peak_value, schedule_fn(100))
537+
np.testing.assert_allclose(end_value, schedule_fn(1000), rtol=1e-3)
538+
539+
with self.subTest("with exponent golden values"):
540+
schedule_fn = self.variant(_schedule.warmup_cosine_decay_schedule(
541+
init_value=0.2,
542+
peak_value=1.21,
543+
end_value=-3.0,
544+
warmup_steps=50,
545+
decay_steps=100,
546+
exponent=2,
547+
))
548+
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
549+
np.testing.assert_allclose(
550+
output,
551+
np.array([
552+
0.20000004768371582,
553+
0.4020000100135803,
554+
1.2100000381469727,
555+
-1.947500228881836,
556+
-3.000000238418579,
557+
]),
558+
rtol=1e-6, atol=1e-8,
559+
)
560+
561+
with self.subTest("zero peak value"):
562+
schedule_fn = self.variant(_schedule.warmup_cosine_decay_schedule(
563+
init_value=0.2,
564+
peak_value=0,
565+
end_value=-3.0,
566+
warmup_steps=50,
567+
decay_steps=100,
568+
exponent=2,
569+
))
570+
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
571+
np.testing.assert_allclose(
572+
output, np.array([0.2, 0.16, 0.0, 0.0, 0.0]),
573+
rtol=1e-6, atol=1e-8,
574+
)
508575

509576
def test_raises_when_decay_equals_warmup(self):
510577
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)