77import numpy as np
88
99from neuralnetlib .activations import ActivationFunction
10- from neuralnetlib .layers import compatibility_dict , Layer , Input , Activation , Dropout , TextVectorization , LSTM , Bidirectional , Embedding , Attention , Dense
10+ from neuralnetlib .layers import compatibility_dict , Layer , Input , Activation , Dropout , TextVectorization , LSTM , GRU , Bidirectional , Embedding , Attention , Dense
1111from neuralnetlib .losses import LossFunction , CategoricalCrossentropy , SparseCategoricalCrossentropy
1212from neuralnetlib .optimizers import Optimizer
1313from neuralnetlib .preprocessing import PCA
@@ -75,7 +75,7 @@ def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str,
7575
7676 def forward_pass (self , X : np .ndarray , training : bool = True ) -> np .ndarray :
7777 for layer in self .layers :
78- if isinstance (layer , (Dropout , LSTM , Bidirectional )):
78+ if isinstance (layer , (Dropout , LSTM , Bidirectional , GRU )):
7979 X = layer .forward_pass (X , training )
8080 else :
8181 X = layer .forward_pass (X )
@@ -102,6 +102,10 @@ def backward_pass(self, error: np.ndarray):
102102 self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wi , layer .cell .dWi , layer .cell .bi , layer .cell .dbi )
103103 self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wc , layer .cell .dWc , layer .cell .bc , layer .cell .dbc )
104104 self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wo , layer .cell .dWo , layer .cell .bo , layer .cell .dbo )
105+ elif isinstance (layer , GRU ):
106+ self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wz , layer .cell .dWz , layer .cell .bz , layer .cell .dbz )
107+ self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wr , layer .cell .dWr , layer .cell .br , layer .cell .dbr )
108+ self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wh , layer .cell .dWh , layer .cell .bh , layer .cell .dbh )
105109 elif hasattr (layer , 'd_weights' ) and hasattr (layer , 'd_bias' ):
106110 self .optimizer .update (len (self .layers ) - 1 - i , layer .weights , layer .d_weights , layer .bias , layer .d_bias )
107111 elif hasattr (layer , 'd_weights' ):
@@ -116,7 +120,7 @@ def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float:
116120
117121 if error .ndim == 1 :
118122 error = error [:, None ]
119- elif isinstance (self .layers [- 1 ], (LSTM , Bidirectional )) and self .layers [- 1 ].return_sequences :
123+ elif isinstance (self .layers [- 1 ], (LSTM , Bidirectional , GRU )) and self .layers [- 1 ].return_sequences :
120124 error = error .reshape (error .shape [0 ], error .shape [1 ], - 1 )
121125
122126 self .backward_pass (error )
@@ -166,12 +170,12 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray,
166170 if hasattr (layer , 'random_state' ):
167171 layer .random_state = random_state
168172
169- has_lstm = any (isinstance (layer , (LSTM , Bidirectional )) for layer in self .layers )
173+ has_lstm_or_gru = any (isinstance (layer , (LSTM , Bidirectional , GRU )) for layer in self .layers )
170174 has_embedding = any (isinstance (layer , Embedding ) for layer in self .layers )
171175
172- if has_lstm and not has_embedding :
176+ if has_lstm_or_gru and not has_embedding :
173177 if len (x_train .shape ) != 3 :
174- raise ValueError ("Input data must be 3D (batch_size, time_steps, features) for LSTM layers without Embedding" )
178+ raise ValueError ("Input data must be 3D (batch_size, time_steps, features) for LSTM/GRU layers without Embedding" )
175179 elif has_embedding :
176180 if len (x_train .shape ) != 2 :
177181 raise ValueError ("Input data must be 2D (batch_size, sequence_length) when using Embedding layer" )
0 commit comments