@@ -483,10 +483,77 @@ def test_with_giant_int_steps(self):
483483
484484class WarmupCosineDecayTest (chex .TestCase ):
485485
486+ @chex .all_variants
487+ @parameterized .named_parameters (
488+ ('with end value' , 10 , 0.5 , 1e-4 ),
489+ ('without end value' , 5 , 3 , 0.0 ),
490+ )
491+ def test_limits (self , init_value , peak_value , end_value ):
492+ """Check cosine schedule decay for the entire training schedule."""
493+ schedule_fn = self .variant (
494+ _schedule .warmup_cosine_decay_schedule (
495+ init_value = init_value ,
496+ peak_value = peak_value ,
497+ warmup_steps = 100 ,
498+ decay_steps = 1000 ,
499+ end_value = end_value ,
500+ )
501+ )
502+
503+ np .testing .assert_allclose (init_value , schedule_fn (0 ))
504+ np .testing .assert_allclose (peak_value , schedule_fn (100 ))
505+ np .testing .assert_allclose (end_value , schedule_fn (1000 ), rtol = 1e-3 )
506+
507+ @chex .all_variants
508+ def test_with_exponent (self ):
509+ """Check that we get correct results when running with exponent on."""
510+ schedule_fn = self .variant (
511+ _schedule .warmup_cosine_decay_schedule (
512+ init_value = 0.2 ,
513+ peak_value = 1.21 ,
514+ end_value = - 3.0 ,
515+ warmup_steps = 50 ,
516+ decay_steps = 100 ,
517+ exponent = 2 ,
518+ )
519+ )
520+ output = schedule_fn (np .array ([0 , 10 , 50 , 75 , 100 ]))
521+ np .testing .assert_allclose (
522+ output ,
523+ np .array ([
524+ 0.20000004768371582 ,
525+ 0.4020000100135803 ,
526+ 1.2100000381469727 ,
527+ - 1.947500228881836 ,
528+ - 3.000000238418579 ,
529+ ]),
530+ rtol = 1e-6 ,
531+ atol = 1e-8 ,
532+ )
533+
534+ @chex .all_variants
535+ def test_zero_peak_value (self ):
536+ """Check that we get correct results when running with zero peak value."""
537+ schedule_fn = self .variant (
538+ _schedule .warmup_cosine_decay_schedule (
539+ init_value = 0.2 ,
540+ peak_value = 0 ,
541+ end_value = - 3.0 ,
542+ warmup_steps = 50 ,
543+ decay_steps = 100 ,
544+ exponent = 2 ,
545+ )
546+ )
547+ output = schedule_fn (np .array ([0 , 10 , 50 , 75 , 100 ]))
548+ np .testing .assert_allclose (
549+ output , np .array ([0.2 , 0.16 , 0.0 , 0.0 , 0.0 ]), rtol = 1e-6 , atol = 1e-8
550+ )
551+
486552 @chex .all_variants
487553 def test_monotonicity_and_exponent_ordering (self ):
488554 init , peak , end = 0.0 , 1.0 , 0.1
489555 warmup_steps , decay_steps = 5 , 25
556+
490557 base = self .variant (_schedule .warmup_cosine_decay_schedule (
491558 init_value = init , peak_value = peak ,
492559 warmup_steps = warmup_steps , decay_steps = decay_steps ,
@@ -495,24 +562,35 @@ def test_monotonicity_and_exponent_ordering(self):
495562 init_value = init , peak_value = peak ,
496563 warmup_steps = warmup_steps , decay_steps = decay_steps ,
497564 end_value = end , exponent = 2.0 ))
565+
498566 steps = np .arange (decay_steps + 1 )
499- vals , vals2 = base (steps ), steep (steps )
500- for t in range (warmup_steps ):
501- self .assertLess (float (vals [t ]), float (vals [t + 1 ]))
502- self .assertAlmostEqual (float (vals [warmup_steps ]), peak , places = 6 )
503- for t in range (warmup_steps , decay_steps ):
504- self .assertGreaterEqual (float (vals [t ]), float (vals [t + 1 ]))
505- self .assertAlmostEqual (float (vals [- 1 ]), end , places = 6 )
506- for t in range (warmup_steps , decay_steps + 1 ):
507- self .assertLessEqual (float (vals2 [t ]), float (vals [t ]) + 1e-12 )
567+ vals = base (steps )
568+ vals2 = steep (steps )
569+
570+ with self .subTest ("warmup increases" ):
571+ for t in range (warmup_steps ):
572+ self .assertLess (float (vals [t ]), float (vals [t + 1 ]))
573+
574+ with self .subTest ("peak at boundary" ):
575+ self .assertAlmostEqual (float (vals [warmup_steps ]), peak , places = 6 )
576+
577+ with self .subTest ("cosine decay nonincreasing and end value" ):
578+ for t in range (warmup_steps , decay_steps ):
579+ self .assertGreaterEqual (float (vals [t ]), float (vals [t + 1 ]))
580+ self .assertAlmostEqual (float (vals [- 1 ]), end , places = 6 )
581+
582+ with self .subTest ("exponent ordering (p=2 ≤ p=1)" ):
583+ for t in range (warmup_steps , decay_steps + 1 ):
584+ self .assertLessEqual (float (vals2 [t ]), float (vals [t ]) + 1e-12 )
585+
508586
509587 def test_raises_when_decay_equals_warmup (self ):
510588 with self .assertRaises (ValueError ):
511589 _schedule .warmup_cosine_decay_schedule (
512590 init_value = 0.2 , peak_value = 1.2 ,
513591 warmup_steps = 10 , decay_steps = 10 , end_value = 0.0 )
514592
515- #ee
593+
516594class SGDRTest (chex .TestCase ):
517595
518596 @chex .all_variants
0 commit comments