Skip to content

Commit d7d6100

Browse files
committed
fix(callbacks): now compatible with every model architecture
1 parent 937d7c3 commit d7d6100

File tree

1 file changed

+135
-8
lines changed

1 file changed

+135
-8
lines changed

neuralnetlib/callbacks.py

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,127 @@
11
import numpy as np
2+
from typing import List, Dict, Any, Union
23

34
from neuralnetlib.metrics import Metric
5+
from neuralnetlib.layers import Layer
6+
7+
8+
class ModelWeightManager:
9+
@staticmethod
10+
def get_model_weights(model) -> List[np.ndarray]:
11+
"""Extract weights from any model type."""
12+
weights = []
13+
14+
if hasattr(model, 'layers'): # Sequential model
15+
weights.extend([layer.weights for layer in model.layers if hasattr(layer, 'weights')])
16+
17+
elif hasattr(model, 'encoder_layers') and hasattr(model, 'decoder_layers'): # Autoencoder
18+
weights.extend([layer.weights for layer in model.encoder_layers if hasattr(layer, 'weights')])
19+
weights.extend([layer.weights for layer in model.decoder_layers if hasattr(layer, 'weights')])
20+
21+
elif hasattr(model, 'embedding'): # Transformer
22+
if hasattr(model.embedding, 'weights'):
23+
weights.append(model.embedding.weights)
24+
25+
for encoder_layer in model.encoder_layers:
26+
if hasattr(encoder_layer, 'attention'):
27+
weights.extend([
28+
encoder_layer.attention.query_dense.weights,
29+
encoder_layer.attention.key_dense.weights,
30+
encoder_layer.attention.value_dense.weights,
31+
encoder_layer.attention.output_dense.weights
32+
])
33+
if hasattr(encoder_layer, 'ffn'):
34+
weights.extend([
35+
encoder_layer.ffn.dense1.weights,
36+
encoder_layer.ffn.dense2.weights
37+
])
38+
39+
for decoder_layer in model.decoder_layers:
40+
if hasattr(decoder_layer, 'self_attention'):
41+
weights.extend([
42+
decoder_layer.self_attention.query_dense.weights,
43+
decoder_layer.self_attention.key_dense.weights,
44+
decoder_layer.self_attention.value_dense.weights,
45+
decoder_layer.self_attention.output_dense.weights
46+
])
47+
if hasattr(decoder_layer, 'cross_attention'):
48+
weights.extend([
49+
decoder_layer.cross_attention.query_dense.weights,
50+
decoder_layer.cross_attention.key_dense.weights,
51+
decoder_layer.cross_attention.value_dense.weights,
52+
decoder_layer.cross_attention.output_dense.weights
53+
])
54+
if hasattr(decoder_layer, 'ffn'):
55+
weights.extend([
56+
decoder_layer.ffn.dense1.weights,
57+
decoder_layer.ffn.dense2.weights
58+
])
59+
60+
if hasattr(model.output_layer, 'weights'):
61+
weights.append(model.output_layer.weights)
62+
63+
return weights
64+
65+
@staticmethod
66+
def set_model_weights(model, weights: List[np.ndarray]) -> None:
67+
"""Restore weights to any model type."""
68+
weight_idx = 0
69+
70+
if hasattr(model, 'layers'): # Sequential model
71+
for layer in model.layers:
72+
if hasattr(layer, 'weights'):
73+
layer.weights = weights[weight_idx]
74+
weight_idx += 1
75+
76+
elif hasattr(model, 'encoder_layers') and hasattr(model, 'decoder_layers'): # Autoencoder
77+
for layer in model.encoder_layers:
78+
if hasattr(layer, 'weights'):
79+
layer.weights = weights[weight_idx]
80+
weight_idx += 1
81+
82+
for layer in model.decoder_layers:
83+
if hasattr(layer, 'weights'):
84+
layer.weights = weights[weight_idx]
85+
weight_idx += 1
86+
87+
elif hasattr(model, 'embedding'):
88+
if hasattr(model.embedding, 'weights'):
89+
model.embedding.weights = weights[weight_idx]
90+
weight_idx += 1
91+
92+
for encoder_layer in model.encoder_layers:
93+
if hasattr(encoder_layer, 'attention'):
94+
encoder_layer.attention.query_dense.weights = weights[weight_idx]
95+
encoder_layer.attention.key_dense.weights = weights[weight_idx + 1]
96+
encoder_layer.attention.value_dense.weights = weights[weight_idx + 2]
97+
encoder_layer.attention.output_dense.weights = weights[weight_idx + 3]
98+
weight_idx += 4
99+
if hasattr(encoder_layer, 'ffn'):
100+
encoder_layer.ffn.dense1.weights = weights[weight_idx]
101+
encoder_layer.ffn.dense2.weights = weights[weight_idx + 1]
102+
weight_idx += 2
103+
104+
for decoder_layer in model.decoder_layers:
105+
if hasattr(decoder_layer, 'self_attention'):
106+
decoder_layer.self_attention.query_dense.weights = weights[weight_idx]
107+
decoder_layer.self_attention.key_dense.weights = weights[weight_idx + 1]
108+
decoder_layer.self_attention.value_dense.weights = weights[weight_idx + 2]
109+
decoder_layer.self_attention.output_dense.weights = weights[weight_idx + 3]
110+
weight_idx += 4
111+
if hasattr(decoder_layer, 'cross_attention'):
112+
decoder_layer.cross_attention.query_dense.weights = weights[weight_idx]
113+
decoder_layer.cross_attention.key_dense.weights = weights[weight_idx + 1]
114+
decoder_layer.cross_attention.value_dense.weights = weights[weight_idx + 2]
115+
decoder_layer.cross_attention.output_dense.weights = weights[weight_idx + 3]
116+
weight_idx += 4
117+
if hasattr(decoder_layer, 'ffn'):
118+
decoder_layer.ffn.dense1.weights = weights[weight_idx]
119+
decoder_layer.ffn.dense2.weights = weights[weight_idx + 1]
120+
weight_idx += 2
121+
122+
# Restore output layer weights
123+
if hasattr(model.output_layer, 'weights'):
124+
model.output_layer.weights = weights[weight_idx]
4125

