88
99from neuralnetlib .activations import ActivationFunction
1010from neuralnetlib .layers import compatibility_dict , Layer , Input , Activation , Dropout , TextVectorization , LSTM , GRU , Bidirectional , Embedding , Attention , Dense
11- from neuralnetlib .losses import LossFunction , CategoricalCrossentropy , SparseCategoricalCrossentropy
11+ from neuralnetlib .losses import LossFunction , CategoricalCrossentropy , SparseCategoricalCrossentropy , BinaryCrossentropy
1212from neuralnetlib .optimizers import Optimizer
1313from neuralnetlib .preprocessing import PCA
1414from neuralnetlib .utils import shuffle , progress_bar , is_interactive , is_display_available , History
@@ -83,33 +83,44 @@ def forward_pass(self, X: np.ndarray, training: bool = True) -> np.ndarray:
8383
8484 def backward_pass (self , error : np .ndarray ):
8585 for i , layer in enumerate (reversed (self .layers )):
86- if i == 0 and isinstance (layer , Activation ) and type (layer .activation_function ).__name__ == "Softmax" and (
87- isinstance (self .loss_function , CategoricalCrossentropy or isinstance (self .loss_function , SparseCategoricalCrossentropy ))):
88- error = self .predictions - self .y_true
86+ if i == 0 and isinstance (layer , Activation ):
87+ if (type (layer .activation_function ).__name__ == "Softmax" and
88+ (isinstance (self .loss_function , CategoricalCrossentropy ))):
89+ error = self .predictions - self .y_true
90+
91+ elif (type (layer .activation_function ).__name__ == "Sigmoid" and
92+ isinstance (self .loss_function , BinaryCrossentropy )):
93+ error = (self .predictions - self .y_true ) / (self .predictions * (1 - self .predictions ) + 1e-15 )
94+
95+ elif isinstance (self .loss_function , SparseCategoricalCrossentropy ):
96+ y_true_one_hot = np .zeros_like (self .predictions )
97+ y_true_one_hot [np .arange (len (self .y_true )), self .y_true ] = 1
98+ error = self .predictions - y_true_one_hot
8999 else :
90100 error = layer .backward_pass (error )
91101
92- if hasattr (layer , 'weights' ):
93- if hasattr (layer , 'd_weights' ) and hasattr (layer , 'd_bias' ):
94- self .optimizer .update (len (self .layers ) - 1 - i , layer .weights , layer .d_weights , layer .bias ,
95- layer .d_bias )
96- elif hasattr (layer , 'd_weights' ):
97- self .optimizer .update (
98- len (self .layers ) - 1 - i , layer .weights , layer .d_weights )
99-
100- elif isinstance (layer , LSTM ):
101- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wf , layer .cell .dWf , layer .cell .bf , layer .cell .dbf )
102- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wi , layer .cell .dWi , layer .cell .bi , layer .cell .dbi )
103- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wc , layer .cell .dWc , layer .cell .bc , layer .cell .dbc )
104- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wo , layer .cell .dWo , layer .cell .bo , layer .cell .dbo )
102+ if isinstance (layer , LSTM ):
103+ layer_idx = len (self .layers ) - 1 - i
104+ cell = layer .cell
105+ self .optimizer .update (layer_idx , cell .Wf , cell .dWf , cell .bf , cell .dbf )
106+ self .optimizer .update (layer_idx , cell .Wi , cell .dWi , cell .bi , cell .dbi )
107+ self .optimizer .update (layer_idx , cell .Wc , cell .dWc , cell .bc , cell .dbc )
108+ self .optimizer .update (layer_idx , cell .Wo , cell .dWo , cell .bo , cell .dbo )
109+
105110 elif isinstance (layer , GRU ):
106- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wz , layer .cell .dWz , layer .cell .bz , layer .cell .dbz )
107- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wr , layer .cell .dWr , layer .cell .br , layer .cell .dbr )
108- self .optimizer .update (len (self .layers ) - 1 - i , layer .cell .Wh , layer .cell .dWh , layer .cell .bh , layer .cell .dbh )
109- elif hasattr (layer , 'd_weights' ) and hasattr (layer , 'd_bias' ):
110- self .optimizer .update (len (self .layers ) - 1 - i , layer .weights , layer .d_weights , layer .bias , layer .d_bias )
111- elif hasattr (layer , 'd_weights' ):
112- self .optimizer .update (len (self .layers ) - 1 - i , layer .weights , layer .d_weights )
111+ layer_idx = len (self .layers ) - 1 - i
112+ cell = layer .cell
113+ self .optimizer .update (layer_idx , cell .Wz , cell .dWz , cell .bz , cell .dbz )
114+ self .optimizer .update (layer_idx , cell .Wr , cell .dWr , cell .br , cell .dbr )
115+ self .optimizer .update (layer_idx , cell .Wh , cell .dWh , cell .bh , cell .dbh )
116+
117+ elif hasattr (layer , 'weights' ):
118+ layer_idx = len (self .layers ) - 1 - i
119+ if hasattr (layer , 'd_bias' ):
120+ self .optimizer .update (layer_idx , layer .weights , layer .d_weights ,
121+ layer .bias , layer .d_bias )
122+ else :
123+ self .optimizer .update (layer_idx , layer .weights , layer .d_weights )
113124
114125 def train_on_batch (self , x_batch : np .ndarray , y_batch : np .ndarray ) -> float :
115126 self .y_true = y_batch
@@ -434,4 +445,4 @@ def __update_plot(self, epoch: int, x_train: np.ndarray, y_train: np.ndarray, ra
434445 ax .set_title (f"Decision Boundary (Epoch { epoch + 1 } )" )
435446
436447 fig .canvas .draw ()
437- plt .pause (0.1 )
448+ plt .pause (0.1 )
0 commit comments