Skip to content

Commit e468890

Browse files
committed
refactor: usage of standalone functions everywhere
1 parent 25f1513 commit e468890

File tree

7 files changed

+140
-88
lines changed

7 files changed

+140
-88
lines changed

examples/models-usages/compression/autoencoder_fashonized_mnist_basic.ipynb

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@
3131
"import numpy as np\n",
3232
"import matplotlib.pyplot as plt\n",
3333
"\n",
34-
"from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
3534
"from sklearn.datasets import fetch_openml\n",
3635
"\n",
3736
"from neuralnetlib.models import Autoencoder\n",
3837
"from neuralnetlib.layers import Input, Dense, BatchNormalization\n",
3938
"from neuralnetlib.preprocessing import PCA\n",
4039
"from neuralnetlib.utils import train_test_split\n",
41-
"from neuralnetlib.metrics import pearsonr\n",
40+
"from neuralnetlib.metrics import pearsonr, mean_absolute_error, mean_squared_error\n",
4241
"\n"
4342
]
4443
},
@@ -308,16 +307,16 @@
308307
"name": "stdout",
309308
"output_type": "stream",
310309
"text": [
311-
"[==============================] 100% Epoch 1/10 - loss: 0.0360 - - 8.65s \n",
312-
"[==============================] 100% Epoch 2/10 - loss: 0.0204 - - 8.42s \n",
313-
"[==============================] 100% Epoch 3/10 - loss: 0.0182 - - 8.22s \n",
314-
"[==============================] 100% Epoch 4/10 - loss: 0.0169 - - 8.39s \n",
315-
"[==============================] 100% Epoch 5/10 - loss: 0.0160 - - 8.38s \n",
316-
"[==============================] 100% Epoch 6/10 - loss: 0.0153 - - 8.38s \n",
317-
"[==============================] 100% Epoch 7/10 - loss: 0.0148 - - 8.24s \n",
318-
"[==============================] 100% Epoch 8/10 - loss: 0.0144 - - 8.16s \n",
319-
"[==============================] 100% Epoch 9/10 - loss: 0.0140 - - 8.20s \n",
320-
"[==============================] 100% Epoch 10/10 - loss: 0.0137 - - 8.71s \n",
310+
"[==============================] 100% Epoch 1/10 - loss: 0.0360 - - 8.55s \n",
311+
"[==============================] 100% Epoch 2/10 - loss: 0.0204 - - 8.40s \n",
312+
"[==============================] 100% Epoch 3/10 - loss: 0.0182 - - 8.32s \n",
313+
"[==============================] 100% Epoch 4/10 - loss: 0.0169 - - 8.41s \n",
314+
"[==============================] 100% Epoch 5/10 - loss: 0.0160 - - 8.47s \n",
315+
"[==============================] 100% Epoch 6/10 - loss: 0.0153 - - 8.47s \n",
316+
"[==============================] 100% Epoch 7/10 - loss: 0.0148 - - 8.53s \n",
317+
"[==============================] 100% Epoch 8/10 - loss: 0.0144 - - 9.89s \n",
318+
"[==============================] 100% Epoch 9/10 - loss: 0.0140 - - 8.33s \n",
319+
"[==============================] 100% Epoch 10/10 - loss: 0.0137 - - 8.42s \n",
321320
"\n"
322321
]
323322
}
@@ -483,7 +482,7 @@
483482
},
484483
{
485484
"cell_type": "code",
486-
"execution_count": 14,
485+
"execution_count": 11,
487486
"id": "cb013785",
488487
"metadata": {
489488
"ExecuteTime": {
@@ -498,7 +497,7 @@
498497
"Text(0.5, 1.0, 'Latent Space Distribution')"
499498
]
500499
},
501-
"execution_count": 14,
500+
"execution_count": 11,
502501
"metadata": {},
503502
"output_type": "execute_result"
504503
},
@@ -597,10 +596,10 @@
597596
"output_type": "stream",
598597
"text": [
599598
"Evaluation Results:\n",
600-
"mse_test: 0.03563702004658259\n",
601-
"mae_test: 0.1104170742069355\n",
602-
"mse_train: 0.035587902817990055\n",
603-
"mae_train: 0.11024269797983575\n",
599+
"mse_test: 0.03563702004658265\n",
600+
"mae_test: 0.11041707420693547\n",
601+
"mse_train: 0.035587902817990354\n",
602+
"mae_train: 0.11024269797983545\n",
604603
"avg_feature_correlation: 0.7689345609509709\n",
605604
"latent_skewness: 1.3513897110286115\n",
606605
"latent_kurtosis: 1.4631146187055855\n",

examples/models-usages/mlp-classification-regression/diabete_regression.ipynb

Lines changed: 113 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 1,
21+
"execution_count": null,
2222
"metadata": {
2323
"ExecuteTime": {
2424
"end_time": "2024-11-14T19:03:42.300332700Z",
@@ -27,8 +27,6 @@
2727
},
2828
"outputs": [],
2929
"source": [
30-
"import numpy as np\n",
31-
"\n",
3230
"from sklearn.datasets import load_diabetes\n",
3331
"\n",
3432
"from neuralnetlib.preprocessing import MinMaxScaler, StandardScaler\n",
@@ -49,7 +47,7 @@
4947
},
5048
{
5149
"cell_type": "code",
52-
"execution_count": 2,
50+
"execution_count": 11,
5351
"metadata": {
5452
"ExecuteTime": {
5553
"end_time": "2024-11-14T19:03:42.333243200Z",
@@ -71,7 +69,7 @@
7169
},
7270
{
7371
"cell_type": "code",
74-
"execution_count": 3,
72+
"execution_count": 12,
7573
"metadata": {
7674
"ExecuteTime": {
7775
"end_time": "2024-11-14T19:03:42.347365300Z",
@@ -97,7 +95,7 @@
9795
},
9896
{
9997
"cell_type": "code",
100-
"execution_count": 4,
98+
"execution_count": 13,
10199
"metadata": {
102100
"ExecuteTime": {
103101
"end_time": "2024-11-14T19:03:42.365577900Z",
@@ -133,7 +131,7 @@
133131
},
134132
{
135133
"cell_type": "code",
136-
"execution_count": 5,
134+
"execution_count": 14,
137135
"metadata": {
138136
"ExecuteTime": {
139137
"end_time": "2024-11-14T19:03:42.379141Z",
@@ -145,7 +143,7 @@
145143
"name": "stdout",
146144
"output_type": "stream",
147145
"text": [
148-
"Sequential(gradient_clip_threshold=5.0, enable_padding=False, padding_size=32, random_state=1733515038822283600)\n",
146+
"Sequential(gradient_clip_threshold=5.0, enable_padding=False, padding_size=32, random_state=1733786806176717100)\n",
149147
"-------------------------------------------------\n",
150148
"Layer 1: Input(input_shape=(10,))\n",
151149
"Layer 2: Dense(units=2)\n",
@@ -177,7 +175,7 @@
177175
},
178176
{
179177
"cell_type": "code",
180-
"execution_count": 6,
178+
"execution_count": 15,
181179
"metadata": {
182180
"ExecuteTime": {
183181
"end_time": "2024-11-14T19:03:42.516565900Z",
@@ -189,30 +187,120 @@
189187
"name": "stdout",
190188
"output_type": "stream",
191189
"text": [
192-
"[==============================] 100% Epoch 1/10 - 0.01s - loss: 1.2543 \n",
193-
"[==============================] 100% Epoch 2/10 - 0.01s - loss: 1.2482 \n",
194-
"[==============================] 100% Epoch 3/10 - 0.01s - loss: 1.2422 \n",
195-
"[==============================] 100% Epoch 4/10 - 0.01s - loss: 1.2366 \n",
196-
"[==============================] 100% Epoch 5/10 - 0.01s - loss: 1.2320 \n",
197-
"[==============================] 100% Epoch 6/10 - 0.01s - loss: 1.2275 \n",
198-
"[==============================] 100% Epoch 7/10 - 0.01s - loss: 1.2231 \n",
199-
"[==============================] 100% Epoch 8/10 - 0.01s - loss: 1.2183 \n",
200-
"[==============================] 100% Epoch 9/10 - 0.01s - loss: 1.2134 \n",
201-
"[==============================] 100% Epoch 10/10 - 0.01s - loss: 1.2083 \n",
190+
"[==============================] 100% Epoch 1/100 - 0.01s - loss: 1.2543 \n",
191+
"[==============================] 100% Epoch 2/100 - 0.01s - loss: 1.2482 \n",
192+
"[==============================] 100% Epoch 3/100 - 0.01s - loss: 1.2422 \n",
193+
"[==============================] 100% Epoch 4/100 - 0.01s - loss: 1.2366 \n",
194+
"[==============================] 100% Epoch 5/100 - 0.01s - loss: 1.2320 \n",
195+
"[==============================] 100% Epoch 6/100 - 0.01s - loss: 1.2275 \n",
196+
"[==============================] 100% Epoch 7/100 - 0.01s - loss: 1.2231 \n",
197+
"[==============================] 100% Epoch 8/100 - 0.01s - loss: 1.2183 \n",
198+
"[==============================] 100% Epoch 9/100 - 0.01s - loss: 1.2134 \n",
199+
"[==============================] 100% Epoch 10/100 - 0.01s - loss: 1.2083 \n",
200+
"[==============================] 100% Epoch 11/100 - 0.01s - loss: 1.2029 \n",
201+
"[==============================] 100% Epoch 12/100 - 0.01s - loss: 1.1975 \n",
202+
"[==============================] 100% Epoch 13/100 - 0.01s - loss: 1.1921 \n",
203+
"[==============================] 100% Epoch 14/100 - 0.01s - loss: 1.1864 \n",
204+
"[==============================] 100% Epoch 15/100 - 0.01s - loss: 1.1806 \n",
205+
"[==============================] 100% Epoch 16/100 - 0.01s - loss: 1.1746 \n",
206+
"[==============================] 100% Epoch 17/100 - 0.01s - loss: 1.1685 \n",
207+
"[==============================] 100% Epoch 18/100 - 0.01s - loss: 1.1622 \n",
208+
"[==============================] 100% Epoch 19/100 - 0.01s - loss: 1.1555 \n",
209+
"[==============================] 100% Epoch 20/100 - 0.01s - loss: 1.1489 \n",
210+
"[==============================] 100% Epoch 21/100 - 0.01s - loss: 1.1421 \n",
211+
"[==============================] 100% Epoch 22/100 - 0.01s - loss: 1.1353 \n",
212+
"[==============================] 100% Epoch 23/100 - 0.01s - loss: 1.1284 \n",
213+
"[==============================] 100% Epoch 24/100 - 0.01s - loss: 1.1213 \n",
214+
"[==============================] 100% Epoch 25/100 - 0.01s - loss: 1.1141 \n",
215+
"[==============================] 100% Epoch 26/100 - 0.01s - loss: 1.1068 \n",
216+
"[==============================] 100% Epoch 27/100 - 0.01s - loss: 1.0994 \n",
217+
"[==============================] 100% Epoch 28/100 - 0.01s - loss: 1.0919 \n",
218+
"[==============================] 100% Epoch 29/100 - 0.01s - loss: 1.0841 \n",
219+
"[==============================] 100% Epoch 30/100 - 0.01s - loss: 1.0761 \n",
220+
"[==============================] 100% Epoch 31/100 - 0.01s - loss: 1.0679 \n",
221+
"[==============================] 100% Epoch 32/100 - 0.01s - loss: 1.0597 \n",
222+
"[==============================] 100% Epoch 33/100 - 0.01s - loss: 1.0514 \n",
223+
"[==============================] 100% Epoch 34/100 - 0.01s - loss: 1.0429 \n",
224+
"[==============================] 100% Epoch 35/100 - 0.01s - loss: 1.0344 \n",
225+
"[==============================] 100% Epoch 36/100 - 0.01s - loss: 1.0259 \n",
226+
"[==============================] 100% Epoch 37/100 - 0.01s - loss: 1.0174 \n",
227+
"[==============================] 100% Epoch 38/100 - 0.01s - loss: 1.0090 \n",
228+
"[==============================] 100% Epoch 39/100 - 0.01s - loss: 1.0003 \n",
229+
"[==============================] 100% Epoch 40/100 - 0.01s - loss: 0.9914 \n",
230+
"[==============================] 100% Epoch 41/100 - 0.01s - loss: 0.9822 \n",
231+
"[==============================] 100% Epoch 42/100 - 0.01s - loss: 0.9729 \n",
232+
"[==============================] 100% Epoch 43/100 - 0.01s - loss: 0.9633 \n",
233+
"[==============================] 100% Epoch 44/100 - 0.01s - loss: 0.9534 \n",
234+
"[==============================] 100% Epoch 45/100 - 0.01s - loss: 0.9430 \n",
235+
"[==============================] 100% Epoch 46/100 - 0.01s - loss: 0.9322 \n",
236+
"[==============================] 100% Epoch 47/100 - 0.01s - loss: 0.9199 \n",
237+
"[==============================] 100% Epoch 48/100 - 0.01s - loss: 0.9069 \n",
238+
"[==============================] 100% Epoch 49/100 - 0.01s - loss: 0.8932 \n",
239+
"[==============================] 100% Epoch 50/100 - 0.01s - loss: 0.8791 \n",
240+
"[==============================] 100% Epoch 51/100 - 0.01s - loss: 0.8643 \n",
241+
"[==============================] 100% Epoch 52/100 - 0.01s - loss: 0.8502 \n",
242+
"[==============================] 100% Epoch 53/100 - 0.01s - loss: 0.8348 \n",
243+
"[==============================] 100% Epoch 54/100 - 0.01s - loss: 0.8178 \n",
244+
"[==============================] 100% Epoch 55/100 - 0.01s - loss: 0.7999 \n",
245+
"[==============================] 100% Epoch 56/100 - 0.01s - loss: 0.7825 \n",
246+
"[==============================] 100% Epoch 57/100 - 0.01s - loss: 0.7658 \n",
247+
"[==============================] 100% Epoch 58/100 - 0.01s - loss: 0.7502 \n",
248+
"[==============================] 100% Epoch 59/100 - 0.01s - loss: 0.7364 \n",
249+
"[==============================] 100% Epoch 60/100 - 0.01s - loss: 0.7243 \n",
250+
"[==============================] 100% Epoch 61/100 - 0.01s - loss: 0.7132 \n",
251+
"[==============================] 100% Epoch 62/100 - 0.01s - loss: 0.7031 \n",
252+
"[==============================] 100% Epoch 63/100 - 0.01s - loss: 0.6937 \n",
253+
"[==============================] 100% Epoch 64/100 - 0.01s - loss: 0.6846 \n",
254+
"[==============================] 100% Epoch 65/100 - 0.01s - loss: 0.6761 \n",
255+
"[==============================] 100% Epoch 66/100 - 0.01s - loss: 0.6674 \n",
256+
"[==============================] 100% Epoch 67/100 - 0.01s - loss: 0.6587 \n",
257+
"[==============================] 100% Epoch 68/100 - 0.01s - loss: 0.6508 \n",
258+
"[==============================] 100% Epoch 69/100 - 0.01s - loss: 0.6434 \n",
259+
"[==============================] 100% Epoch 70/100 - 0.01s - loss: 0.6365 \n",
260+
"[==============================] 100% Epoch 71/100 - 0.01s - loss: 0.6300 \n",
261+
"[==============================] 100% Epoch 72/100 - 0.01s - loss: 0.6237 \n",
262+
"[==============================] 100% Epoch 73/100 - 0.01s - loss: 0.6175 \n",
263+
"[==============================] 100% Epoch 74/100 - 0.01s - loss: 0.6110 \n",
264+
"[==============================] 100% Epoch 75/100 - 0.01s - loss: 0.6031 \n",
265+
"[==============================] 100% Epoch 76/100 - 0.01s - loss: 0.5917 \n",
266+
"[==============================] 100% Epoch 77/100 - 0.01s - loss: 0.5811 \n",
267+
"[==============================] 100% Epoch 78/100 - 0.01s - loss: 0.5711 \n",
268+
"[==============================] 100% Epoch 79/100 - 0.01s - loss: 0.5631 \n",
269+
"[==============================] 100% Epoch 80/100 - 0.01s - loss: 0.5569 \n",
270+
"[==============================] 100% Epoch 81/100 - 0.01s - loss: 0.5516 \n",
271+
"[==============================] 100% Epoch 82/100 - 0.01s - loss: 0.5470 \n",
272+
"[==============================] 100% Epoch 83/100 - 0.01s - loss: 0.5427 \n",
273+
"[==============================] 100% Epoch 84/100 - 0.01s - loss: 0.5384 \n",
274+
"[==============================] 100% Epoch 85/100 - 0.01s - loss: 0.5347 \n",
275+
"[==============================] 100% Epoch 86/100 - 0.01s - loss: 0.5312 \n",
276+
"[==============================] 100% Epoch 87/100 - 0.01s - loss: 0.5280 \n",
277+
"[==============================] 100% Epoch 88/100 - 0.01s - loss: 0.5252 \n",
278+
"[==============================] 100% Epoch 89/100 - 0.01s - loss: 0.5226 \n",
279+
"[==============================] 100% Epoch 90/100 - 0.01s - loss: 0.5202 \n",
280+
"[==============================] 100% Epoch 91/100 - 0.01s - loss: 0.5178 \n",
281+
"[==============================] 100% Epoch 92/100 - 0.01s - loss: 0.5156 \n",
282+
"[==============================] 100% Epoch 93/100 - 0.01s - loss: 0.5134 \n",
283+
"[==============================] 100% Epoch 94/100 - 0.01s - loss: 0.5112 \n",
284+
"[==============================] 100% Epoch 95/100 - 0.01s - loss: 0.5093 \n",
285+
"[==============================] 100% Epoch 96/100 - 0.01s - loss: 0.5075 \n",
286+
"[==============================] 100% Epoch 97/100 - 0.01s - loss: 0.5058 \n",
287+
"[==============================] 100% Epoch 98/100 - 0.01s - loss: 0.5041 \n",
288+
"[==============================] 100% Epoch 99/100 - 0.01s - loss: 0.5026 \n",
289+
"[==============================] 100% Epoch 100/100 - 0.01s - loss: 0.5011 \n",
202290
"\n"
203291
]
204292
},
205293
{
206294
"data": {
207295
"text/plain": []
208296
},
209-
"execution_count": 6,
297+
"execution_count": 15,
210298
"metadata": {},
211299
"output_type": "execute_result"
212300
}
213301
],
214302
"source": [
215-
"model.fit(x_train, y_train, epochs=10, batch_size=32, random_state=42)"
303+
"model.fit(x_train, y_train, epochs=100, batch_size=32, random_state=42)"
216304
]
217305
},
218306
{
@@ -224,7 +312,7 @@
224312
},
225313
{
226314
"cell_type": "code",
227-
"execution_count": 7,
315+
"execution_count": 16,
228316
"metadata": {
229317
"ExecuteTime": {
230318
"end_time": "2024-11-14T19:03:42.518566Z",
@@ -236,7 +324,7 @@
236324
"name": "stdout",
237325
"output_type": "stream",
238326
"text": [
239-
"Test loss: 1.1136541600695817 function=MeanSquaredError\n"
327+
"Test loss: 1.7401693358245864 function=MeanSquaredError\n"
240328
]
241329
}
242330
],
@@ -254,7 +342,7 @@
254342
},
255343
{
256344
"cell_type": "code",
257-
"execution_count": 8,
345+
"execution_count": 17,
258346
"metadata": {
259347
"ExecuteTime": {
260348
"end_time": "2024-11-14T19:03:42.519566900Z",
@@ -266,49 +354,14 @@
266354
"name": "stdout",
267355
"output_type": "stream",
268356
"text": [
269-
"MAE: 0.8748635782918366\n"
357+
"MAE: 1.0799955616547592\n"
270358
]
271359
}
272360
],
273361
"source": [
274362
"y_pred = model.predict(x_test)\n",
275363
"print(\"MAE: \", MeanAbsoluteError()(y_test, y_pred))"
276364
]
277-
},
278-
{
279-
"cell_type": "markdown",
280-
"metadata": {},
281-
"source": [
282-
"## 8. Getting original MAE (without normalization from StandardScaler)"
283-
]
284-
},
285-
{
286-
"cell_type": "code",
287-
"execution_count": 9,
288-
"metadata": {
289-
"ExecuteTime": {
290-
"end_time": "2024-11-14T19:03:42.533072900Z",
291-
"start_time": "2024-11-14T19:03:42.519566900Z"
292-
}
293-
},
294-
"outputs": [
295-
{
296-
"name": "stdout",
297-
"output_type": "stream",
298-
"text": [
299-
"MAE (original): 65.9770599527072\n"
300-
]
301-
}
302-
],
303-
"source": [
304-
"y_pred_scaled = model.predict(x_test)\n",
305-
"\n",
306-
"y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()\n",
307-
"y_test_original = scaler_y.inverse_transform(y_test.reshape(-1, 1)).flatten()\n",
308-
"\n",
309-
"mae_original = np.mean(np.abs(y_test_original - y_pred))\n",
310-
"print(f'MAE (original): {mae_original}')"
311-
]
312365
}
313366
],
314367
"metadata": {

examples/models-usages/mlp-classification-regression/mnist_loading_saved_model.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 1,
21+
"execution_count": null,
2222
"metadata": {
2323
"ExecuteTime": {
2424
"end_time": "2024-11-14T19:04:46.439544500Z",
@@ -27,7 +27,7 @@
2727
},
2828
"outputs": [],
2929
"source": [
30-
"from tensorflow.keras.datasets import mnist # Dataset for testing\n",
30+
"from tensorflow.keras.datasets import mnist\n",
3131
"\n",
3232
"from neuralnetlib.models import Sequential\n",
3333
"from neuralnetlib.utils import train_test_split\n",

examples/models-usages/mlp-classification-regression/tictactoe/tic_tac_toe_alternative_dataset_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pandas as pd
33
import requests
44
from scipy.io import arff
5-
from sklearn.model_selection import train_test_split
65

76
from neuralnetlib.activations import ReLU, Sigmoid
87
from neuralnetlib.callbacks import EarlyStopping
@@ -11,6 +10,7 @@
1110
from neuralnetlib.metrics import accuracy_score
1211
from neuralnetlib.models import Sequential
1312
from neuralnetlib.optimizers import Adam
13+
from neuralnetlib.utils import train_test_split
1414

1515

1616
def main():

0 commit comments

Comments
 (0)