Skip to content

Commit e7894b9

Browse files
committed
perf(Adam/RNN): better gradient handling
1 parent 4a4448d commit e7894b9

File tree

3 files changed

+93
-65
lines changed

3 files changed

+93
-65
lines changed

neuralnetlib/layers.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,9 +1657,8 @@ def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
16571657
np.sum(self.cell.dWo**2) + np.sum(self.cell.dUo**2) + np.sum(self.cell.dbo**2))
16581658

16591659
global_norm = np.sqrt(squared_norm_sum)
1660-
1661-
scaling_factor = min(1.0, self.clip_value / (global_norm + 1e-8)) / timesteps
1662-
if scaling_factor < 1.0: # Only scale if necessary
1660+
scaling_factor = min(1.0, self.clip_value / (global_norm + 1e-8))
1661+
if scaling_factor < 1.0:
16631662
dx *= scaling_factor
16641663
for grad in self.cell.__dict__:
16651664
if grad.startswith('d'):
@@ -1709,43 +1708,54 @@ def __str__(self) -> str:
17091708
def forward_pass(self, input_data: np.ndarray, training: bool = True) -> np.ndarray:
17101709
self.forward_output = self.forward_layer.forward_pass(
17111710
input_data, training)
1711+
17121712
backward_input = input_data[:, ::-1, :]
17131713
self.backward_output = self.backward_layer.forward_pass(
17141714
backward_input, training)
17151715

17161716
if isinstance(self.forward_output, tuple):
1717-
forward_seq, forward_h, forward_c = self.forward_output
1718-
backward_seq, backward_h, backward_c = self.backward_output
1719-
17201717
if self.forward_layer.return_sequences:
1718+
forward_seq, forward_h, forward_c = self.forward_output
1719+
backward_seq, backward_h, backward_c = self.backward_output
1720+
17211721
backward_seq = backward_seq[:, ::-1, :]
1722-
return np.concatenate([forward_seq, backward_seq], axis=-1), \
1723-
np.concatenate([forward_h, backward_h], axis=-1), \
1724-
np.concatenate([forward_c, backward_c], axis=-1)
1722+
1723+
combined_seq = np.concatenate([forward_seq, backward_seq], axis=-1)
1724+
combined_h = np.concatenate([forward_h, backward_h], axis=-1)
1725+
combined_c = np.concatenate([forward_c, backward_c], axis=-1)
1726+
1727+
return combined_seq, combined_h, combined_c
17251728
else:
1726-
return np.concatenate([forward_h, backward_h], axis=-1)
1729+
forward_h, _, forward_c = self.forward_output
1730+
backward_h, _, backward_c = self.backward_output
1731+
combined_h = np.concatenate([forward_h, backward_h], axis=-1)
1732+
combined_c = np.concatenate([forward_c, backward_c], axis=-1)
1733+
return combined_h, combined_h, combined_c
17271734
else:
17281735
if self.forward_layer.return_sequences:
1729-
self.backward_output = self.backward_output[:, ::-1, :]
1730-
return np.concatenate([self.forward_output, self.backward_output], axis=-1)
1736+
backward_seq = self.backward_output[:, ::-1, :]
1737+
return np.concatenate([self.forward_output, backward_seq], axis=-1)
1738+
else:
1739+
return np.concatenate([self.forward_output, self.backward_output], axis=-1)
17311740

17321741
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
17331742
forward_dim = output_error.shape[-1] // 2
1734-
1743+
17351744
if len(output_error.shape) == 3:
17361745
forward_error = output_error[:, :, :forward_dim]
17371746
backward_error = output_error[:, :, forward_dim:]
1747+
17381748
backward_error = backward_error[:, ::-1, :]
17391749
else:
17401750
forward_error = output_error[:, :forward_dim]
17411751
backward_error = output_error[:, forward_dim:]
1742-
1752+
17431753
forward_dx = self.forward_layer.backward_pass(forward_error)
17441754
backward_dx = self.backward_layer.backward_pass(backward_error)
1745-
1755+
17461756
if len(output_error.shape) == 3:
17471757
backward_dx = backward_dx[:, ::-1, :]
1748-
1758+
17491759
return forward_dx + backward_dx
17501760

17511761
def get_config(self) -> dict:
@@ -2045,7 +2055,7 @@ def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
20452055
np.sum(self.cell.dWh**2) + np.sum(self.cell.dUh**2) + np.sum(self.cell.dbh**2))
20462056

20472057
global_norm = np.sqrt(squared_norm_sum)
2048-
scaling_factor = min(1.0, self.clip_value / (global_norm + 1e-8)) / timesteps
2058+
scaling_factor = min(1.0, self.clip_value / (global_norm + 1e-8))
20492059
if scaling_factor < 1.0:
20502060
dx *= scaling_factor
20512061
for grad in self.cell.__dict__:
@@ -2122,7 +2132,7 @@ def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
21222132
input_data = self.cache['input']
21232133

