Skip to content

Commit 1278f1e

Browse files
committed
fix: general fixes and improvements
1 parent 019cfef commit 1278f1e

File tree

6 files changed

+173
-166
lines changed

6 files changed

+173
-166
lines changed

neuralnetlib/callbacks.py

Lines changed: 102 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,124 +6,131 @@
66

77
class ModelWeightManager:
88
@staticmethod
9-
def get_model_weights(model) -> list[np.ndarray]:
10-
"""Extract weights from any model type."""
11-
weights = []
9+
def get_model_weights(model) -> list[tuple[np.ndarray, np.ndarray | None]]:
10+
"""Extract weights and biases from any model type."""
11+
params = []
12+
13+
def get_params_from_layer(layer):
14+
if hasattr(layer, 'weights'):
15+
weights = layer.weights.copy()
16+
bias = layer.bias.copy() if hasattr(layer, 'bias') else None
17+
return (weights, bias)
18+
return None
19+
20+
def get_params_from_dense_layers(layers):
21+
layer_params = []
22+
for layer in layers:
23+
p = get_params_from_layer(layer)
24+
if p:
25+
layer_params.append(p)
26+
return layer_params
1227

1328
if hasattr(model, 'layers'): # Sequential model
14-
weights.extend(
15-
[layer.weights for layer in model.layers if hasattr(layer, 'weights')])
29+
for layer in model.layers:
30+
p = get_params_from_layer(layer)
31+
if p:
32+
params.append(p)
1633

1734
elif hasattr(model, 'encoder_layers') and hasattr(model, 'decoder_layers'): # Autoencoder
18-
weights.extend(
19-
[layer.weights for layer in model.encoder_layers if hasattr(layer, 'weights')])
20-
weights.extend(
21-
[layer.weights for layer in model.decoder_layers if hasattr(layer, 'weights')])
35+
for layer in model.encoder_layers:
36+
p = get_params_from_layer(layer)
37+
if p:
38+
params.append(p)
39+
for layer in model.decoder_layers:
40+
p = get_params_from_layer(layer)
41+
if p:
42+
params.append(p)
2243

23-
elif hasattr(model, 'embedding'): # Transformer
24-
if hasattr(model.embedding, 'weights'):
25-
weights.append(model.embedding.weights)
44+
elif hasattr(model, 'src_embedding'): # Transformer
45+
params.append(get_params_from_layer(model.src_embedding))
46+
params.append(get_params_from_layer(model.tgt_embedding))
2647

2748
for encoder_layer in model.encoder_layers:
28-
if hasattr(encoder_layer, 'attention'):
29-
weights.extend([
30-
encoder_layer.attention.query_dense.weights,
31-
encoder_layer.attention.key_dense.weights,
32-
encoder_layer.attention.value_dense.weights,
33-
encoder_layer.attention.output_dense.weights
34-
])
35-
if hasattr(encoder_layer, 'ffn'):
36-
weights.extend([
37-
encoder_layer.ffn.dense1.weights,
38-
encoder_layer.ffn.dense2.weights
39-
])
49+
params.extend(get_params_from_dense_layers([
50+
encoder_layer.attention.query_dense,
51+
encoder_layer.attention.key_dense,
52+
encoder_layer.attention.value_dense,
53+
encoder_layer.attention.output_dense,
54+
encoder_layer.ffn.dense1,
55+
encoder_layer.ffn.dense2
56+
]))
4057

