|
68 | 68 | }, |
69 | 69 | { |
70 | 70 | "cell_type": "code", |
71 | | - "execution_count": null, |
| 71 | + "execution_count": 3, |
72 | 72 | "metadata": { |
73 | 73 | "ExecuteTime": { |
74 | 74 | "end_time": "2024-11-14T19:16:46.274852700Z", |
|
84 | 84 | "x_test shape: (25000, 200)\n", |
85 | 85 | "y_train shape: (20000,)\n", |
86 | 86 | "y_test shape: (25000,)\n", |
87 | | - "x_train[0]: [4.500e+01 1.080e+02 1.000e+01 1.000e+01 1.100e+01 4.000e+00 6.500e+01\n", |
88 | | - " 3.960e+03 9.000e+00 1.100e+01 4.100e+01 4.020e+02 2.000e+00 7.800e+02\n", |
89 | | - " 3.300e+01 2.000e+00 6.130e+03 1.100e+01 2.000e+00 4.000e+00 2.763e+03\n", |
90 | | - " 8.440e+02 2.600e+01 2.000e+00 2.240e+02 5.000e+00 1.930e+02 3.960e+03\n", |
91 | | - " 3.900e+01 4.400e+01 7.900e+02 1.530e+02 1.540e+02 1.430e+02 4.100e+01\n", |
92 | | - " 2.521e+03 5.600e+01 8.000e+00 4.100e+01 2.028e+03 5.590e+02 1.100e+01\n", |
93 | | - " 4.000e+00 2.000e+01 4.400e+01 6.383e+03 5.284e+03 4.740e+02 4.820e+02\n", |
94 | | - " 1.300e+01 6.600e+01 9.200e+01 1.040e+02 2.250e+02 6.000e+00 4.040e+02\n", |
95 | | - " 5.240e+02 1.800e+01 3.960e+03 1.800e+01 1.110e+02 7.000e+00 1.780e+02\n", |
96 | | - " 3.960e+03 4.510e+02 4.420e+02 7.600e+01 9.900e+01 9.760e+02 6.000e+00\n", |
97 | | - " 1.369e+03 1.100e+01 2.630e+02 2.000e+00 4.600e+02 8.519e+03 2.000e+00\n", |
98 | | - " 9.000e+00 3.084e+03 5.900e+01 9.000e+00 5.500e+01 7.207e+03 2.000e+00\n", |
99 | | - " 5.000e+00 2.000e+00 5.900e+01 4.700e+01 7.750e+02 7.000e+00 9.963e+03\n", |
100 | | - " 5.900e+01 4.700e+01 6.000e+00 8.700e+01 3.930e+02 3.100e+01 1.500e+01\n", |
101 | | - " 3.775e+03 1.100e+01 1.290e+02 3.300e+02 7.300e+01 1.030e+02 4.000e+00\n", |
102 | | - " 2.000e+01 9.000e+00 1.200e+02 1.793e+03 8.000e+00 2.000e+00 2.000e+00\n", |
103 | | - " 2.000e+00 5.071e+03 3.960e+03 4.700e+01 2.470e+02 6.000e+00 5.879e+03\n", |
104 | | - " 8.220e+02 7.400e+01 2.000e+00 2.100e+01 1.460e+02 1.688e+03 8.000e+00\n", |
105 | | - " 4.909e+03 1.500e+01 4.800e+01 2.000e+00 1.999e+03 1.100e+01 4.000e+00\n", |
106 | | - " 2.170e+02 1.300e+01 1.040e+02 5.900e+01 8.000e+01 2.700e+03 8.300e+01\n", |
107 | | - " 1.200e+01 4.300e+01 1.700e+01 3.960e+03 3.418e+03 5.300e+01 9.760e+02\n", |
108 | | - " 5.000e+00 6.861e+03 1.700e+01 5.900e+01 2.140e+02 9.220e+02 2.000e+00\n", |
109 | | - " 4.600e+02 5.603e+03 2.000e+00 4.860e+02 5.000e+00 1.557e+03 2.000e+00\n", |
110 | | - " 5.500e+01 7.300e+01 1.400e+02 1.404e+03 5.000e+00 8.510e+02 1.400e+01\n", |
111 | | - " 2.000e+01 4.500e+01 2.400e+01 4.000e+01 2.330e+02 3.340e+02 8.740e+02\n", |
112 | | - " 1.100e+02 5.000e+00 6.000e+01 1.510e+02 1.200e+01 1.600e+01 5.260e+02\n", |
113 | | - " 3.400e+01 1.091e+03 2.000e+00 1.200e+01 9.000e+00 3.680e+02 7.000e+00\n", |
114 | | - " 2.000e+00 2.442e+03 8.700e+01 3.700e+02 1.102e+03 1.524e+03 5.000e+00\n", |
115 | | - " 7.300e+01 2.240e+02 2.060e+02 8.440e+02]\n", |
| 87 | + "x_train[0]: [ 45 108 10 10 11 4 65 3960 9 11 41 402 2 780\n", |
| 88 | + " 33 2 6130 11 2 4 2763 844 26 2 224 5 193 3960\n", |
| 89 | + " 39 44 790 153 154 143 41 2521 56 8 41 2028 559 11\n", |
| 90 | + " 4 20 44 6383 5284 474 482 13 66 92 104 225 6 404\n", |
| 91 | + " 524 18 3960 18 111 7 178 3960 451 442 76 99 976 6\n", |
| 92 | + " 1369 11 263 2 460 8519 2 9 3084 59 9 55 7207 2\n", |
| 93 | + " 5 2 59 47 775 7 9963 59 47 6 87 393 31 15\n", |
| 94 | + " 3775 11 129 330 73 103 4 20 9 120 1793 8 2 2\n", |
| 95 | + " 2 5071 3960 47 247 6 5879 822 74 2 21 146 1688 8\n", |
| 96 | + " 4909 15 48 2 1999 11 4 217 13 104 59 80 2700 83\n", |
| 97 | + " 12 43 17 3960 3418 53 976 5 6861 17 59 214 922 2\n", |
| 98 | + " 460 5603 2 486 5 1557 2 55 73 140 1404 5 851 14\n", |
| 99 | + " 20 45 24 40 233 334 874 110 5 60 151 12 16 526\n", |
| 100 | + " 34 1091 2 12 9 368 7 2 2442 87 370 1102 1524 5\n", |
| 101 | + " 73 224 206 844]\n", |
116 | 102 | "y_train[0]: 1\n" |
117 | 103 | ] |
118 | 104 | } |
|
184 | 170 | "name": "stdout", |
185 | 171 | "output_type": "stream", |
186 | 172 | "text": [ |
187 | | - "Sequential(temperature=1.0, gradient_clip_threshold=5.0, enable_padding=False, padding_size=32, random_state=1731611806261338000)\n", |
| 173 | + "Sequential(gradient_clip_threshold=5.0, enable_padding=False, padding_size=32, random_state=1733520050429276600)\n", |
188 | 174 | "-------------------------------------------------\n", |
189 | 175 | "Layer 1: Input(input_shape=(200,))\n", |
190 | 176 | "Layer 2: Embedding(input_dim=10000, output_dim=100)\n", |
191 | | - "Layer 3: Bidirectional(layer=LSTM(units=32, return_sequences=True, return_state=False, random_state=None, clip_value=5.0))\n", |
| 177 | + "Layer 3: Bidirectional(layer=LSTM(units=32, return_sequences=True, return_state=False, clip_value=5.0, random_state=None))\n", |
192 | 178 | "Layer 4: Attention(use_scale=True, score_mode=dot, return_sequences=False)\n", |
193 | 179 | "Layer 5: Dense(units=1)\n", |
194 | 180 | "Layer 6: Activation(Sigmoid)\n", |
195 | 181 | "-------------------------------------------------\n", |
196 | 182 | "Loss function: BinaryCrossentropy\n", |
197 | 183 | "Optimizer: Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clip_norm=None, clip_value=None)\n", |
198 | | - "-------------------------------------------------\n" |
| 184 | + "-------------------------------------------------\n", |
| 185 | + "\n" |
199 | 186 | ] |
200 | 187 | } |
201 | 188 | ], |
|
214 | 201 | }, |
215 | 202 | { |
216 | 203 | "cell_type": "code", |
217 | | - "execution_count": 6, |
| 204 | + "execution_count": null, |
218 | 205 | "metadata": { |
219 | 206 | "ExecuteTime": { |
220 | 207 | "end_time": "2024-11-14T19:49:32.469776300Z", |
|
226 | 213 | "name": "stdout", |
227 | 214 | "output_type": "stream", |
228 | 215 | "text": [ |
229 | | - "[==============================] 100% Epoch 1/10 - loss: 0.4515 - accuracy: 0.8001 - 149.41s - val_accuracy: 0.8397\n", |
230 | | - "[==============================] 100% Epoch 2/10 - loss: 0.2670 - accuracy: 0.8926 - 134.93s - val_accuracy: 0.8542\n", |
231 | | - "[==============================] 100% Epoch 3/10 - loss: 0.2317 - accuracy: 0.9129 - 136.07s - val_accuracy: 0.8512\n", |
232 | | - "[==============================] 100% Epoch 4/10 - loss: 0.2437 - accuracy: 0.9196 - 133.42s - val_accuracy: 0.8383\n", |
233 | | - "[==============================] 100% Epoch 5/10 - loss: 0.2506 - accuracy: 0.9280 - 135.57s - val_accuracy: 0.8449\n", |
234 | | - "[==============================] 100% Epoch 6/10 - loss: 0.2753 - accuracy: 0.9333 - 138.26s - val_accuracy: 0.8346\n", |
235 | | - "[==============================] 100% Epoch 7/10 - loss: 0.3047 - accuracy: 0.9371 - 141.18s - val_accuracy: 0.8236\n", |
236 | | - "[==============================] 100% Epoch 8/10 - loss: 0.3261 - accuracy: 0.9405 - 140.46s - val_accuracy: 0.8178\n", |
237 | | - "[==============================] 100% Epoch 9/10 - loss: 0.3593 - accuracy: 0.9459 - 135.74s - val_accuracy: 0.8236\n", |
238 | | - "[==============================] 100% Epoch 10/10 - loss: 0.3402 - accuracy: 0.9528 - 144.90s - val_accuracy: 0.8296\n" |
| 216 | + "[==============================] 100% Epoch 1/5 - 274.36s - loss: 0.6359 - accuracy: 0.6967 - val_loss: 0.7389 - val_accuracy: 0.7740\n", |
| 217 | + "[==============================] 100% Epoch 2/5 - 276.99s - loss: 0.4441 - accuracy: 0.8237 - val_loss: 1.0205 - val_accuracy: 0.8307\n", |
| 218 | + "[==============================] 100% Epoch 3/5 - 285.17s - loss: 0.3278 - accuracy: 0.8611 - val_loss: 1.4672 - val_accuracy: 0.8485\n", |
| 219 | + "[==============================] 100% Epoch 4/5 - 269.14s - loss: 0.2853 - accuracy: 0.8797 - val_loss: 1.9860 - val_accuracy: 0.8568\n", |
| 220 | + "[==============================] 100% Epoch 5/5 - 267.17s - loss: 0.2713 - accuracy: 0.8895 - val_loss: 2.5888 - val_accuracy: 0.8598\n", |
| 221 | + "\n" |
239 | 222 | ] |
240 | | - }, |
241 | | - { |
242 | | - "data": { |
243 | | - "text/plain": [] |
244 | | - }, |
245 | | - "execution_count": 6, |
246 | | - "metadata": {}, |
247 | | - "output_type": "execute_result" |
248 | 223 | } |
249 | 224 | ], |
250 | 225 | "source": [ |
251 | | - "model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_test, y_test), metrics=['accuracy'], random_state=42)" |
| 226 | + "model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test), metrics=['accuracy'], random_state=42)" |
252 | 227 | ] |
253 | 228 | }, |
254 | 229 | { |
|
260 | 235 | }, |
261 | 236 | { |
262 | 237 | "cell_type": "code", |
263 | | - "execution_count": 7, |
| 238 | + "execution_count": 9, |
264 | 239 | "metadata": { |
265 | 240 | "ExecuteTime": { |
266 | 241 | "end_time": "2024-11-14T19:49:59.259717400Z", |
|
272 | 247 | "name": "stdout", |
273 | 248 | "output_type": "stream", |
274 | 249 | "text": [ |
275 | | - "Loss: 11.780504039134605\n", |
276 | | - "Accuracy: 0.831\n" |
| 250 | + "Loss: 2.6114417790014\n", |
| 251 | + "Accuracy: 0.8712\n" |
277 | 252 | ] |
278 | 253 | } |
279 | 254 | ], |
|
0 commit comments