Skip to content

Commit 48a71b2

Browse files
committed
feat(optimizers): add RAdam
1 parent d059bad commit 48a71b2

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

neuralnetlib/optimizers.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,121 @@ def __str__(self):
340340
return (f"{self.__class__.__name__}(learning_rate={self.learning_rate}, "
341341
f"beta_1={self.beta_1}, beta_2={self.beta_2}, epsilon={self.epsilon}, "
342342
f"clip_norm={self.clip_norm}, clip_value={self.clip_value})")
343+
344+
345+
class RAdam(Optimizer):
346+
def __init__(self, learning_rate: float = 0.001, beta_1: float = 0.9, beta_2: float = 0.999,
347+
epsilon: float = 1e-8, clip_norm: float = None, clip_value: float = None) -> None:
348+
super().__init__(learning_rate)
349+
self.beta_1 = beta_1
350+
self.beta_2 = beta_2
351+
self.epsilon = epsilon
352+
self.clip_norm = clip_norm
353+
self.clip_value = clip_value
354+
self.t = 0
355+
356+
self.m_w, self.v_w = {}, {}
357+
self.m_b, self.v_b = {}, {}
358+
359+
self._min_denom = 1e-16
360+
self._max_exp = np.log(np.finfo(np.float64).max)
361+
362+
self.rho_inf = 2/(1-beta_2) - 1
363+
364+
def _clip_gradients(self, grad: np.ndarray) -> np.ndarray:
365+
if grad is None:
366+
return None
367+
368+
if self.clip_norm is not None:
369+
grad_norm = np.linalg.norm(grad)
370+
if grad_norm > self.clip_norm:
371+
grad = grad * (self.clip_norm / (grad_norm + self._min_denom))
372+
373+
if self.clip_value is not None:
374+
grad = np.clip(grad, -self.clip_value, self.clip_value)
375+
376+
return grad
377+
378+
def _compute_moments(self, param: np.ndarray, grad: np.ndarray, m: np.ndarray, v: np.ndarray) -> tuple:
379+
grad = self._clip_gradients(grad)
380+
381+
m = self.beta_1 * m + (1 - self.beta_1) * grad
382+
v = self.beta_2 * v + (1 - self.beta_2) * np.square(grad)
383+
384+
beta1_t = self.beta_1 ** self.t
385+
beta2_t = self.beta_2 ** self.t
386+
387+
m_hat = m / (1 - beta1_t)
388+
389+
rho_t = self.rho_inf - 2 * self.t * beta2_t / (1 - beta2_t)
390+
391+
if rho_t > 4:
392+
v_hat = np.sqrt(v / (1 - beta2_t))
393+
r_t = np.sqrt(((rho_t - 4) * (rho_t - 2) * self.rho_inf) /
394+
((self.rho_inf - 4) * (self.rho_inf - 2) * rho_t))
395+
396+
denom = v_hat + self.epsilon
397+
update = r_t * self.learning_rate * m_hat / np.maximum(denom, self._min_denom)
398+
else:
399+
update = self.learning_rate * m_hat
400+
401+
update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
402+
param -= update
403+
404+
return param, m, v
405+
406+
def update(self, layer_index: int, weights: np.ndarray, weights_grad: np.ndarray, bias: np.ndarray,
407+
bias_grad: np.ndarray) -> None:
408+
if layer_index not in self.m_w:
409+
self.m_w[layer_index] = np.zeros_like(weights)
410+
self.v_w[layer_index] = np.zeros_like(weights)
411+
self.m_b[layer_index] = np.zeros_like(bias)
412+
self.v_b[layer_index] = np.zeros_like(bias)
413+
414+
self.t += 1
415+
416+
weights, self.m_w[layer_index], self.v_w[layer_index] = self._compute_moments(
417+
weights, weights_grad, self.m_w[layer_index], self.v_w[layer_index]
418+
)
419+
420+
bias, self.m_b[layer_index], self.v_b[layer_index] = self._compute_moments(
421+
bias, bias_grad, self.m_b[layer_index], self.v_b[layer_index]
422+
)
423+
424+
def get_config(self) -> dict:
425+
return {
426+
"name": self.__class__.__name__,
427+
"learning_rate": self.learning_rate,
428+
"beta_1": self.beta_1,
429+
"beta_2": self.beta_2,
430+
"epsilon": self.epsilon,
431+
"clip_norm": self.clip_norm,
432+
"clip_value": self.clip_value,
433+
"t": self.t,
434+
"m_w": dict_with_ndarray_to_dict_with_list(self.m_w),
435+
"v_w": dict_with_ndarray_to_dict_with_list(self.v_w),
436+
"m_b": dict_with_ndarray_to_dict_with_list(self.m_b),
437+
"v_b": dict_with_ndarray_to_dict_with_list(self.v_b)
438+
}
439+
440+
@staticmethod
441+
def from_config(config: dict):
442+
radam = RAdam(
443+
learning_rate=config['learning_rate'],
444+
beta_1=config['beta_1'],
445+
beta_2=config['beta_2'],
446+
epsilon=config['epsilon'],
447+
clip_norm=config.get('clip_norm'),
448+
clip_value=config.get('clip_value')
449+
)
450+
radam.t = config['t']
451+
radam.m_w = dict_with_list_to_dict_with_ndarray(config['m_w'])
452+
radam.v_w = dict_with_list_to_dict_with_ndarray(config['v_w'])
453+
radam.m_b = dict_with_list_to_dict_with_ndarray(config['m_b'])
454+
radam.v_b = dict_with_list_to_dict_with_ndarray(config['v_b'])
455+
return radam
456+
457+
def __str__(self):
458+
return (f"{self.__class__.__name__}(learning_rate={self.learning_rate}, "
459+
f"beta_1={self.beta_1}, beta_2={self.beta_2}, epsilon={self.epsilon}, "
460+
f"clip_norm={self.clip_norm}, clip_value={self.clip_value})")

0 commit comments

Comments
 (0)