4158
for decoder_layer in model.decoder_layers:
42-
if hasattr(decoder_layer, 'self_attention'):
43-
weights.extend([
44-
decoder_layer.self_attention.query_dense.weights,
45-
decoder_layer.self_attention.key_dense.weights,
46-
decoder_layer.self_attention.value_dense.weights,
47-
decoder_layer.self_attention.output_dense.weights
48-
])
49-
if hasattr(decoder_layer, 'cross_attention'):
50-
weights.extend([
51-
decoder_layer.cross_attention.query_dense.weights,
52-
decoder_layer.cross_attention.key_dense.weights,
53-
decoder_layer.cross_attention.value_dense.weights,
54-
decoder_layer.cross_attention.output_dense.weights
55-
])
56-
if hasattr(decoder_layer, 'ffn'):
57-
weights.extend([
58-
decoder_layer.ffn.dense1.weights,
59-
decoder_layer.ffn.dense2.weights
60-
])
61-
62-
if hasattr(model.output_layer, 'weights'):
63-
weights.append(model.output_layer.weights)
64-
65-
return weights
59+
params.extend(get_params_from_dense_layers([
60+
decoder_layer.self_attention.query_dense,
61+
decoder_layer.self_attention.key_dense,
62+
decoder_layer.self_attention.value_dense,
63+
decoder_layer.self_attention.output_dense,
64+
decoder_layer.cross_attention.query_dense,
65+
decoder_layer.cross_attention.key_dense,
66+
decoder_layer.cross_attention.value_dense,
67+
decoder_layer.cross_attention.output_dense,
68+
decoder_layer.ffn.dense1,
69+
decoder_layer.ffn.dense2
70+
]))
71+
72+
params.append(get_params_from_layer(model.output_layer))
73+
74+
return [p for p in params if p is not None]
6675

6776
@staticmethod
68-
def set_model_weights(model, weights: list[np.ndarray]) -> None:
69-
"""Restore weights to any model type."""
70-
weight_idx = 0
77+
def set_model_weights(model, params: list[tuple[np.ndarray, np.ndarray | None]]) -> None:
78+
"""Restore weights and biases to any model type."""
79+
param_idx = 0
80+
81+
def set_params_for_layer(layer):
82+
nonlocal param_idx
83+
if hasattr(layer, 'weights'):
84+
if param_idx < len(params):
85+
weights, bias = params[param_idx]
86+
layer.weights = weights.copy()
87+
if hasattr(layer, 'bias') and bias is not None:
88+
layer.bias = bias.copy()
89+
param_idx += 1
90+
91+
def set_params_for_dense_layers(layers):
92+
for layer in layers:
93+
set_params_for_layer(layer)
7194

7295
if hasattr(model, 'layers'): # Sequential model
7396
for layer in model.layers:
74-
if hasattr(layer, 'weights'):
75-
layer.weights = weights[weight_idx]
76-
weight_idx += 1
97+
set_params_for_layer(layer)
7798

7899
elif hasattr(model, 'encoder_layers') and hasattr(model, 'decoder_layers'): # Autoencoder
79100
for layer in model.encoder_layers:
80-
if hasattr(layer, 'weights'):
81-
layer.weights = weights[weight_idx]
82-
weight_idx += 1
83-
101+
set_params_for_layer(layer)
84102
for layer in model.decoder_layers:
85-
if hasattr(layer, 'weights'):
86-
layer.weights = weights[weight_idx]
87-
weight_idx += 1
103+
set_params_for_layer(layer)
88104

89-
elif hasattr(model, 'embedding'):
90-
if hasattr(model.embedding, 'weights'):
91-
model.embedding.weights = weights[weight_idx]
92-
weight_idx += 1
105+
elif hasattr(model, 'src_embedding'): # Transformer
106+
set_params_for_layer(model.src_embedding)
107+
set_params_for_layer(model.tgt_embedding)
93108

94109
for encoder_layer in model.encoder_layers:
95-
if hasattr(encoder_layer, 'attention'):
96-
encoder_layer.attention.query_dense.weights = weights[weight_idx]
97-
encoder_layer.attention.key_dense.weights = weights[weight_idx + 1]
98-
encoder_layer.attention.value_dense.weights = weights[weight_idx + 2]
99-
encoder_layer.attention.output_dense.weights = weights[weight_idx + 3]
100-
weight_idx += 4
101-
if hasattr(encoder_layer, 'ffn'):
102-
encoder_layer.ffn.dense1.weights = weights[weight_idx]
103-
encoder_layer.ffn.dense2.weights = weights[weight_idx + 1]
104-
weight_idx += 2
110+
set_params_for_dense_layers([
111+
encoder_layer.attention.query_dense,
112+
encoder_layer.attention.key_dense,
113+
encoder_layer.attention.value_dense,
114+
encoder_layer.attention.output_dense,
115+
encoder_layer.ffn.dense1,
116+
encoder_layer.ffn.dense2
117+
])
105118

