@@ -3079,44 +3079,28 @@ def forward_pass(self, inputs: tuple[np.ndarray, np.ndarray]) -> np.ndarray:
30793079 return self .gamma * self .normalized + self .beta
30803080
30813081 def backward_pass (self , output_error : np .ndarray ) -> tuple [np .ndarray , np .ndarray ]:
3082- N = output_error .shape [- 1 ]
3083- batch_size = output_error .shape [0 ]
3084- self .step += 1
3085-
3086- scaled_error = output_error * self .grad_scale
3087-
3088- d_gamma_raw = np .sum (scaled_error * self .output_before_gamma , axis = (0 , 1 ), keepdims = True )
3089- d_beta_raw = np .sum (scaled_error , axis = (0 , 1 ), keepdims = True )
3090-
3091- scale = 1.0 / (batch_size * output_error .shape [1 ])
3092- d_gamma = d_gamma_raw * scale
3093- d_beta = d_beta_raw * scale
3094-
3095- self .update_gradient_stats (d_gamma )
3096-
3097- self .d_gamma = self .normalize_gradients (d_gamma )
3098- self .d_beta = self .normalize_gradients (d_beta )
3099-
3100- d_normalized = scaled_error * self .gamma
3101-
3102- d_variance = np .clip (
3103- - 0.5 * np .sum (d_normalized * self .output_before_gamma , axis = - 1 , keepdims = True ) / self .std ,
3104- - self .grad_clip , self .grad_clip
3105- )
3106-
3107- d_mean = np .clip (
3108- - np .sum (d_normalized / self .std , axis = - 1 , keepdims = True ),
3109- - self .grad_clip , self .grad_clip
3110- )
3082+ dY = output_error
3083+ B , T , F = dY .shape
3084+ N = F
3085+
3086+ x_minus_mean = self .normalized * self .std
31113087
3112- d_input = np .clip (
3113- ( d_normalized / self . std +
3114- 2.0 * d_variance * self . output_before_gamma / N +
3115- d_mean / N ),
3116- - self . grad_clip , self . grad_clip
3117- )
3088+ d_gamma = np .sum ( dY * self . normalized , axis = ( 0 , 1 ), keepdims = True )
3089+ d_beta = np . sum ( dY , axis = ( 0 , 1 ), keepdims = True )
3090+
3091+ d_normalized = dY * self . gamma
3092+
3093+ d_var = np . sum ( d_normalized * x_minus_mean * ( - 0.5 ) / ( self . std ** 3 ), axis = - 1 , keepdims = True )
31183094
3119- return d_input , d_input
3095+ d_mean = np .sum (d_normalized * (- 1.0 / self .std ), axis = - 1 , keepdims = True ) \
3096+ + d_var * np .mean (- 2.0 * x_minus_mean , axis = - 1 , keepdims = True )
3097+
3098+ dx = (d_normalized / self .std ) + (d_var * 2.0 * x_minus_mean / N ) + (d_mean / N )
3099+
3100+ self .d_gamma = d_gamma
3101+ self .d_beta = d_beta
3102+
3103+ return dx , dx
31203104
31213105 def __str__ (self ) -> str :
31223106 return f'AddNorm(epsilon={ self .epsilon } , random_state={ self .random_state } )'
0 commit comments