@@ -487,6 +487,7 @@ class WarmupCosineDecayTest(chex.TestCase):
487487 def test_monotonicity_and_exponent_ordering (self ):
488488 init , peak , end = 0.0 , 1.0 , 0.1
489489 warmup_steps , decay_steps = 5 , 25
490+
490491 base = self .variant (_schedule .warmup_cosine_decay_schedule (
491492 init_value = init , peak_value = peak ,
492493 warmup_steps = warmup_steps , decay_steps = decay_steps ,
@@ -495,16 +496,82 @@ def test_monotonicity_and_exponent_ordering(self):
495496 init_value = init , peak_value = peak ,
496497 warmup_steps = warmup_steps , decay_steps = decay_steps ,
497498 end_value = end , exponent = 2.0 ))
499+
498500 steps = np .arange (decay_steps + 1 )
499- vals , vals2 = base (steps ), steep (steps )
500- for t in range (warmup_steps ):
501- self .assertLess (float (vals [t ]), float (vals [t + 1 ]))
502- self .assertAlmostEqual (float (vals [warmup_steps ]), peak , places = 6 )
503- for t in range (warmup_steps , decay_steps ):
504- self .assertGreaterEqual (float (vals [t ]), float (vals [t + 1 ]))
505- self .assertAlmostEqual (float (vals [- 1 ]), end , places = 6 )
506- for t in range (warmup_steps , decay_steps + 1 ):
507- self .assertLessEqual (float (vals2 [t ]), float (vals [t ]) + 1e-12 )
501+ vals = base (steps )
502+ vals2 = steep (steps )
503+
504+ with self .subTest ("warmup increases" ):
505+ for t in range (warmup_steps ):
506+ self .assertLess (float (vals [t ]), float (vals [t + 1 ]))
507+
508+ with self .subTest ("peak at boundary" ):
509+ self .assertAlmostEqual (float (vals [warmup_steps ]), peak , places = 6 )
510+
511+ with self .subTest ("cosine decay nonincreasing and end value" ):
512+ for t in range (warmup_steps , decay_steps ):
513+ self .assertGreaterEqual (float (vals [t ]), float (vals [t + 1 ]))
514+ self .assertAlmostEqual (float (vals [- 1 ]), end , places = 6 )
515+
516+ with self .subTest ("exponent ordering (p=2 ≤ p=1)" ):
517+ for t in range (warmup_steps , decay_steps + 1 ):
518+ self .assertLessEqual (float (vals2 [t ]), float (vals [t ]) + 1e-12 )
519+
520+ @chex .all_variants
521+ def test_regression_values_subtests (self ):
522+
523+ for name , init_value , peak_value , end_value in [
524+ ("with end value" , 10 , 0.5 , 1e-4 ),
525+ ("without end value" , 5 , 3 , 0.0 ),
526+ ]:
527+ with self .subTest (f"limits: { name } " ):
528+ schedule_fn = self .variant (_schedule .warmup_cosine_decay_schedule (
529+ init_value = init_value ,
530+ peak_value = peak_value ,
531+ warmup_steps = 100 ,
532+ decay_steps = 1000 ,
533+ end_value = end_value ,
534+ ))
535+ np .testing .assert_allclose (init_value , schedule_fn (0 ))
536+ np .testing .assert_allclose (peak_value , schedule_fn (100 ))
537+ np .testing .assert_allclose (end_value , schedule_fn (1000 ), rtol = 1e-3 )
538+
539+ with self .subTest ("with exponent golden values" ):
540+ schedule_fn = self .variant (_schedule .warmup_cosine_decay_schedule (
541+ init_value = 0.2 ,
542+ peak_value = 1.21 ,
543+ end_value = - 3.0 ,
544+ warmup_steps = 50 ,
545+ decay_steps = 100 ,
546+ exponent = 2 ,
547+ ))
548+ output = schedule_fn (np .array ([0 , 10 , 50 , 75 , 100 ]))
549+ np .testing .assert_allclose (
550+ output ,
551+ np .array ([
552+ 0.20000004768371582 ,
553+ 0.4020000100135803 ,
554+ 1.2100000381469727 ,
555+ - 1.947500228881836 ,
556+ - 3.000000238418579 ,
557+ ]),
558+ rtol = 1e-6 , atol = 1e-8 ,
559+ )
560+
561+ with self .subTest ("zero peak value" ):
562+ schedule_fn = self .variant (_schedule .warmup_cosine_decay_schedule (
563+ init_value = 0.2 ,
564+ peak_value = 0 ,
565+ end_value = - 3.0 ,
566+ warmup_steps = 50 ,
567+ decay_steps = 100 ,
568+ exponent = 2 ,
569+ ))
570+ output = schedule_fn (np .array ([0 , 10 , 50 , 75 , 100 ]))
571+ np .testing .assert_allclose (
572+ output , np .array ([0.2 , 0.16 , 0.0 , 0.0 , 0.0 ]),
573+ rtol = 1e-6 , atol = 1e-8 ,
574+ )
508575
509576 def test_raises_when_decay_equals_warmup (self ):
510577 with self .assertRaises (ValueError ):
0 commit comments