106119
for decoder_layer in model.decoder_layers:
107-
if hasattr(decoder_layer, 'self_attention'):
108-
decoder_layer.self_attention.query_dense.weights = weights[weight_idx]
109-
decoder_layer.self_attention.key_dense.weights = weights[weight_idx + 1]
110-
decoder_layer.self_attention.value_dense.weights = weights[weight_idx + 2]
111-
decoder_layer.self_attention.output_dense.weights = weights[weight_idx + 3]
112-
weight_idx += 4
113-
if hasattr(decoder_layer, 'cross_attention'):
114-
decoder_layer.cross_attention.query_dense.weights = weights[weight_idx]
115-
decoder_layer.cross_attention.key_dense.weights = weights[weight_idx + 1]
116-
decoder_layer.cross_attention.value_dense.weights = weights[weight_idx + 2]
117-
decoder_layer.cross_attention.output_dense.weights = weights[weight_idx + 3]
118-
weight_idx += 4
119-
if hasattr(decoder_layer, 'ffn'):
120-
decoder_layer.ffn.dense1.weights = weights[weight_idx]
121-
decoder_layer.ffn.dense2.weights = weights[weight_idx + 1]
122-
weight_idx += 2
123-
124-
# Restore output layer weights
125-
if hasattr(model.output_layer, 'weights'):
126-
model.output_layer.weights = weights[weight_idx]
120+
set_params_for_dense_layers([
121+
decoder_layer.self_attention.query_dense,
122+
decoder_layer.self_attention.key_dense,
123+
decoder_layer.self_attention.value_dense,
124+
decoder_layer.self_attention.output_dense,
125+
decoder_layer.cross_attention.query_dense,
126+
decoder_layer.cross_attention.key_dense,
127+
decoder_layer.cross_attention.value_dense,
128+
decoder_layer.cross_attention.output_dense,
129+
decoder_layer.ffn.dense1,
130+
decoder_layer.ffn.dense2
131+
])
132+
133+
set_params_for_layer(model.output_layer)
127134

128135

129136
class Callback:

neuralnetlib/layers.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ def forward_pass(self, input_data: np.ndarray) -> np.ndarray:
153153
self.input_shape = input_data.shape
154154
self.input = input_data
155155

156-
if len(input_data.shape) == 1 and self.input_dim:
157-
batch_size = input_data.shape[0]
158-
input_data = input_data.reshape(batch_size, self.input_dim)
156+
if input_data.ndim == 1:
157+
input_data = input_data.reshape(1, -1)
159158
self.input = input_data
160159

161160
if len(input_data.shape) == 3:
@@ -316,13 +315,7 @@ def forward_pass(self, input_data: np.ndarray, training: bool = True) -> np.ndar
316315
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
317316
if self.adaptive:
318317
return self.dropout_impl.gradient(output_error)
319-
320-
if output_error.shape[0] != self.mask.shape[0]:
321-
rng = np.random.default_rng(
322-
self.random_state if self.random_state is not None else int(time.time_ns()))
323-
self.mask = rng.binomial(1, 1 - self.rate,
324-
size=(output_error.shape[0], self.mask.shape[1])) / (1 - self.rate)
325-
318+
326319
return output_error * self.mask
327320

