@@ -484,72 +484,35 @@ def test_with_giant_int_steps(self):
484484class WarmupCosineDecayTest (chex .TestCase ):
485485
486486 @chex .all_variants
487- @parameterized .named_parameters (
488- ('with end value' , 10 , 0.5 , 1e-4 ),
489- ('without end value' , 5 , 3 , 0.0 ),
490- )
491- def test_limits (self , init_value , peak_value , end_value ):
492- """Check cosine schedule decay for the entire training schedule."""
493- schedule_fn = self .variant (
494- _schedule .warmup_cosine_decay_schedule (
495- init_value = init_value ,
496- peak_value = peak_value ,
497- warmup_steps = 100 ,
498- decay_steps = 1000 ,
499- end_value = end_value ,
500- )
501- )
502-
503- np .testing .assert_allclose (init_value , schedule_fn (0 ))
504- np .testing .assert_allclose (peak_value , schedule_fn (100 ))
505- np .testing .assert_allclose (end_value , schedule_fn (1000 ), rtol = 1e-3 )
506-
507- @chex .all_variants
508- def test_with_exponent (self ):
509- """Check that we get correct results when running with exponent on."""
510- schedule_fn = self .variant (
511- _schedule .warmup_cosine_decay_schedule (
512- init_value = 0.2 ,
513- peak_value = 1.21 ,
514- end_value = - 3.0 ,
515- warmup_steps = 50 ,
516- decay_steps = 100 ,
517- exponent = 2 ,
518- )
519- )
520- output = schedule_fn (np .array ([0 , 10 , 50 , 75 , 100 ]))
521- np .testing .assert_allclose (
522- output ,
523- np .array ([
524- 0.20000004768371582 ,
525- 0.4020000100135803 ,
526- 1.2100000381469727 ,
527- - 1.947500228881836 ,
528- - 3.000000238418579 ,
529- ]),
530- rtol = 1e-6 ,
531- atol = 1e-8 ,
532- )
533-
534- @chex .all_variants
535- def test_zero_peak_value (self ):
536- """Check that we get correct results when running with zero peak value."""
537- schedule_fn = self .variant (
538- _schedule .warmup_cosine_decay_schedule (
539- init_value = 0.2 ,
540- peak_value = 0 ,
541- end_value = - 3.0 ,
542- warmup_steps = 50 ,
543- decay_steps = 100 ,
544- exponent = 2 ,
545- )
546- )
547- output = schedule_fn (np .array ([0 , 10 , 50 , 75 , 100 ]))
548- np .testing .assert_allclose (
549- output , np .array ([0.2 , 0.16 , 0.0 , 0.0 , 0.0 ]), rtol = 1e-6 , atol = 1e-8
550- )
551-
487+ def test_monotonicity_and_exponent_ordering (self ):
488+ init , peak , end = 0.0 , 1.0 , 0.1
489+ warmup_steps , decay_steps = 5 , 25
490+ base = self .variant (_schedule .warmup_cosine_decay_schedule (
491+ init_value = init , peak_value = peak ,
492+ warmup_steps = warmup_steps , decay_steps = decay_steps ,
493+ end_value = end , exponent = 1.0 ))
494+ steep = self .variant (_schedule .warmup_cosine_decay_schedule (
495+ init_value = init , peak_value = peak ,
496+ warmup_steps = warmup_steps , decay_steps = decay_steps ,
497+ end_value = end , exponent = 2.0 ))
498+ steps = np .arange (decay_steps + 1 )
499+ vals , vals2 = base (steps ), steep (steps )
500+ for t in range (warmup_steps ):
501+ self .assertLess (float (vals [t ]), float (vals [t + 1 ]))
502+ self .assertAlmostEqual (float (vals [warmup_steps ]), peak , places = 6 )
503+ for t in range (warmup_steps , decay_steps ):
504+ self .assertGreaterEqual (float (vals [t ]), float (vals [t + 1 ]))
505+ self .assertAlmostEqual (float (vals [- 1 ]), end , places = 6 )
506+ for t in range (warmup_steps , decay_steps + 1 ):
507+ self .assertLessEqual (float (vals2 [t ]), float (vals [t ]) + 1e-12 )
508+
509+ def test_raises_when_decay_equals_warmup (self ):
510+ with self .assertRaises (ValueError ):
511+ _schedule .warmup_cosine_decay_schedule (
512+ init_value = 0.2 , peak_value = 1.2 ,
513+ warmup_steps = 10 , decay_steps = 10 , end_value = 0.0 )
552514
515+ #ee
553516class SGDRTest (chex .TestCase ):
554517
555518 @chex .all_variants
0 commit comments