Skip to content

Commit 9823501

Browse files
committed
fix(AddNorm): backward pass
1 parent 3c7be9b commit 9823501

File tree

2 files changed

+59
-90
lines changed

2 files changed

+59
-90
lines changed

examples/models-usages/generation/transformer-text-generation/transformer-for-translation.ipynb

Lines changed: 39 additions & 54 deletions
Large diffs are not rendered by default.

neuralnetlib/layers.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,44 +3079,28 @@ def forward_pass(self, inputs: tuple[np.ndarray, np.ndarray]) -> np.ndarray:
30793079
return self.gamma * self.normalized + self.beta
30803080

30813081
def backward_pass(self, output_error: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
3082-
N = output_error.shape[-1]
3083-
batch_size = output_error.shape[0]
3084-
self.step += 1
3085-
3086-
scaled_error = output_error * self.grad_scale
3087-
3088-
d_gamma_raw = np.sum(scaled_error * self.output_before_gamma, axis=(0, 1), keepdims=True)
3089-
d_beta_raw = np.sum(scaled_error, axis=(0, 1), keepdims=True)
3090-
3091-
scale = 1.0 / (batch_size * output_error.shape[1])
3092-
d_gamma = d_gamma_raw * scale
3093-
d_beta = d_beta_raw * scale
3094-
3095-
self.update_gradient_stats(d_gamma)
3096-
3097-
self.d_gamma = self.normalize_gradients(d_gamma)
3098-
self.d_beta = self.normalize_gradients(d_beta)
3099-
3100-
d_normalized = scaled_error * self.gamma
3101-
3102-
d_variance = np.clip(
3103-
-0.5 * np.sum(d_normalized * self.output_before_gamma, axis=-1, keepdims=True) / self.std,
3104-
-self.grad_clip, self.grad_clip
3105-
)
3106-
3107-
d_mean = np.clip(
3108-
-np.sum(d_normalized / self.std, axis=-1, keepdims=True),
3109-
-self.grad_clip, self.grad_clip
3110-
)
3082+
dY = output_error
3083+
B, T, F = dY.shape
3084+
N = F
3085+
3086+
x_minus_mean = self.normalized * self.std
31113087

3112-
d_input = np.clip(
3113-
(d_normalized / self.std +
3114-
2.0 * d_variance * self.output_before_gamma / N +
3115-
d_mean / N),
3116-
-self.grad_clip, self.grad_clip
3117-
)
3088+
d_gamma = np.sum(dY * self.normalized, axis=(0, 1), keepdims=True)
3089+
d_beta = np.sum(dY, axis=(0, 1), keepdims=True)
3090+
3091+
d_normalized = dY * self.gamma
3092+
3093+
d_var = np.sum(d_normalized * x_minus_mean * (-0.5) / (self.std**3), axis=-1, keepdims=True)
31183094

3119-
return d_input, d_input
3095+
d_mean = np.sum(d_normalized * (-1.0 / self.std), axis=-1, keepdims=True) \
3096+
+ d_var * np.mean(-2.0 * x_minus_mean, axis=-1, keepdims=True)
3097+
3098+
dx = (d_normalized / self.std) + (d_var * 2.0 * x_minus_mean / N) + (d_mean / N)
3099+
3100+
self.d_gamma = d_gamma
3101+
self.d_beta = d_beta
3102+
3103+
return dx, dx
31203104

31213105
def __str__(self) -> str:
31223106
return f'AddNorm(epsilon={self.epsilon}, random_state={self.random_state})'

0 commit comments

Comments
 (0)