Skip to content

Commit a78eef5

Browse files
committed
fix: GRU to Model exceptions
1 parent 014e044 commit a78eef5

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

neuralnetlib/model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from neuralnetlib.activations import ActivationFunction
10-
from neuralnetlib.layers import compatibility_dict, Layer, Input, Activation, Dropout, TextVectorization, LSTM, Bidirectional, Embedding, Attention, Dense
10+
from neuralnetlib.layers import compatibility_dict, Layer, Input, Activation, Dropout, TextVectorization, LSTM, GRU, Bidirectional, Embedding, Attention, Dense
1111
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, SparseCategoricalCrossentropy
1212
from neuralnetlib.optimizers import Optimizer
1313
from neuralnetlib.preprocessing import PCA
@@ -75,7 +75,7 @@ def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str,
7575

7676
def forward_pass(self, X: np.ndarray, training: bool = True) -> np.ndarray:
7777
for layer in self.layers:
78-
if isinstance(layer, (Dropout, LSTM, Bidirectional)):
78+
if isinstance(layer, (Dropout, LSTM, Bidirectional, GRU)):
7979
X = layer.forward_pass(X, training)
8080
else:
8181
X = layer.forward_pass(X)
@@ -102,6 +102,10 @@ def backward_pass(self, error: np.ndarray):
102102
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wi, layer.cell.dWi, layer.cell.bi, layer.cell.dbi)
103103
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wc, layer.cell.dWc, layer.cell.bc, layer.cell.dbc)
104104
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wo, layer.cell.dWo, layer.cell.bo, layer.cell.dbo)
105+
elif isinstance(layer, GRU):
106+
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wz, layer.cell.dWz, layer.cell.bz, layer.cell.dbz)
107+
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wr, layer.cell.dWr, layer.cell.br, layer.cell.dbr)
108+
self.optimizer.update(len(self.layers) - 1 - i, layer.cell.Wh, layer.cell.dWh, layer.cell.bh, layer.cell.dbh)
105109
elif hasattr(layer, 'd_weights') and hasattr(layer, 'd_bias'):
106110
self.optimizer.update(len(self.layers) - 1 - i, layer.weights, layer.d_weights, layer.bias, layer.d_bias)
107111
elif hasattr(layer, 'd_weights'):
@@ -116,7 +120,7 @@ def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float:
116120

117121
if error.ndim == 1:
118122
error = error[:, None]
119-
elif isinstance(self.layers[-1], (LSTM, Bidirectional)) and self.layers[-1].return_sequences:
123+
elif isinstance(self.layers[-1], (LSTM, Bidirectional, GRU)) and self.layers[-1].return_sequences:
120124
error = error.reshape(error.shape[0], error.shape[1], -1)
121125

122126
self.backward_pass(error)
@@ -166,12 +170,12 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray,
166170
if hasattr(layer, 'random_state'):
167171
layer.random_state = random_state
168172

169-
has_lstm = any(isinstance(layer, (LSTM, Bidirectional)) for layer in self.layers)
173+
has_lstm_or_gru = any(isinstance(layer, (LSTM, Bidirectional, GRU)) for layer in self.layers)
170174
has_embedding = any(isinstance(layer, Embedding) for layer in self.layers)
171175

172-
if has_lstm and not has_embedding:
176+
if has_lstm_or_gru and not has_embedding:
173177
if len(x_train.shape) != 3:
174-
raise ValueError("Input data must be 3D (batch_size, time_steps, features) for LSTM layers without Embedding")
178+
raise ValueError("Input data must be 3D (batch_size, time_steps, features) for LSTM/GRU layers without Embedding")
175179
elif has_embedding:
176180
if len(x_train.shape) != 2:
177181
raise ValueError("Input data must be 2D (batch_size, sequence_length) when using Embedding layer")

0 commit comments

Comments
 (0)