11import numpy as np
2+ from typing import List , Dict , Any , Union
23
34from neuralnetlib .metrics import Metric
5+ from neuralnetlib .layers import Layer
6+
7+
8+ class ModelWeightManager :
9+ @staticmethod
10+ def get_model_weights (model ) -> List [np .ndarray ]:
11+ """Extract weights from any model type."""
12+ weights = []
13+
14+ if hasattr (model , 'layers' ): # Sequential model
15+ weights .extend ([layer .weights for layer in model .layers if hasattr (layer , 'weights' )])
16+
17+ elif hasattr (model , 'encoder_layers' ) and hasattr (model , 'decoder_layers' ): # Autoencoder
18+ weights .extend ([layer .weights for layer in model .encoder_layers if hasattr (layer , 'weights' )])
19+ weights .extend ([layer .weights for layer in model .decoder_layers if hasattr (layer , 'weights' )])
20+
21+ elif hasattr (model , 'embedding' ): # Transformer
22+ if hasattr (model .embedding , 'weights' ):
23+ weights .append (model .embedding .weights )
24+
25+ for encoder_layer in model .encoder_layers :
26+ if hasattr (encoder_layer , 'attention' ):
27+ weights .extend ([
28+ encoder_layer .attention .query_dense .weights ,
29+ encoder_layer .attention .key_dense .weights ,
30+ encoder_layer .attention .value_dense .weights ,
31+ encoder_layer .attention .output_dense .weights
32+ ])
33+ if hasattr (encoder_layer , 'ffn' ):
34+ weights .extend ([
35+ encoder_layer .ffn .dense1 .weights ,
36+ encoder_layer .ffn .dense2 .weights
37+ ])
38+
39+ for decoder_layer in model .decoder_layers :
40+ if hasattr (decoder_layer , 'self_attention' ):
41+ weights .extend ([
42+ decoder_layer .self_attention .query_dense .weights ,
43+ decoder_layer .self_attention .key_dense .weights ,
44+ decoder_layer .self_attention .value_dense .weights ,
45+ decoder_layer .self_attention .output_dense .weights
46+ ])
47+ if hasattr (decoder_layer , 'cross_attention' ):
48+ weights .extend ([
49+ decoder_layer .cross_attention .query_dense .weights ,
50+ decoder_layer .cross_attention .key_dense .weights ,
51+ decoder_layer .cross_attention .value_dense .weights ,
52+ decoder_layer .cross_attention .output_dense .weights
53+ ])
54+ if hasattr (decoder_layer , 'ffn' ):
55+ weights .extend ([
56+ decoder_layer .ffn .dense1 .weights ,
57+ decoder_layer .ffn .dense2 .weights
58+ ])
59+
60+ if hasattr (model .output_layer , 'weights' ):
61+ weights .append (model .output_layer .weights )
62+
63+ return weights
64+
65+ @staticmethod
66+ def set_model_weights (model , weights : List [np .ndarray ]) -> None :
67+ """Restore weights to any model type."""
68+ weight_idx = 0
69+
70+ if hasattr (model , 'layers' ): # Sequential model
71+ for layer in model .layers :
72+ if hasattr (layer , 'weights' ):
73+ layer .weights = weights [weight_idx ]
74+ weight_idx += 1
75+
76+ elif hasattr (model , 'encoder_layers' ) and hasattr (model , 'decoder_layers' ): # Autoencoder
77+ for layer in model .encoder_layers :
78+ if hasattr (layer , 'weights' ):
79+ layer .weights = weights [weight_idx ]
80+ weight_idx += 1
81+
82+ for layer in model .decoder_layers :
83+ if hasattr (layer , 'weights' ):
84+ layer .weights = weights [weight_idx ]
85+ weight_idx += 1
86+
87+ elif hasattr (model , 'embedding' ):
88+ if hasattr (model .embedding , 'weights' ):
89+ model .embedding .weights = weights [weight_idx ]
90+ weight_idx += 1
91+
92+ for encoder_layer in model .encoder_layers :
93+ if hasattr (encoder_layer , 'attention' ):
94+ encoder_layer .attention .query_dense .weights = weights [weight_idx ]
95+ encoder_layer .attention .key_dense .weights = weights [weight_idx + 1 ]
96+ encoder_layer .attention .value_dense .weights = weights [weight_idx + 2 ]
97+ encoder_layer .attention .output_dense .weights = weights [weight_idx + 3 ]
98+ weight_idx += 4
99+ if hasattr (encoder_layer , 'ffn' ):
100+ encoder_layer .ffn .dense1 .weights = weights [weight_idx ]
101+ encoder_layer .ffn .dense2 .weights = weights [weight_idx + 1 ]
102+ weight_idx += 2
103+
104+ for decoder_layer in model .decoder_layers :
105+ if hasattr (decoder_layer , 'self_attention' ):
106+ decoder_layer .self_attention .query_dense .weights = weights [weight_idx ]
107+ decoder_layer .self_attention .key_dense .weights = weights [weight_idx + 1 ]
108+ decoder_layer .self_attention .value_dense .weights = weights [weight_idx + 2 ]
109+ decoder_layer .self_attention .output_dense .weights = weights [weight_idx + 3 ]
110+ weight_idx += 4
111+ if hasattr (decoder_layer , 'cross_attention' ):
112+ decoder_layer .cross_attention .query_dense .weights = weights [weight_idx ]
113+ decoder_layer .cross_attention .key_dense .weights = weights [weight_idx + 1 ]
114+ decoder_layer .cross_attention .value_dense .weights = weights [weight_idx + 2 ]
115+ decoder_layer .cross_attention .output_dense .weights = weights [weight_idx + 3 ]
116+ weight_idx += 4
117+ if hasattr (decoder_layer , 'ffn' ):
118+ decoder_layer .ffn .dense1 .weights = weights [weight_idx ]
119+ decoder_layer .ffn .dense2 .weights = weights [weight_idx + 1 ]
120+ weight_idx += 2
121+
122+ # Restore output layer weights
123+ if hasattr (model .output_layer , 'weights' ):
124+ model .output_layer .weights = weights [weight_idx ]
4125
5126
6127class Callback :
@@ -32,13 +153,14 @@ def __init__(self, patience: int = 5, min_delta: float = 0.001, restore_best_wei
32153 self .min_delta : float = min_delta
33154 self .restore_best_weights : bool = restore_best_weights
34155 self .start_from_epoch : int = start_from_epoch
35- self .monitor : Metric | str = Metric (monitor ) if monitor != 'loss' else 'loss'
156+ self .monitor : Union [ Metric , str ] = Metric (monitor ) if monitor != 'loss' else 'loss'
36157 self .mode : str = mode
37158 self .baseline : float | None = baseline
38- self .best_weights : list | None = None
159+ self .best_weights : List [ np . ndarray ] | None = None
39160 self .best_metric : float | None = None
40161 self .patience_counter : int = 0
41162 self .stop_training : bool = False
163+ self .weight_manager = ModelWeightManager ()
42164
43165 def on_train_begin (self , logs : dict | None = None ) -> None :
44166 self .patience_counter = 0
@@ -67,16 +189,14 @@ def on_epoch_end(self, epoch: int, logs: dict | None = None) -> bool:
67189 self .best_metric = current_metric
68190 self .patience_counter = 0
69191 if self .restore_best_weights :
70- self .best_weights = [ layer . weights for layer in model . layers if hasattr ( layer , 'weights' )]
192+ self .best_weights = self . weight_manager . get_model_weights ( model )
71193 else :
72194 self .patience_counter += 1
73195
74196 if self .patience_counter >= self .patience :
75197 self .stop_training = True
76198 if self .restore_best_weights and self .best_weights is not None :
77- for layer , best_weights in zip ([layer for layer in model .layers if hasattr (layer , 'weights' )],
78- self .best_weights ):
79- layer .weights = best_weights
199+ self .weight_manager .set_model_weights (model , self .best_weights )
80200 print (f"\n Early stopping triggered after epoch { epoch + 1 } " )
81201 return True
82202
@@ -119,6 +239,13 @@ def on_epoch_begin(self, epoch: int, logs: dict | None = None) -> None:
119239 model .optimizer .learning_rate = new_lr
120240 if self .verbose > 0 :
121241 print (f"Epoch { epoch + 1 } : Learning rate updated from { old_lr :.5f} to { new_lr :.5f} " )
242+ elif hasattr (model , 'encoder_optimizer' ) and hasattr (model , 'decoder_optimizer' ):
243+ old_encoder_lr = model .encoder_optimizer .learning_rate
244+ old_decoder_lr = model .decoder_optimizer .learning_rate
245+ model .encoder_optimizer .learning_rate = new_lr
246+ model .decoder_optimizer .learning_rate = new_lr
247+ if self .verbose > 0 :
248+ print (f"Epoch { epoch + 1 } : Encoder learning rate updated from { old_encoder_lr :.5f} to { new_lr :.5f} " )
249+ print (f"Epoch { epoch + 1 } : Decoder learning rate updated from { old_decoder_lr :.5f} to { new_lr :.5f} " )
122250 else :
123- raise AttributeError ("Model's optimizer does not have a learning rate attribute." )
124-
251+ raise AttributeError ("Model's optimizer(s) do not have a learning rate attribute or are not properly configured." )
0 commit comments