|
18 | 18 | }, |
19 | 19 | { |
20 | 20 | "cell_type": "code", |
21 | | - "execution_count": 1, |
| 21 | + "execution_count": null, |
22 | 22 | "metadata": { |
23 | 23 | "ExecuteTime": { |
24 | 24 | "end_time": "2024-11-14T19:03:42.300332700Z", |
|
27 | 27 | }, |
28 | 28 | "outputs": [], |
29 | 29 | "source": [ |
30 | | - "import numpy as np\n", |
31 | | - "\n", |
32 | 30 | "from sklearn.datasets import load_diabetes\n", |
33 | 31 | "\n", |
34 | 32 | "from neuralnetlib.preprocessing import MinMaxScaler, StandardScaler\n", |
|
49 | 47 | }, |
50 | 48 | { |
51 | 49 | "cell_type": "code", |
52 | | - "execution_count": 2, |
| 50 | + "execution_count": 11, |
53 | 51 | "metadata": { |
54 | 52 | "ExecuteTime": { |
55 | 53 | "end_time": "2024-11-14T19:03:42.333243200Z", |
|
71 | 69 | }, |
72 | 70 | { |
73 | 71 | "cell_type": "code", |
74 | | - "execution_count": 3, |
| 72 | + "execution_count": 12, |
75 | 73 | "metadata": { |
76 | 74 | "ExecuteTime": { |
77 | 75 | "end_time": "2024-11-14T19:03:42.347365300Z", |
|
97 | 95 | }, |
98 | 96 | { |
99 | 97 | "cell_type": "code", |
100 | | - "execution_count": 4, |
| 98 | + "execution_count": 13, |
101 | 99 | "metadata": { |
102 | 100 | "ExecuteTime": { |
103 | 101 | "end_time": "2024-11-14T19:03:42.365577900Z", |
|
133 | 131 | }, |
134 | 132 | { |
135 | 133 | "cell_type": "code", |
136 | | - "execution_count": 5, |
| 134 | + "execution_count": 14, |
137 | 135 | "metadata": { |
138 | 136 | "ExecuteTime": { |
139 | 137 | "end_time": "2024-11-14T19:03:42.379141Z", |
|
145 | 143 | "name": "stdout", |
146 | 144 | "output_type": "stream", |
147 | 145 | "text": [ |
148 | | - "Sequential(gradient_clip_threshold=5.0, enable_padding=False, padding_size=32, random_state=1733515038822283600)\n", |
| 146 | + "Sequential(gradient_clip_threshold=5.0, enable_padding=False, padding_size=32, random_state=1733786806176717100)\n", |
149 | 147 | "-------------------------------------------------\n", |
150 | 148 | "Layer 1: Input(input_shape=(10,))\n", |
151 | 149 | "Layer 2: Dense(units=2)\n", |
|
177 | 175 | }, |
178 | 176 | { |
179 | 177 | "cell_type": "code", |
180 | | - "execution_count": 6, |
| 178 | + "execution_count": 15, |
181 | 179 | "metadata": { |
182 | 180 | "ExecuteTime": { |
183 | 181 | "end_time": "2024-11-14T19:03:42.516565900Z", |
|
189 | 187 | "name": "stdout", |
190 | 188 | "output_type": "stream", |
191 | 189 | "text": [ |
192 | | - "[==============================] 100% Epoch 1/10 - 0.01s - loss: 1.2543 \n", |
193 | | - "[==============================] 100% Epoch 2/10 - 0.01s - loss: 1.2482 \n", |
194 | | - "[==============================] 100% Epoch 3/10 - 0.01s - loss: 1.2422 \n", |
195 | | - "[==============================] 100% Epoch 4/10 - 0.01s - loss: 1.2366 \n", |
196 | | - "[==============================] 100% Epoch 5/10 - 0.01s - loss: 1.2320 \n", |
197 | | - "[==============================] 100% Epoch 6/10 - 0.01s - loss: 1.2275 \n", |
198 | | - "[==============================] 100% Epoch 7/10 - 0.01s - loss: 1.2231 \n", |
199 | | - "[==============================] 100% Epoch 8/10 - 0.01s - loss: 1.2183 \n", |
200 | | - "[==============================] 100% Epoch 9/10 - 0.01s - loss: 1.2134 \n", |
201 | | - "[==============================] 100% Epoch 10/10 - 0.01s - loss: 1.2083 \n", |
| 190 | + "[==============================] 100% Epoch 1/100 - 0.01s - loss: 1.2543 \n", |
| 191 | + "[==============================] 100% Epoch 2/100 - 0.01s - loss: 1.2482 \n", |
| 192 | + "[==============================] 100% Epoch 3/100 - 0.01s - loss: 1.2422 \n", |
| 193 | + "[==============================] 100% Epoch 4/100 - 0.01s - loss: 1.2366 \n", |
| 194 | + "[==============================] 100% Epoch 5/100 - 0.01s - loss: 1.2320 \n", |
| 195 | + "[==============================] 100% Epoch 6/100 - 0.01s - loss: 1.2275 \n", |
| 196 | + "[==============================] 100% Epoch 7/100 - 0.01s - loss: 1.2231 \n", |
| 197 | + "[==============================] 100% Epoch 8/100 - 0.01s - loss: 1.2183 \n", |
| 198 | + "[==============================] 100% Epoch 9/100 - 0.01s - loss: 1.2134 \n", |
| 199 | + "[==============================] 100% Epoch 10/100 - 0.01s - loss: 1.2083 \n", |
| 200 | + "[==============================] 100% Epoch 11/100 - 0.01s - loss: 1.2029 \n", |
| 201 | + "[==============================] 100% Epoch 12/100 - 0.01s - loss: 1.1975 \n", |
| 202 | + "[==============================] 100% Epoch 13/100 - 0.01s - loss: 1.1921 \n", |
| 203 | + "[==============================] 100% Epoch 14/100 - 0.01s - loss: 1.1864 \n", |
| 204 | + "[==============================] 100% Epoch 15/100 - 0.01s - loss: 1.1806 \n", |
| 205 | + "[==============================] 100% Epoch 16/100 - 0.01s - loss: 1.1746 \n", |
| 206 | + "[==============================] 100% Epoch 17/100 - 0.01s - loss: 1.1685 \n", |
| 207 | + "[==============================] 100% Epoch 18/100 - 0.01s - loss: 1.1622 \n", |
| 208 | + "[==============================] 100% Epoch 19/100 - 0.01s - loss: 1.1555 \n", |
| 209 | + "[==============================] 100% Epoch 20/100 - 0.01s - loss: 1.1489 \n", |
| 210 | + "[==============================] 100% Epoch 21/100 - 0.01s - loss: 1.1421 \n", |
| 211 | + "[==============================] 100% Epoch 22/100 - 0.01s - loss: 1.1353 \n", |
| 212 | + "[==============================] 100% Epoch 23/100 - 0.01s - loss: 1.1284 \n", |
| 213 | + "[==============================] 100% Epoch 24/100 - 0.01s - loss: 1.1213 \n", |
| 214 | + "[==============================] 100% Epoch 25/100 - 0.01s - loss: 1.1141 \n", |
| 215 | + "[==============================] 100% Epoch 26/100 - 0.01s - loss: 1.1068 \n", |
| 216 | + "[==============================] 100% Epoch 27/100 - 0.01s - loss: 1.0994 \n", |
| 217 | + "[==============================] 100% Epoch 28/100 - 0.01s - loss: 1.0919 \n", |
| 218 | + "[==============================] 100% Epoch 29/100 - 0.01s - loss: 1.0841 \n", |
| 219 | + "[==============================] 100% Epoch 30/100 - 0.01s - loss: 1.0761 \n", |
| 220 | + "[==============================] 100% Epoch 31/100 - 0.01s - loss: 1.0679 \n", |
| 221 | + "[==============================] 100% Epoch 32/100 - 0.01s - loss: 1.0597 \n", |
| 222 | + "[==============================] 100% Epoch 33/100 - 0.01s - loss: 1.0514 \n", |
| 223 | + "[==============================] 100% Epoch 34/100 - 0.01s - loss: 1.0429 \n", |
| 224 | + "[==============================] 100% Epoch 35/100 - 0.01s - loss: 1.0344 \n", |
| 225 | + "[==============================] 100% Epoch 36/100 - 0.01s - loss: 1.0259 \n", |
| 226 | + "[==============================] 100% Epoch 37/100 - 0.01s - loss: 1.0174 \n", |
| 227 | + "[==============================] 100% Epoch 38/100 - 0.01s - loss: 1.0090 \n", |
| 228 | + "[==============================] 100% Epoch 39/100 - 0.01s - loss: 1.0003 \n", |
| 229 | + "[==============================] 100% Epoch 40/100 - 0.01s - loss: 0.9914 \n", |
| 230 | + "[==============================] 100% Epoch 41/100 - 0.01s - loss: 0.9822 \n", |
| 231 | + "[==============================] 100% Epoch 42/100 - 0.01s - loss: 0.9729 \n", |
| 232 | + "[==============================] 100% Epoch 43/100 - 0.01s - loss: 0.9633 \n", |
| 233 | + "[==============================] 100% Epoch 44/100 - 0.01s - loss: 0.9534 \n", |
| 234 | + "[==============================] 100% Epoch 45/100 - 0.01s - loss: 0.9430 \n", |
| 235 | + "[==============================] 100% Epoch 46/100 - 0.01s - loss: 0.9322 \n", |
| 236 | + "[==============================] 100% Epoch 47/100 - 0.01s - loss: 0.9199 \n", |
| 237 | + "[==============================] 100% Epoch 48/100 - 0.01s - loss: 0.9069 \n", |
| 238 | + "[==============================] 100% Epoch 49/100 - 0.01s - loss: 0.8932 \n", |
| 239 | + "[==============================] 100% Epoch 50/100 - 0.01s - loss: 0.8791 \n", |
| 240 | + "[==============================] 100% Epoch 51/100 - 0.01s - loss: 0.8643 \n", |
| 241 | + "[==============================] 100% Epoch 52/100 - 0.01s - loss: 0.8502 \n", |
| 242 | + "[==============================] 100% Epoch 53/100 - 0.01s - loss: 0.8348 \n", |
| 243 | + "[==============================] 100% Epoch 54/100 - 0.01s - loss: 0.8178 \n", |
| 244 | + "[==============================] 100% Epoch 55/100 - 0.01s - loss: 0.7999 \n", |
| 245 | + "[==============================] 100% Epoch 56/100 - 0.01s - loss: 0.7825 \n", |
| 246 | + "[==============================] 100% Epoch 57/100 - 0.01s - loss: 0.7658 \n", |
| 247 | + "[==============================] 100% Epoch 58/100 - 0.01s - loss: 0.7502 \n", |
| 248 | + "[==============================] 100% Epoch 59/100 - 0.01s - loss: 0.7364 \n", |
| 249 | + "[==============================] 100% Epoch 60/100 - 0.01s - loss: 0.7243 \n", |
| 250 | + "[==============================] 100% Epoch 61/100 - 0.01s - loss: 0.7132 \n", |
| 251 | + "[==============================] 100% Epoch 62/100 - 0.01s - loss: 0.7031 \n", |
| 252 | + "[==============================] 100% Epoch 63/100 - 0.01s - loss: 0.6937 \n", |
| 253 | + "[==============================] 100% Epoch 64/100 - 0.01s - loss: 0.6846 \n", |
| 254 | + "[==============================] 100% Epoch 65/100 - 0.01s - loss: 0.6761 \n", |
| 255 | + "[==============================] 100% Epoch 66/100 - 0.01s - loss: 0.6674 \n", |
| 256 | + "[==============================] 100% Epoch 67/100 - 0.01s - loss: 0.6587 \n", |
| 257 | + "[==============================] 100% Epoch 68/100 - 0.01s - loss: 0.6508 \n", |
| 258 | + "[==============================] 100% Epoch 69/100 - 0.01s - loss: 0.6434 \n", |
| 259 | + "[==============================] 100% Epoch 70/100 - 0.01s - loss: 0.6365 \n", |
| 260 | + "[==============================] 100% Epoch 71/100 - 0.01s - loss: 0.6300 \n", |
| 261 | + "[==============================] 100% Epoch 72/100 - 0.01s - loss: 0.6237 \n", |
| 262 | + "[==============================] 100% Epoch 73/100 - 0.01s - loss: 0.6175 \n", |
| 263 | + "[==============================] 100% Epoch 74/100 - 0.01s - loss: 0.6110 \n", |
| 264 | + "[==============================] 100% Epoch 75/100 - 0.01s - loss: 0.6031 \n", |
| 265 | + "[==============================] 100% Epoch 76/100 - 0.01s - loss: 0.5917 \n", |
| 266 | + "[==============================] 100% Epoch 77/100 - 0.01s - loss: 0.5811 \n", |
| 267 | + "[==============================] 100% Epoch 78/100 - 0.01s - loss: 0.5711 \n", |
| 268 | + "[==============================] 100% Epoch 79/100 - 0.01s - loss: 0.5631 \n", |
| 269 | + "[==============================] 100% Epoch 80/100 - 0.01s - loss: 0.5569 \n", |
| 270 | + "[==============================] 100% Epoch 81/100 - 0.01s - loss: 0.5516 \n", |
| 271 | + "[==============================] 100% Epoch 82/100 - 0.01s - loss: 0.5470 \n", |
| 272 | + "[==============================] 100% Epoch 83/100 - 0.01s - loss: 0.5427 \n", |
| 273 | + "[==============================] 100% Epoch 84/100 - 0.01s - loss: 0.5384 \n", |
| 274 | + "[==============================] 100% Epoch 85/100 - 0.01s - loss: 0.5347 \n", |
| 275 | + "[==============================] 100% Epoch 86/100 - 0.01s - loss: 0.5312 \n", |
| 276 | + "[==============================] 100% Epoch 87/100 - 0.01s - loss: 0.5280 \n", |
| 277 | + "[==============================] 100% Epoch 88/100 - 0.01s - loss: 0.5252 \n", |
| 278 | + "[==============================] 100% Epoch 89/100 - 0.01s - loss: 0.5226 \n", |
| 279 | + "[==============================] 100% Epoch 90/100 - 0.01s - loss: 0.5202 \n", |
| 280 | + "[==============================] 100% Epoch 91/100 - 0.01s - loss: 0.5178 \n", |
| 281 | + "[==============================] 100% Epoch 92/100 - 0.01s - loss: 0.5156 \n", |
| 282 | + "[==============================] 100% Epoch 93/100 - 0.01s - loss: 0.5134 \n", |
| 283 | + "[==============================] 100% Epoch 94/100 - 0.01s - loss: 0.5112 \n", |
| 284 | + "[==============================] 100% Epoch 95/100 - 0.01s - loss: 0.5093 \n", |
| 285 | + "[==============================] 100% Epoch 96/100 - 0.01s - loss: 0.5075 \n", |
| 286 | + "[==============================] 100% Epoch 97/100 - 0.01s - loss: 0.5058 \n", |
| 287 | + "[==============================] 100% Epoch 98/100 - 0.01s - loss: 0.5041 \n", |
| 288 | + "[==============================] 100% Epoch 99/100 - 0.01s - loss: 0.5026 \n", |
| 289 | + "[==============================] 100% Epoch 100/100 - 0.01s - loss: 0.5011 \n", |
202 | 290 | "\n" |
203 | 291 | ] |
204 | 292 | }, |
205 | 293 | { |
206 | 294 | "data": { |
207 | 295 | "text/plain": [] |
208 | 296 | }, |
209 | | - "execution_count": 6, |
| 297 | + "execution_count": 15, |
210 | 298 | "metadata": {}, |
211 | 299 | "output_type": "execute_result" |
212 | 300 | } |
213 | 301 | ], |
214 | 302 | "source": [ |
215 | | - "model.fit(x_train, y_train, epochs=10, batch_size=32, random_state=42)" |
| 303 | + "model.fit(x_train, y_train, epochs=100, batch_size=32, random_state=42)" |
216 | 304 | ] |
217 | 305 | }, |
218 | 306 | { |
|
224 | 312 | }, |
225 | 313 | { |
226 | 314 | "cell_type": "code", |
227 | | - "execution_count": 7, |
| 315 | + "execution_count": 16, |
228 | 316 | "metadata": { |
229 | 317 | "ExecuteTime": { |
230 | 318 | "end_time": "2024-11-14T19:03:42.518566Z", |
|
236 | 324 | "name": "stdout", |
237 | 325 | "output_type": "stream", |
238 | 326 | "text": [ |
239 | | - "Test loss: 1.1136541600695817 function=MeanSquaredError\n" |
| 327 | + "Test loss: 1.7401693358245864 function=MeanSquaredError\n" |
240 | 328 | ] |
241 | 329 | } |
242 | 330 | ], |
|
254 | 342 | }, |
255 | 343 | { |
256 | 344 | "cell_type": "code", |
257 | | - "execution_count": 8, |
| 345 | + "execution_count": 17, |
258 | 346 | "metadata": { |
259 | 347 | "ExecuteTime": { |
260 | 348 | "end_time": "2024-11-14T19:03:42.519566900Z", |
|
266 | 354 | "name": "stdout", |
267 | 355 | "output_type": "stream", |
268 | 356 | "text": [ |
269 | | - "MAE: 0.8748635782918366\n" |
| 357 | + "MAE: 1.0799955616547592\n" |
270 | 358 | ] |
271 | 359 | } |
272 | 360 | ], |
273 | 361 | "source": [ |
274 | 362 | "y_pred = model.predict(x_test)\n", |
275 | 363 | "print(\"MAE: \", MeanAbsoluteError()(y_test, y_pred))" |
276 | 364 | ] |
277 | | - }, |
278 | | - { |
279 | | - "cell_type": "markdown", |
280 | | - "metadata": {}, |
281 | | - "source": [ |
282 | | - "## 8. Getting original MAE (without normalization from StandardScaler)" |
283 | | - ] |
284 | | - }, |
285 | | - { |
286 | | - "cell_type": "code", |
287 | | - "execution_count": 9, |
288 | | - "metadata": { |
289 | | - "ExecuteTime": { |
290 | | - "end_time": "2024-11-14T19:03:42.533072900Z", |
291 | | - "start_time": "2024-11-14T19:03:42.519566900Z" |
292 | | - } |
293 | | - }, |
294 | | - "outputs": [ |
295 | | - { |
296 | | - "name": "stdout", |
297 | | - "output_type": "stream", |
298 | | - "text": [ |
299 | | - "MAE (original): 65.9770599527072\n" |
300 | | - ] |
301 | | - } |
302 | | - ], |
303 | | - "source": [ |
304 | | - "y_pred_scaled = model.predict(x_test)\n", |
305 | | - "\n", |
306 | | - "y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()\n", |
307 | | - "y_test_original = scaler_y.inverse_transform(y_test.reshape(-1, 1)).flatten()\n", |
308 | | - "\n", |
309 | | - "mae_original = np.mean(np.abs(y_test_original - y_pred))\n", |
310 | | - "print(f'MAE (original): {mae_original}')" |
311 | | - ] |
312 | 365 | } |
313 | 366 | ], |
314 | 367 | "metadata": { |
|
0 commit comments