21242134
if not self.return_sequences:
2125-
output_error = np.expand_dims(output_error, 1) / seq_length
2135+
output_error = np.expand_dims(output_error, 1)
21262136
output_error = np.repeat(output_error, seq_length, axis=1)
21272137

21282138
d_input = np.zeros((batch_size, seq_length, features))

neuralnetlib/model.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from neuralnetlib.activations import ActivationFunction
1010
from neuralnetlib.layers import compatibility_dict, Layer, Input, Activation, Dropout, TextVectorization, LSTM, GRU, Bidirectional, Embedding, Attention, Dense
11-
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, SparseCategoricalCrossentropy
11+
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, SparseCategoricalCrossentropy, BinaryCrossentropy
1212
from neuralnetlib.optimizers import Optimizer
1313
from neuralnetlib.preprocessing import PCA
1414
from neuralnetlib.utils import shuffle, progress_bar, is_interactive, is_display_available, History
@@ -83,33 +83,44 @@ def forward_pass(self, X: np.ndarray, training: bool = True) -> np.ndarray:
8383

8484
def backward_pass(self, error: np.ndarray):
8585
for i, layer in enumerate(reversed(self.layers)):
86-
if i == 0 and isinstance(layer, Activation) and type(layer.activation_function).__name__ == "Softmax" and (
87-
isinstance(self.loss_function, CategoricalCrossentropy or isinstance(self.loss_function, SparseCategoricalCrossentropy))):
88-
error = self.predictions - self.y_true
86+
if i == 0 and isinstance(layer, Activation):
87+
if (type(layer.activation_function).__name__ == "Softmax" and
88+
(isinstance(self.loss_function, CategoricalCrossentropy))):
89+
error = self.predictions - self.y_true
90+
91+
elif (type(layer.activation_function).__name__ == "Sigmoid" and
92+
isinstance(self.loss_function, BinaryCrossentropy)):
93+
error = (self.predictions - self.y_true) / (self.predictions * (1 - self.predictions) + 1e-15)
94+
95+
elif isinstance(self.loss_function, SparseCategoricalCrossentropy):
96+
y_true_one_hot = np.zeros_like(self.predictions)
97+
y_true_one_hot[np.arange(len(self.y_true)), self.y_true] = 1
98+
error = self.predictions - y_true_one_hot
8999
else:
90100
error = layer.backward_pass(error)
91101

92-
if hasattr(layer, 'weights'):
93-
if hasattr(layer, 'd_weights') and hasattr(layer, 'd_bias'):
94-
self.optimizer.update(len(self.layers) - 1 - i, layer.weights, layer.d_weights, layer.bias,
95-
layer.d_bias)
96-
elif hasattr(layer, 'd_weights'):
97-
self.optimizer.update(
98-
len(self.layers) - 1 - i, layer.weights, layer.d_weights)
99-
100-
elif isinstance(layer, LSTM):
101-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wf, layer.cell.dWf, layer.cell.bf, layer.cell.dbf)
102-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wi, layer.cell.dWi, layer.cell.bi, layer.cell.dbi)
103-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wc, layer.cell.dWc, layer.cell.bc, layer.cell.dbc)
104-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wo, layer.cell.dWo, layer.cell.bo, layer.cell.dbo)
102+
if isinstance(layer, LSTM):
103+
layer_idx = len(self.layers) - 1 - i
104+
cell = layer.cell
105+
self.optimizer.update(layer_idx, cell.Wf, cell.dWf, cell.bf, cell.dbf)
106+
self.optimizer.update(layer_idx, cell.Wi, cell.dWi, cell.bi, cell.dbi)
107+
self.optimizer.update(layer_idx, cell.Wc, cell.dWc, cell.bc, cell.dbc)
108+
self.optimizer.update(layer_idx, cell.Wo, cell.dWo, cell.bo, cell.dbo)
109+
105110
elif isinstance(layer, GRU):
106-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wz, layer.cell.dWz, layer.cell.bz, layer.cell.dbz)
107-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wr, layer.cell.dWr, layer.cell.br, layer.cell.dbr)
108-
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wh, layer.cell.dWh, layer.cell.bh, layer.cell.dbh)
109-
elif hasattr(layer, 'd_weights') and hasattr(layer, 'd_bias'):
110-
self.optimizer.update(len(self.layers) - 1 - i, layer.weights, layer.d_weights, layer.bias, layer.d_bias)
111-
elif hasattr(layer, 'd_weights'):
112-
self.optimizer.update(len(self.layers) - 1 - i, layer.weights, layer.d_weights)
111+
layer_idx = len(self.layers) - 1 - i
112+
cell = layer.cell
113+
self.optimizer.update(layer_idx, cell.Wz, cell.dWz, cell.bz, cell.dbz)
114+
self.optimizer.update(layer_idx, cell.Wr, cell.dWr, cell.br, cell.dbr)
115+
self.optimizer.update(layer_idx, cell.Wh, cell.dWh, cell.bh, cell.dbh)
116+
117+
elif hasattr(layer, 'weights'):
118+
layer_idx = len(self.layers) - 1 - i
119+
if hasattr(layer, 'd_bias'):
120+
self.optimizer.update(layer_idx, layer.weights, layer.d_weights,
121+
layer.bias, layer.d_bias)
122+
else:
123+
self.optimizer.update(layer_idx, layer.weights, layer.d_weights)
113124

