|
2 | 2 | "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "code", |
5 | | - "execution_count": 1, |
| 5 | + "execution_count": 11, |
6 | 6 | "id": "a036f9b8eee0491", |
7 | 7 | "metadata": { |
8 | 8 | "ExecuteTime": { |
|
13 | 13 | }, |
14 | 14 | "outputs": [], |
15 | 15 | "source": [ |
16 | | - "import numpy as np\n", |
17 | 16 | "from neuralnetlib.models import Transformer\n", |
18 | | - "from neuralnetlib.preprocessing import Tokenizer, pad_sequences" |
| 17 | + "from neuralnetlib.preprocessing import Tokenizer, pad_sequences\n", |
| 18 | + "from neuralnetlib.optimizers import Adam" |
19 | 19 | ] |
20 | 20 | }, |
21 | 21 | { |
22 | 22 | "cell_type": "code", |
23 | | - "execution_count": 2, |
| 23 | + "execution_count": 12, |
24 | 24 | "id": "be237a3421e586a2", |
25 | 25 | "metadata": { |
26 | 26 | "ExecuteTime": { |
|
50 | 50 | }, |
51 | 51 | { |
52 | 52 | "cell_type": "code", |
53 | | - "execution_count": 3, |
| 53 | + "execution_count": 13, |
54 | 54 | "id": "f4c0d8598f0ba7a", |
55 | 55 | "metadata": { |
56 | 56 | "ExecuteTime": { |
|
70 | 70 | }, |
71 | 71 | { |
72 | 72 | "cell_type": "code", |
73 | | - "execution_count": 4, |
| 73 | + "execution_count": 14, |
74 | 74 | "id": "67338439", |
75 | 75 | "metadata": {}, |
76 | 76 | "outputs": [], |
|
89 | 89 | }, |
90 | 90 | { |
91 | 91 | "cell_type": "code", |
92 | | - "execution_count": 5, |
| 92 | + "execution_count": 15, |
93 | 93 | "id": "5501a2c7", |
94 | 94 | "metadata": {}, |
95 | 95 | "outputs": [ |
|
126 | 126 | }, |
127 | 127 | { |
128 | 128 | "cell_type": "code", |
129 | | - "execution_count": 6, |
| 129 | + "execution_count": 16, |
130 | 130 | "id": "68d2884d", |
131 | 131 | "metadata": {}, |
132 | 132 | "outputs": [ |
|
136 | 136 | "text": [ |
137 | 137 | "Transformer(\n", |
138 | 138 | " vocab_size=24,\n", |
139 | | - " d_model=128,\n", |
140 | | - " n_heads=4,\n", |
141 | | - " n_encoder_layers=2,\n", |
142 | | - " n_decoder_layers=2,\n", |
143 | | - " d_ff=256,\n", |
144 | | - " dropout_rate=0.2,\n", |
145 | | - " max_sequence_length=5\n", |
| 139 | + " d_model=64,\n", |
| 140 | + " n_heads=2,\n", |
| 141 | + " n_encoder_layers=1,\n", |
| 142 | + " n_decoder_layers=1,\n", |
| 143 | + " d_ff=128,\n", |
| 144 | + " dropout_rate=0.1,\n", |
| 145 | + " max_sequence_length=512\n", |
146 | 146 | ")\n" |
147 | 147 | ] |
148 | 148 | } |
149 | 149 | ], |
150 | 150 | "source": [ |
151 | 151 | "model = Transformer(\n", |
152 | 152 | " vocab_size=max_vocab_size,\n", |
153 | | - " d_model=128,\n", |
154 | | - " n_heads=4,\n", |
155 | | - " n_encoder_layers=2,\n", |
156 | | - " n_decoder_layers=2,\n", |
157 | | - " d_ff=256,\n", |
158 | | - " dropout_rate=0.2,\n", |
159 | | - " max_sequence_length=max_seq_len,\n", |
160 | | - " temperature=0.7,\n", |
| 153 | + " d_model=64,\n", |
| 154 | + " n_heads=2,\n", |
| 155 | + " n_encoder_layers=1,\n", |
| 156 | + " n_decoder_layers=1,\n", |
| 157 | + " d_ff=128,\n", |
| 158 | + " dropout_rate=0.1,\n", |
| 159 | + " temperature=1.0,\n", |
161 | 160 | " random_state=42\n", |
162 | 161 | ")\n", |
163 | 162 | "\n", |
164 | 163 | "\n", |
165 | 164 | "model.compile(\n", |
166 | 165 | " loss_function='sequencecrossentropy',\n", |
167 | | - " optimizer='adam',\n", |
| 166 | + " optimizer=Adam(learning_rate=0.001),\n", |
168 | 167 | " verbose=True\n", |
169 | 168 | ")" |
170 | 169 | ] |
171 | 170 | }, |
172 | 171 | { |
173 | 172 | "cell_type": "code", |
174 | | - "execution_count": 7, |
| 173 | + "execution_count": 17, |
175 | 174 | "id": "845375dc", |
176 | 175 | "metadata": {}, |
177 | 176 | "outputs": [], |
|
181 | 180 | }, |
182 | 181 | { |
183 | 182 | "cell_type": "code", |
184 | | - "execution_count": 8, |
| 183 | + "execution_count": 18, |
185 | 184 | "id": "e3bdab93", |
186 | 185 | "metadata": {}, |
187 | 186 | "outputs": [ |
188 | 187 | { |
189 | 188 | "name": "stdout", |
190 | 189 | "output_type": "stream", |
191 | 190 | "text": [ |
192 | | - "[==============================] 100% Epoch 1/50 - loss: 12.6405 - - 0.07s\n", |
193 | | - "[==============================] 100% Epoch 2/50 - loss: 8.8913 - - 0.04s\n", |
194 | | - "[==============================] 100% Epoch 3/50 - loss: 5.5905 - - 0.03s\n", |
195 | | - "[==============================] 100% Epoch 4/50 - loss: 1.8309 - - 0.03s\n", |
196 | | - "[==============================] 100% Epoch 5/50 - loss: 1.3206 - - 0.03s\n", |
197 | | - "[==============================] 100% Epoch 6/50 - loss: 0.0618 - - 0.03s\n", |
198 | | - "[==============================] 100% Epoch 7/50 - loss: 0.0073 - - 0.04s\n", |
199 | | - "[==============================] 100% Epoch 8/50 - loss: 0.0071 - - 0.03s\n", |
200 | | - "[==============================] 100% Epoch 9/50 - loss: 0.0077 - - 0.03s\n", |
201 | | - "[==============================] 100% Epoch 10/50 - loss: 0.0088 - - 0.03s\n", |
202 | | - "[==============================] 100% Epoch 11/50 - loss: 0.0137 - - 0.03s\n", |
203 | | - "[==============================] 100% Epoch 12/50 - loss: 0.0133 - - 0.03s\n", |
204 | | - "[==============================] 100% Epoch 13/50 - loss: 0.0125 - - 0.03s\n", |
205 | | - "[==============================] 100% Epoch 14/50 - loss: 0.0065 - - 0.03s\n", |
206 | | - "[==============================] 100% Epoch 15/50 - loss: 0.0057 - - 0.03s\n", |
207 | | - "[==============================] 100% Epoch 16/50 - loss: 0.0051 - - 0.03s\n", |
208 | | - "[==============================] 100% Epoch 17/50 - loss: 0.0045 - - 0.03s\n", |
209 | | - "[==============================] 100% Epoch 18/50 - loss: 0.0040 - - 0.03s\n", |
210 | | - "[==============================] 100% Epoch 19/50 - loss: 0.0036 - - 0.03s\n", |
211 | | - "[==============================] 100% Epoch 20/50 - loss: 0.0033 - - 0.04s\n", |
212 | | - "[==============================] 100% Epoch 21/50 - loss: 0.0030 - - 0.04s\n", |
213 | | - "[==============================] 100% Epoch 22/50 - loss: 0.0027 - - 0.03s\n", |
214 | | - "[==============================] 100% Epoch 23/50 - loss: 0.0025 - - 0.03s\n", |
215 | | - "[==============================] 100% Epoch 24/50 - loss: 0.0023 - - 0.03s\n", |
216 | | - "[==============================] 100% Epoch 25/50 - loss: 0.0021 - - 0.04s\n", |
217 | | - "[==============================] 100% Epoch 26/50 - loss: 0.0020 - - 0.04s\n", |
218 | | - "[==============================] 100% Epoch 27/50 - loss: 0.0018 - - 0.03s\n", |
219 | | - "[==============================] 100% Epoch 28/50 - loss: 0.0017 - - 0.04s\n", |
220 | | - "[==============================] 100% Epoch 29/50 - loss: 0.0017 - - 0.06s\n", |
221 | | - "[==============================] 100% Epoch 30/50 - loss: 0.0018 - - 0.03s\n", |
222 | | - "[==============================] 100% Epoch 31/50 - loss: 0.0020 - - 0.03s\n", |
223 | | - "[==============================] 100% Epoch 32/50 - loss: 0.0024 - - 0.03s\n", |
224 | | - "[==============================] 100% Epoch 33/50 - loss: 0.0030 - - 0.03s\n", |
225 | | - "[==============================] 100% Epoch 34/50 - loss: 0.0086 - - 0.03s\n", |
226 | | - "[==============================] 100% Epoch 35/50 - loss: 0.0030 - - 0.03s\n", |
227 | | - "[==============================] 100% Epoch 36/50 - loss: 0.0030 - - 0.03s\n", |
228 | | - "[==============================] 100% Epoch 37/50 - loss: 0.0079 - - 0.03s\n", |
229 | | - "[==============================] 100% Epoch 38/50 - loss: 0.0032 - - 0.03s\n", |
230 | | - "[==============================] 100% Epoch 39/50 - loss: 0.0035 - - 0.03s\n", |
231 | | - "[==============================] 100% Epoch 40/50 - loss: 0.0043 - - 0.04s\n", |
232 | | - "[==============================] 100% Epoch 41/50 - loss: 0.0093 - - 0.03s\n", |
233 | | - "[==============================] 100% Epoch 42/50 - loss: 0.0043 - - 0.03s\n", |
234 | | - "[==============================] 100% Epoch 43/50 - loss: 0.0042 - - 0.03s\n", |
235 | | - "[==============================] 100% Epoch 44/50 - loss: 0.0044 - - 0.03s\n", |
236 | | - "[==============================] 100% Epoch 45/50 - loss: 0.0047 - - 0.03s\n", |
237 | | - "[==============================] 100% Epoch 46/50 - loss: 0.0093 - - 0.03s\n", |
238 | | - "[==============================] 100% Epoch 47/50 - loss: 0.0039 - - 0.03s\n", |
239 | | - "[==============================] 100% Epoch 48/50 - loss: 0.0034 - - 0.03s\n", |
240 | | - "[==============================] 100% Epoch 49/50 - loss: 0.0032 - - 0.04s\n", |
241 | | - "[==============================] 100% Epoch 50/50 - loss: 0.0030 - - 0.04s\n", |
| 191 | + "[==============================] 100% Epoch 1/50 - loss: 13.4469 - - 0.47s\n", |
| 192 | + "[==============================] 100% Epoch 2/50 - loss: 11.0863 - - 0.47s\n", |
| 193 | + "[==============================] 100% Epoch 3/50 - loss: 5.6369 - - 0.45s\n", |
| 194 | + "[==============================] 100% Epoch 4/50 - loss: 3.0461 - - 0.47s\n", |
| 195 | + "[==============================] 100% Epoch 5/50 - loss: 1.0563 - - 0.47s\n", |
| 196 | + "[==============================] 100% Epoch 6/50 - loss: 0.0950 - - 0.47s\n", |
| 197 | + "[==============================] 100% Epoch 7/50 - loss: 0.0527 - - 0.48s\n", |
| 198 | + "[==============================] 100% Epoch 8/50 - loss: 0.0339 - - 0.48s\n", |
| 199 | + "[==============================] 100% Epoch 9/50 - loss: 0.0211 - - 0.46s\n", |
| 200 | + "[==============================] 100% Epoch 10/50 - loss: 0.0163 - - 0.46s\n", |
| 201 | + "[==============================] 100% Epoch 11/50 - loss: 0.0121 - - 0.47s\n", |
| 202 | + "[==============================] 100% Epoch 12/50 - loss: 0.0085 - - 0.47s\n", |
| 203 | + "[==============================] 100% Epoch 13/50 - loss: 0.0055 - - 0.47s\n", |
| 204 | + "[==============================] 100% Epoch 14/50 - loss: 0.0044 - - 0.46s\n", |
| 205 | + "[==============================] 100% Epoch 15/50 - loss: 0.0047 - - 0.46s\n", |
| 206 | + "[==============================] 100% Epoch 16/50 - loss: 0.0053 - - 0.45s\n", |
| 207 | + "[==============================] 100% Epoch 17/50 - loss: 0.0060 - - 0.47s\n", |
| 208 | + "[==============================] 100% Epoch 18/50 - loss: 0.0069 - - 0.45s\n", |
| 209 | + "[==============================] 100% Epoch 19/50 - loss: 0.0079 - - 0.47s\n", |
| 210 | + "[==============================] 100% Epoch 20/50 - loss: 0.0090 - - 0.46s\n", |
| 211 | + "[==============================] 100% Epoch 21/50 - loss: 0.0102 - - 0.46s\n", |
| 212 | + "[==============================] 100% Epoch 22/50 - loss: 0.0118 - - 0.51s\n", |
| 213 | + "[==============================] 100% Epoch 23/50 - loss: 0.0164 - - 0.50s\n", |
| 214 | + "[==============================] 100% Epoch 24/50 - loss: 0.0209 - - 0.49s\n", |
| 215 | + "[==============================] 100% Epoch 25/50 - loss: 0.0248 - - 0.53s\n", |
| 216 | + "[==============================] 100% Epoch 26/50 - loss: 0.0269 - - 0.54s\n", |
| 217 | + "[==============================] 100% Epoch 27/50 - loss: 0.0124 - - 0.56s\n", |
| 218 | + "[==============================] 100% Epoch 28/50 - loss: 0.0110 - - 0.53s\n", |
| 219 | + "[==============================] 100% Epoch 29/50 - loss: 0.0099 - - 0.49s\n", |
| 220 | + "[==============================] 100% Epoch 30/50 - loss: 0.0089 - - 0.48s\n", |
| 221 | + "[==============================] 100% Epoch 31/50 - loss: 0.0077 - - 0.52s\n", |
| 222 | + "[==============================] 100% Epoch 32/50 - loss: 0.0066 - - 0.55s\n", |
| 223 | + "[==============================] 100% Epoch 33/50 - loss: 0.0053 - - 0.55s\n", |
| 224 | + "[==============================] 100% Epoch 34/50 - loss: 0.0040 - - 0.50s\n", |
| 225 | + "[==============================] 100% Epoch 35/50 - loss: 0.0036 - - 0.50s\n", |
| 226 | + "[==============================] 100% Epoch 36/50 - loss: 0.0036 - - 0.49s\n", |
| 227 | + "[==============================] 100% Epoch 37/50 - loss: 0.0035 - - 0.49s\n", |
| 228 | + "[==============================] 100% Epoch 38/50 - loss: 0.0034 - - 0.53s\n", |
| 229 | + "[==============================] 100% Epoch 39/50 - loss: 0.0033 - - 0.54s\n", |
| 230 | + "[==============================] 100% Epoch 40/50 - loss: 0.0032 - - 0.50s\n", |
| 231 | + "[==============================] 100% Epoch 41/50 - loss: 0.0031 - - 0.49s\n", |
| 232 | + "[==============================] 100% Epoch 42/50 - loss: 0.0029 - - 0.49s\n", |
| 233 | + "[==============================] 100% Epoch 43/50 - loss: 0.0028 - - 0.52s\n", |
| 234 | + "[==============================] 100% Epoch 44/50 - loss: 0.0026 - - 0.54s\n", |
| 235 | + "[==============================] 100% Epoch 45/50 - loss: 0.0025 - - 0.56s\n", |
| 236 | + "[==============================] 100% Epoch 46/50 - loss: 0.0023 - - 0.48s\n", |
| 237 | + "[==============================] 100% Epoch 47/50 - loss: 0.0022 - - 0.48s\n", |
| 238 | + "[==============================] 100% Epoch 48/50 - loss: 0.0021 - - 0.51s\n", |
| 239 | + "[==============================] 100% Epoch 49/50 - loss: 0.0021 - - 0.51s\n", |
| 240 | + "[==============================] 100% Epoch 50/50 - loss: 0.0020 - - 0.56s\n", |
242 | 241 | "\n" |
243 | 242 | ] |
244 | 243 | } |
|
247 | 246 | "history = model.fit(\n", |
248 | 247 | " x_train_padded, y_train_padded,\n", |
249 | 248 | " epochs=50,\n", |
250 | | - " batch_size=5,\n", |
| 249 | + " batch_size=2,\n", |
251 | 250 | " verbose=True\n", |
252 | 251 | ")" |
253 | 252 | ] |
254 | 253 | }, |
255 | 254 | { |
256 | 255 | "cell_type": "code", |
257 | | - "execution_count": 9, |
| 256 | + "execution_count": 19, |
258 | 257 | "id": "c1dc335b", |
259 | 258 | "metadata": {}, |
260 | 259 | "outputs": [ |
|
263 | 262 | "output_type": "stream", |
264 | 263 | "text": [ |
265 | 264 | "FR: je suis heureux.\n", |
266 | | - "EN: goodbye goodbye goodbye goodbye\n", |
| 265 | + "EN: you cats cats\n", |
267 | 266 | "\n", |
268 | 267 | "FR: comment allez-vous ?\n", |
269 | | - "EN: goodbye goodbye goodbye goodbye\n", |
| 268 | + "EN: you cats cats\n", |
270 | 269 | "\n", |
271 | 270 | "FR: bonjour le monde.\n", |
272 | | - "EN: goodbye goodbye goodbye goodbye\n", |
| 271 | + "EN: you cats cats\n", |
273 | 272 | "\n" |
274 | 273 | ] |
275 | 274 | } |
276 | 275 | ], |
277 | 276 | "source": [ |
278 | 277 | "def translate(sentence: str, model, fr_tokenizer, en_tokenizer) -> str:\n", |
279 | 278 | " tokens = fr_tokenizer.texts_to_sequences([sentence], preprocess_ponctuation=True)[0]\n", |
280 | | - " tokens = [model.SOS_IDX] + [t + 4 for t in tokens] + [model.EOS_IDX] # Shift indices by 4\n", |
| 279 | + " tokens = [model.SOS_IDX] + tokens + [model.EOS_IDX]\n", |
281 | 280 | " padded = pad_sequences([tokens], max_length=max_len_x, padding='post', pad_value=model.PAD_IDX)\n", |
282 | 281 | " \n", |
283 | 282 | " pred = model.predict(padded, max_length=max_seq_len)[0]\n", |
|
0 commit comments