328321
def get_config(self) -> dict:
@@ -373,22 +366,21 @@ def __init__(self, filters: int, kernel_size: int | tuple, strides: int | tuple
373366

374367
def initialize_weights(self, input_shape: tuple):
375368
_, _, _, in_channels = input_shape
369+
fan_in = np.prod(self.kernel_size) * in_channels
370+
fan_out = np.prod(self.kernel_size) * self.filters
376371

377372
self.rng = np.random.default_rng(
378373
self.random_state if self.random_state is not None else int(time.time_ns()))
379374

380-
if self.weights_init == "xavier":
381-
self.weights = self.rng.normal(0, np.sqrt(2 / (np.prod(self.kernel_size) * in_channels)),
382-
(*self.kernel_size, in_channels, self.filters))
375+
if self.weights_init == "glorot_uniform" or self.weights_init == "xavier":
376+
limit = np.sqrt(6 / (fan_in + fan_out))
377+
self.weights = self.rng.uniform(-limit, limit, (*self.kernel_size, in_channels, self.filters))
383378
elif self.weights_init == "he":
384-
self.weights = self.rng.normal(0, np.sqrt(2 / (in_channels * np.prod(self.kernel_size))),
385-
(*self.kernel_size, in_channels, self.filters))
379+
self.weights = self.rng.normal(0, np.sqrt(2 / fan_in), (*self.kernel_size, in_channels, self.filters))
386380
elif self.weights_init == "default":
387-
self.weights = self.rng.normal(
388-
0, 0.01, (*self.kernel_size, in_channels, self.filters))
381+
self.weights = self.rng.normal(0, 0.01, (*self.kernel_size, in_channels, self.filters))
389382
else:
390-
raise ValueError(
391-
"Invalid weights_init value. Possible values are 'xavier', 'he', and 'default'.")
383+
raise ValueError("Invalid weights_init value. Possible values are 'xavier', 'he', and 'default'.")
392384

393385
if self.bias_init == "default":
394386
self.bias = np.zeros((1, self.filters))

neuralnetlib/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
8181
return np.mean(np.square(y_true - y_pred))
8282

8383
def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
84-
return 2 * (y_pred - y_true) / y_true.shape[0]
84+
return 2 * (y_pred - y_true) / y_true.size
8585

8686
def __str__(self):
8787
return "MeanSquaredError"

neuralnetlib/metrics.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -307,18 +307,24 @@ def mean_absolute_percentage_error(y_pred: np.ndarray, y_true: np.ndarray, thres
307307
return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
308308

309309

310-
def r2_score(y_pred: np.ndarray, y_true: np.ndarray, threshold: float = 0.5) -> float:
310+
def r2_score(y_pred: np.ndarray, y_true: np.ndarray) -> float:
311311
y_pred, y_true = _reshape_inputs(y_pred, y_true)
312-
if y_pred.shape[1] == 1:
313-
y_pred_classes = (y_pred >= threshold).astype(int).ravel()
314-
y_true_classes = y_true.ravel()
315-
else:
316-
y_pred_classes = np.argmax(y_pred, axis=1)
317-
y_true_classes = np.argmax(y_true, axis=1)
318312

319-
ss_res = np.sum((y_true_classes - y_pred_classes) ** 2)
320-
ss_tot = np.sum((y_true_classes - np.mean(y_true_classes)) ** 2)
321-
return 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
313+
if y_pred.shape[1] == 1:
314+
y_pred_ = y_pred.ravel()
315+
y_true_ = y_true.ravel()
316+
ss_res = np.sum((y_true_ - y_pred_) ** 2)
317+
ss_tot = np.sum((y_true_ - np.mean(y_true_)) ** 2)
318+
return 1.0 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
319+
320+
r2s = []
321+
for j in range(y_pred.shape[1]):
322+
yp = y_pred[:, j]
323+
yt = y_true[:, j]
324+
ss_res = np.sum((yt - yp) ** 2)
325+
ss_tot = np.sum((yt - np.mean(yt)) ** 2)
326+
r2s.append(1.0 - (ss_res / ss_tot) if ss_tot != 0 else 0.0)
327+
return float(np.mean(r2s)) if r2s else 0.0
322328

323329

324330
def bleu_score(y_pred: np.ndarray, y_true: np.ndarray, threshold: float | None = None, n_gram: int = 4, smooth: bool = False) -> float:

neuralnetlib/models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@ def forward_pass(self, X: np.ndarray, training: bool = True, labels: np.ndarray
193193
return X
194194

195195
def backward_pass(self, error: np.ndarray, gan: bool = False, compute_only: bool = False) -> np.ndarray:
196-
if self.n_classes is not None and error.shape[1] > error.shape[1] - self.n_classes:
197-
error = error[:, :-self.n_classes]
198196

199197
for i, layer in enumerate(reversed(self.layers)):
200198
if i == 0 and isinstance(layer, Activation):

0 commit comments

Comments
 (0)