114125
def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float:
115126
self.y_true = y_batch
@@ -434,4 +445,4 @@ def __update_plot(self, epoch: int, x_train: np.ndarray, y_train: np.ndarray, ra
434445
ax.set_title(f"Decision Boundary (Epoch {epoch + 1})")
435446

436447
fig.canvas.draw()
437-
plt.pause(0.1)
448+
plt.pause(0.1)

neuralnetlib/optimizers.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def __init__(self, learning_rate: float = 0.001, beta_1: float = 0.9, beta_2: fl
130130
self.epsilon = epsilon
131131
self.clip_norm = clip_norm
132132
self.clip_value = clip_value
133-
self.t = 0
133+
self.t = 1
134+
self.needs_time_increment = True
134135

135136
self.m_w, self.v_w = {}, {}
136137
self.m_b, self.v_b = {}, {}
@@ -159,36 +160,42 @@ def _compute_moments(self, param: np.ndarray, grad: np.ndarray, m: np.ndarray, v
159160
m = self.beta_1 * m + (1 - self.beta_1) * grad
160161
v = self.beta_2 * v + (1 - self.beta_2) * np.square(grad)
161162

162-
beta1_t = np.minimum(self.beta_1 ** self.t, 1 - self._min_denom)
163-
beta2_t = np.minimum(self.beta_2 ** self.t, 1 - self._min_denom)
163+
m_hat = m / (1 - self.beta_1 ** self.t)
164+
v_hat = v / (1 - self.beta_2 ** self.t)
164165

165-
m_hat = m / (1 - beta1_t)
166-
v_hat = v / (1 - beta2_t)
167-
168-
denom = np.sqrt(v_hat) + self.epsilon
169-
update = self.learning_rate * m_hat / np.maximum(denom, self._min_denom)
166+
denom = np.sqrt(v_hat + self.epsilon)
167+
update = self.learning_rate * m_hat / denom
170168

171169
update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
172170
param -= update
173171

174172
return param, m, v
175173

176174
def update(self, layer_index: int, weights: np.ndarray, weights_grad: np.ndarray, bias: np.ndarray, bias_grad: np.ndarray) -> None:
175+
if weights_grad is None and bias_grad is None:
176+
return
177+
177178
if layer_index not in self.m_w:
178-
self.m_w[layer_index] = np.zeros_like(weights)
179-
self.v_w[layer_index] = np.zeros_like(weights)
180-
self.m_b[layer_index] = np.zeros_like(bias)
181-
self.v_b[layer_index] = np.zeros_like(bias)
182-
183-
self.t += 1
179+
if weights is not None:
180+
self.m_w[layer_index] = np.zeros_like(weights)
181+
self.v_w[layer_index] = np.zeros_like(weights)
182+
if bias is not None:
183+
self.m_b[layer_index] = np.zeros_like(bias)
184+
self.v_b[layer_index] = np.zeros_like(bias)
185+
186+
if self.needs_time_increment:
187+
self.t += 1
188+
self.needs_time_increment = False
189+
190+
if weights is not None:
191+
weights, self.m_w[layer_index], self.v_w[layer_index] = self._compute_moments(
192+
weights, weights_grad, self.m_w[layer_index], self.v_w[layer_index]
193+
)
184194

185-
weights, self.m_w[layer_index], self.v_w[layer_index] = self._compute_moments(
186-
weights, weights_grad, self.m_w[layer_index], self.v_w[layer_index]
187-
)
188-
189-
bias, self.m_b[layer_index], self.v_b[layer_index] = self._compute_moments(
190-
bias, bias_grad, self.m_b[layer_index], self.v_b[layer_index]
191-
)
195+
if bias is not None:
196+
bias, self.m_b[layer_index], self.v_b[layer_index] = self._compute_moments(
197+
bias, bias_grad, self.m_b[layer_index], self.v_b[layer_index]
198+
)
192199

193200
def get_config(self) -> dict:
194201
return {
@@ -226,4 +233,4 @@ def from_config(config: dict):
226233
def __str__(self):
227234
return (f"{self.__class__.__name__}(learning_rate={self.learning_rate}, "
228235
f"beta_1={self.beta_1}, beta_2={self.beta_2}, epsilon={self.epsilon}, "
229-
f"clip_norm={self.clip_norm}, clip_value={self.clip_value})")
236+
f"clip_norm={self.clip_norm}, clip_value={self.clip_value})")

0 commit comments

Comments
 (0)