5126

6127
class Callback:
@@ -32,13 +153,14 @@ def __init__(self, patience: int = 5, min_delta: float = 0.001, restore_best_wei
32153
self.min_delta: float = min_delta
33154
self.restore_best_weights: bool = restore_best_weights
34155
self.start_from_epoch: int = start_from_epoch
35-
self.monitor: Metric | str = Metric(monitor) if monitor != 'loss' else 'loss'
156+
self.monitor: Union[Metric, str] = Metric(monitor) if monitor != 'loss' else 'loss'
36157
self.mode: str = mode
37158
self.baseline: float | None = baseline
38-
self.best_weights: list | None = None
159+
self.best_weights: List[np.ndarray] | None = None
39160
self.best_metric: float | None = None
40161
self.patience_counter: int = 0
41162
self.stop_training: bool = False
163+
self.weight_manager = ModelWeightManager()
42164

43165
def on_train_begin(self, logs: dict | None = None) -> None:
44166
self.patience_counter = 0
@@ -67,16 +189,14 @@ def on_epoch_end(self, epoch: int, logs: dict | None = None) -> bool:
67189
self.best_metric = current_metric
68190
self.patience_counter = 0
69191
if self.restore_best_weights:
70-
self.best_weights = [layer.weights for layer in model.layers if hasattr(layer, 'weights')]
192+
self.best_weights = self.weight_manager.get_model_weights(model)
71193
else:
72194
self.patience_counter += 1
73195

74196
if self.patience_counter >= self.patience:
75197
self.stop_training = True
76198
if self.restore_best_weights and self.best_weights is not None:
77-
for layer, best_weights in zip([layer for layer in model.layers if hasattr(layer, 'weights')],
78-
self.best_weights):
79-
layer.weights = best_weights
199+
self.weight_manager.set_model_weights(model, self.best_weights)
80200
print(f"\nEarly stopping triggered after epoch {epoch + 1}")
81201
return True
82202

@@ -119,6 +239,13 @@ def on_epoch_begin(self, epoch: int, logs: dict | None = None) -> None:
119239
model.optimizer.learning_rate = new_lr
120240
if self.verbose > 0:
121241
print(f"Epoch {epoch + 1}: Learning rate updated from {old_lr:.5f} to {new_lr:.5f}")
242+
elif hasattr(model, 'encoder_optimizer') and hasattr(model, 'decoder_optimizer'):
243+
old_encoder_lr = model.encoder_optimizer.learning_rate
244+
old_decoder_lr = model.decoder_optimizer.learning_rate
245+
model.encoder_optimizer.learning_rate = new_lr
246+
model.decoder_optimizer.learning_rate = new_lr
247+
if self.verbose > 0:
248+
print(f"Epoch {epoch + 1}: Encoder learning rate updated from {old_encoder_lr:.5f} to {new_lr:.5f}")
249+
print(f"Epoch {epoch + 1}: Decoder learning rate updated from {old_decoder_lr:.5f} to {new_lr:.5f}")
122250
else:
123-
raise AttributeError("Model's optimizer does not have a learning rate attribute.")
124-
251+
raise AttributeError("Model's optimizer(s) do not have a learning rate attribute or are not properly configured.")

0 commit comments

Comments
 (0)