@@ -340,3 +340,121 @@ def __str__(self):
340340 return (f"{ self .__class__ .__name__ } (learning_rate={ self .learning_rate } , "
341341 f"beta_1={ self .beta_1 } , beta_2={ self .beta_2 } , epsilon={ self .epsilon } , "
342342 f"clip_norm={ self .clip_norm } , clip_value={ self .clip_value } )" )
343+
344+
345+ class RAdam (Optimizer ):
346+ def __init__ (self , learning_rate : float = 0.001 , beta_1 : float = 0.9 , beta_2 : float = 0.999 ,
347+ epsilon : float = 1e-8 , clip_norm : float = None , clip_value : float = None ) -> None :
348+ super ().__init__ (learning_rate )
349+ self .beta_1 = beta_1
350+ self .beta_2 = beta_2
351+ self .epsilon = epsilon
352+ self .clip_norm = clip_norm
353+ self .clip_value = clip_value
354+ self .t = 0
355+
356+ self .m_w , self .v_w = {}, {}
357+ self .m_b , self .v_b = {}, {}
358+
359+ self ._min_denom = 1e-16
360+ self ._max_exp = np .log (np .finfo (np .float64 ).max )
361+
362+ self .rho_inf = 2 / (1 - beta_2 ) - 1
363+
364+ def _clip_gradients (self , grad : np .ndarray ) -> np .ndarray :
365+ if grad is None :
366+ return None
367+
368+ if self .clip_norm is not None :
369+ grad_norm = np .linalg .norm (grad )
370+ if grad_norm > self .clip_norm :
371+ grad = grad * (self .clip_norm / (grad_norm + self ._min_denom ))
372+
373+ if self .clip_value is not None :
374+ grad = np .clip (grad , - self .clip_value , self .clip_value )
375+
376+ return grad
377+
378+ def _compute_moments (self , param : np .ndarray , grad : np .ndarray , m : np .ndarray , v : np .ndarray ) -> tuple :
379+ grad = self ._clip_gradients (grad )
380+
381+ m = self .beta_1 * m + (1 - self .beta_1 ) * grad
382+ v = self .beta_2 * v + (1 - self .beta_2 ) * np .square (grad )
383+
384+ beta1_t = self .beta_1 ** self .t
385+ beta2_t = self .beta_2 ** self .t
386+
387+ m_hat = m / (1 - beta1_t )
388+
389+ rho_t = self .rho_inf - 2 * self .t * beta2_t / (1 - beta2_t )
390+
391+ if rho_t > 4 :
392+ v_hat = np .sqrt (v / (1 - beta2_t ))
393+ r_t = np .sqrt (((rho_t - 4 ) * (rho_t - 2 ) * self .rho_inf ) /
394+ ((self .rho_inf - 4 ) * (self .rho_inf - 2 ) * rho_t ))
395+
396+ denom = v_hat + self .epsilon
397+ update = r_t * self .learning_rate * m_hat / np .maximum (denom , self ._min_denom )
398+ else :
399+ update = self .learning_rate * m_hat
400+
401+ update = np .nan_to_num (update , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
402+ param -= update
403+
404+ return param , m , v
405+
406+ def update (self , layer_index : int , weights : np .ndarray , weights_grad : np .ndarray , bias : np .ndarray ,
407+ bias_grad : np .ndarray ) -> None :
408+ if layer_index not in self .m_w :
409+ self .m_w [layer_index ] = np .zeros_like (weights )
410+ self .v_w [layer_index ] = np .zeros_like (weights )
411+ self .m_b [layer_index ] = np .zeros_like (bias )
412+ self .v_b [layer_index ] = np .zeros_like (bias )
413+
414+ self .t += 1
415+
416+ weights , self .m_w [layer_index ], self .v_w [layer_index ] = self ._compute_moments (
417+ weights , weights_grad , self .m_w [layer_index ], self .v_w [layer_index ]
418+ )
419+
420+ bias , self .m_b [layer_index ], self .v_b [layer_index ] = self ._compute_moments (
421+ bias , bias_grad , self .m_b [layer_index ], self .v_b [layer_index ]
422+ )
423+
424+ def get_config (self ) -> dict :
425+ return {
426+ "name" : self .__class__ .__name__ ,
427+ "learning_rate" : self .learning_rate ,
428+ "beta_1" : self .beta_1 ,
429+ "beta_2" : self .beta_2 ,
430+ "epsilon" : self .epsilon ,
431+ "clip_norm" : self .clip_norm ,
432+ "clip_value" : self .clip_value ,
433+ "t" : self .t ,
434+ "m_w" : dict_with_ndarray_to_dict_with_list (self .m_w ),
435+ "v_w" : dict_with_ndarray_to_dict_with_list (self .v_w ),
436+ "m_b" : dict_with_ndarray_to_dict_with_list (self .m_b ),
437+ "v_b" : dict_with_ndarray_to_dict_with_list (self .v_b )
438+ }
439+
440+ @staticmethod
441+ def from_config (config : dict ):
442+ radam = RAdam (
443+ learning_rate = config ['learning_rate' ],
444+ beta_1 = config ['beta_1' ],
445+ beta_2 = config ['beta_2' ],
446+ epsilon = config ['epsilon' ],
447+ clip_norm = config .get ('clip_norm' ),
448+ clip_value = config .get ('clip_value' )
449+ )
450+ radam .t = config ['t' ]
451+ radam .m_w = dict_with_list_to_dict_with_ndarray (config ['m_w' ])
452+ radam .v_w = dict_with_list_to_dict_with_ndarray (config ['v_w' ])
453+ radam .m_b = dict_with_list_to_dict_with_ndarray (config ['m_b' ])
454+ radam .v_b = dict_with_list_to_dict_with_ndarray (config ['v_b' ])
455+ return radam
456+
457+ def __str__ (self ):
458+ return (f"{ self .__class__ .__name__ } (learning_rate={ self .learning_rate } , "
459+ f"beta_1={ self .beta_1 } , beta_2={ self .beta_2 } , epsilon={ self .epsilon } , "
460+ f"clip_norm={ self .clip_norm } , clip_value={ self .clip_value } )" )
0 commit comments