Skip to content

Commit b43c8d5

Browse files
committed
fix(Dense): correct temporal data handling in forward and backward passes
1 parent 244a6f2 commit b43c8d5

File tree

3 files changed

+99
-85
lines changed

3 files changed

+99
-85
lines changed

examples/classification-regression/mnist_multiclass.ipynb

Lines changed: 37 additions & 39 deletions
Large diffs are not rendered by default.

examples/classification-regression/sentiment_analysis.ipynb

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
"execution_count": 1,
2222
"metadata": {
2323
"ExecuteTime": {
24-
"end_time": "2024-11-06T21:51:28.948615200Z",
25-
"start_time": "2024-11-06T21:51:19.721136Z"
24+
"end_time": "2024-11-09T15:29:05.393532Z",
25+
"start_time": "2024-11-09T15:28:57.267583700Z"
2626
}
2727
},
2828
"outputs": [],
@@ -51,8 +51,8 @@
5151
"execution_count": 2,
5252
"metadata": {
5353
"ExecuteTime": {
54-
"end_time": "2024-11-06T21:51:30.589179800Z",
55-
"start_time": "2024-11-06T21:51:28.950619500Z"
54+
"end_time": "2024-11-09T15:29:06.872553900Z",
55+
"start_time": "2024-11-09T15:29:05.396041700Z"
5656
}
5757
},
5858
"outputs": [],
@@ -72,8 +72,8 @@
7272
"execution_count": 3,
7373
"metadata": {
7474
"ExecuteTime": {
75-
"end_time": "2024-11-06T21:51:30.871205900Z",
76-
"start_time": "2024-11-06T21:51:30.590182500Z"
75+
"end_time": "2024-11-09T15:29:07.138228500Z",
76+
"start_time": "2024-11-09T15:29:06.873553300Z"
7777
}
7878
},
7979
"outputs": [
@@ -150,8 +150,8 @@
150150
"execution_count": 4,
151151
"metadata": {
152152
"ExecuteTime": {
153-
"end_time": "2024-11-06T21:51:30.899961500Z",
154-
"start_time": "2024-11-06T21:51:30.871205900Z"
153+
"end_time": "2024-11-09T15:29:07.182100500Z",
154+
"start_time": "2024-11-09T15:29:07.139267400Z"
155155
}
156156
},
157157
"outputs": [],
@@ -176,8 +176,8 @@
176176
"execution_count": 5,
177177
"metadata": {
178178
"ExecuteTime": {
179-
"end_time": "2024-11-06T21:51:30.904961800Z",
180-
"start_time": "2024-11-06T21:51:30.886456800Z"
179+
"end_time": "2024-11-09T15:29:07.185659500Z",
180+
"start_time": "2024-11-09T15:29:07.154336500Z"
181181
}
182182
},
183183
"outputs": [
@@ -189,8 +189,8 @@
189189
"-------------------------------------------------\n",
190190
"Layer 1: Input(input_shape=(200,))\n",
191191
"Layer 2: Embedding(input_dim=10000, output_dim=100)\n",
192-
"Layer 3: Bidirectional(layer=LSTM(units=32, return_sequences=True, return_state=False, random_state=None))\n",
193-
"Layer 4: Attention(use_scale=True, score_mode=dot)\n",
192+
"Layer 3: Bidirectional(layer=LSTM(units=32, return_sequences=True, return_state=False))\n",
193+
"Layer 4: Attention(use_scale=True, score_mode=dot, return_sequences=False)\n",
194194
"Layer 5: Dense(units=1)\n",
195195
"Layer 6: Activation(Sigmoid)\n",
196196
"-------------------------------------------------\n",
@@ -215,30 +215,37 @@
215215
},
216216
{
217217
"cell_type": "code",
218-
"execution_count": 7,
218+
"execution_count": 6,
219219
"metadata": {
220220
"ExecuteTime": {
221-
"end_time": "2024-11-06T22:17:05.632380200Z",
222-
"start_time": "2024-11-06T22:17:05.625379900Z"
221+
"end_time": "2024-11-09T15:57:58.751713500Z",
222+
"start_time": "2024-11-09T15:29:07.168952900Z"
223223
}
224224
},
225225
"outputs": [
226226
{
227227
"name": "stdout",
228228
"output_type": "stream",
229229
"text": [
230-
"\n",
231-
"[==============================] 100% Epoch 1/10 - loss: 0.6193 - accuracy: 0.7079 - 248.72s - val_accuracy: 0.8013\n",
232-
"[==============================] 100% Epoch 2/10 - loss: 0.4215 - accuracy: 0.8477 - 264.70s - val_accuracy: 0.8504\n",
233-
"[==============================] 100% Epoch 3/10 - loss: 0.3301 - accuracy: 0.8799 - 266.74s - val_accuracy: 0.8624\n",
234-
"[==============================] 100% Epoch 4/10 - loss: 0.2835 - accuracy: 0.8954 - 255.44s - val_accuracy: 0.8677\n",
235-
"[==============================] 100% Epoch 5/10 - loss: 0.2519 - accuracy: 0.9093 - 239.53s - val_accuracy: 0.8710\n",
236-
"[==============================] 100% Epoch 6/10 - loss: 0.2283 - accuracy: 0.9183 - 239.53s - val_accuracy: 0.8728\n",
237-
"[==============================] 100% Epoch 7/10 - loss: 0.2090 - accuracy: 0.9260 - 239.53s - val_accuracy: 0.8802\n",
238-
"[==============================] 100% Epoch 8/10 - loss: 0.1926 - accuracy: 0.9320 - 239.53s - val_accuracy: 0.8884\n",
239-
"[==============================] 100% Epoch 9/10 - loss: 0.1784 - accuracy: 0.9376 - 239.53s - val_accuracy: 0.8902\n",
240-
"[==============================] 100% Epoch 10/10 - loss: 0.1660 - accuracy: 0.9423 - 239.53s - val_accuracy: 0.9000\n"
230+
"[==============================] 100% Epoch 1/10 - loss: 0.4424 - accuracy: 0.7944 - 118.13s - val_accuracy: 0.8490\n",
231+
"[==============================] 100% Epoch 2/10 - loss: 0.2401 - accuracy: 0.9084 - 120.27s - val_accuracy: 0.8170\n",
232+
"[==============================] 100% Epoch 3/10 - loss: 0.1814 - accuracy: 0.9332 - 121.17s - val_accuracy: 0.8602\n",
233+
"[==============================] 100% Epoch 4/10 - loss: 0.1479 - accuracy: 0.9485 - 118.24s - val_accuracy: 0.8509\n",
234+
"[==============================] 100% Epoch 5/10 - loss: 0.1056 - accuracy: 0.9649 - 120.75s - val_accuracy: 0.8637\n",
235+
"[==============================] 100% Epoch 6/10 - loss: 0.0854 - accuracy: 0.9735 - 118.61s - val_accuracy: 0.8549\n",
236+
"[==============================] 100% Epoch 7/10 - loss: 0.0871 - accuracy: 0.9728 - 120.97s - val_accuracy: 0.8567\n",
237+
"[==============================] 100% Epoch 8/10 - loss: 0.0629 - accuracy: 0.9799 - 117.70s - val_accuracy: 0.8515\n",
238+
"[==============================] 100% Epoch 9/10 - loss: 0.0533 - accuracy: 0.9840 - 120.13s - val_accuracy: 0.8463\n",
239+
"[==============================] 100% Epoch 10/10 - loss: 0.0394 - accuracy: 0.9890 - 118.95s - val_accuracy: 0.8444\n"
241240
]
241+
},
242+
{
243+
"data": {
244+
"text/plain": ""
245+
},
246+
"execution_count": 6,
247+
"metadata": {},
248+
"output_type": "execute_result"
242249
}
243250
],
244251
"source": [
@@ -254,20 +261,20 @@
254261
},
255262
{
256263
"cell_type": "code",
257-
"execution_count": null,
264+
"execution_count": 7,
258265
"metadata": {
259266
"ExecuteTime": {
260-
"end_time": "2024-11-06T22:17:25.754433600Z",
261-
"start_time": "2024-11-06T22:17:14.398517800Z"
267+
"end_time": "2024-11-09T15:58:20.883275400Z",
268+
"start_time": "2024-11-09T15:57:58.748498900Z"
262269
}
263270
},
264271
"outputs": [
265272
{
266273
"name": "stdout",
267274
"output_type": "stream",
268275
"text": [
269-
"Loss: 1.4010948021794365\n",
270-
"Accuracy: 0.881\n"
276+
"Loss: 3.0061621606978313\n",
277+
"Accuracy: 0.8592\n"
271278
]
272279
}
273280
],

neuralnetlib/layers.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,30 +136,39 @@ def initialize_weights(self, input_size: int):
136136

137137
def forward_pass(self, input_data: np.ndarray) -> np.ndarray:
138138
self.input_shape = input_data.shape
139+
self.input = input_data
139140

140141
if len(input_data.shape) == 3:
141142
batch_size, timesteps, features = input_data.shape
142-
input_data = input_data.mean(axis=1)
143-
143+
input_reshaped = input_data.reshape(-1, features)
144+
145+
if self.weights is None:
146+
self.initialize_weights(features)
147+
148+
output = np.dot(input_reshaped, self.weights) + self.bias
149+
150+
return output.reshape(batch_size, timesteps, self.units)
151+
144152
if self.weights is None:
145153
self.initialize_weights(input_data.shape[1])
146-
147-
self.input = input_data
148-
output = np.dot(self.input, self.weights) + self.bias
149-
return output
154+
155+
return np.dot(input_data, self.weights) + self.bias
150156

151157
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
152158
if len(output_error.shape) == 3:
153-
output_error = output_error.mean(axis=1)
154-
159+
batch_size, timesteps, _ = output_error.shape
160+
output_error_reshaped = output_error.reshape(-1, output_error.shape[-1])
161+
input_reshaped = self.input.reshape(-1, self.input.shape[-1])
162+
163+
input_error = np.dot(output_error_reshaped, self.weights.T)
164+
self.d_weights = np.dot(input_reshaped.T, output_error_reshaped)
165+
self.d_bias = np.sum(output_error_reshaped, axis=0, keepdims=True)
166+
167+
return input_error.reshape(batch_size, timesteps, -1)
168+
155169
input_error = np.dot(output_error, self.weights.T)
156170
self.d_weights = np.dot(self.input.T, output_error)
157171
self.d_bias = np.sum(output_error, axis=0, keepdims=True)
158-
159-
if len(self.input_shape) == 3:
160-
input_error = np.expand_dims(input_error, 1)
161-
input_error = np.repeat(input_error, self.input_shape[1], axis=1)
162-
163172
return input_error
164173

165174
def get_config(self) -> dict:
@@ -1680,7 +1689,7 @@ def from_config(config: dict):
16801689

16811690

16821691
class Attention(Layer):
1683-
def __init__(self, use_scale: bool = True, score_mode: str = "dot", return_sequences: bool = True):
1692+
def __init__(self, use_scale: bool = True, score_mode: str = "dot", return_sequences: bool = False):
16841693
super().__init__()
16851694
self.use_scale = use_scale
16861695
self.score_mode = score_mode
@@ -1714,7 +1723,7 @@ def forward_pass(self, input_data: np.ndarray) -> np.ndarray:
17141723
context[i] = np.dot(attention_weights[i], input_data[i])
17151724

17161725
if not self.return_sequences:
1717-
return np.mean(context, axis=1)
1726+
context = np.mean(context, axis=1)
17181727
return context
17191728

17201729
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)