Skip to content

Commit 937d7c3

Browse files
committed
fix(Transformer): tokenization, sequence handling and shapes
1 parent 1ddaa82 commit 937d7c3

File tree

5 files changed

+278
-178
lines changed

5 files changed

+278
-178
lines changed

examples/generation/transformer-text-generation/transformer-for-translation.ipynb

Lines changed: 81 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 11,
66
"id": "a036f9b8eee0491",
77
"metadata": {
88
"ExecuteTime": {
@@ -13,14 +13,14 @@
1313
},
1414
"outputs": [],
1515
"source": [
16-
"import numpy as np\n",
1716
"from neuralnetlib.models import Transformer\n",
18-
"from neuralnetlib.preprocessing import Tokenizer, pad_sequences"
17+
"from neuralnetlib.preprocessing import Tokenizer, pad_sequences\n",
18+
"from neuralnetlib.optimizers import Adam"
1919
]
2020
},
2121
{
2222
"cell_type": "code",
23-
"execution_count": 2,
23+
"execution_count": 12,
2424
"id": "be237a3421e586a2",
2525
"metadata": {
2626
"ExecuteTime": {
@@ -50,7 +50,7 @@
5050
},
5151
{
5252
"cell_type": "code",
53-
"execution_count": 3,
53+
"execution_count": 13,
5454
"id": "f4c0d8598f0ba7a",
5555
"metadata": {
5656
"ExecuteTime": {
@@ -70,7 +70,7 @@
7070
},
7171
{
7272
"cell_type": "code",
73-
"execution_count": 4,
73+
"execution_count": 14,
7474
"id": "67338439",
7575
"metadata": {},
7676
"outputs": [],
@@ -89,7 +89,7 @@
8989
},
9090
{
9191
"cell_type": "code",
92-
"execution_count": 5,
92+
"execution_count": 15,
9393
"id": "5501a2c7",
9494
"metadata": {},
9595
"outputs": [
@@ -126,7 +126,7 @@
126126
},
127127
{
128128
"cell_type": "code",
129-
"execution_count": 6,
129+
"execution_count": 16,
130130
"id": "68d2884d",
131131
"metadata": {},
132132
"outputs": [
@@ -136,42 +136,41 @@
136136
"text": [
137137
"Transformer(\n",
138138
" vocab_size=24,\n",
139-
" d_model=128,\n",
140-
" n_heads=4,\n",
141-
" n_encoder_layers=2,\n",
142-
" n_decoder_layers=2,\n",
143-
" d_ff=256,\n",
144-
" dropout_rate=0.2,\n",
145-
" max_sequence_length=5\n",
139+
" d_model=64,\n",
140+
" n_heads=2,\n",
141+
" n_encoder_layers=1,\n",
142+
" n_decoder_layers=1,\n",
143+
" d_ff=128,\n",
144+
" dropout_rate=0.1,\n",
145+
" max_sequence_length=512\n",
146146
")\n"
147147
]
148148
}
149149
],
150150
"source": [
151151
"model = Transformer(\n",
152152
" vocab_size=max_vocab_size,\n",
153-
" d_model=128,\n",
154-
" n_heads=4,\n",
155-
" n_encoder_layers=2,\n",
156-
" n_decoder_layers=2,\n",
157-
" d_ff=256,\n",
158-
" dropout_rate=0.2,\n",
159-
" max_sequence_length=max_seq_len,\n",
160-
" temperature=0.7,\n",
153+
" d_model=64,\n",
154+
" n_heads=2,\n",
155+
" n_encoder_layers=1,\n",
156+
" n_decoder_layers=1,\n",
157+
" d_ff=128,\n",
158+
" dropout_rate=0.1,\n",
159+
" temperature=1.0,\n",
161160
" random_state=42\n",
162161
")\n",
163162
"\n",
164163
"\n",
165164
"model.compile(\n",
166165
" loss_function='sequencecrossentropy',\n",
167-
" optimizer='adam',\n",
166+
" optimizer=Adam(learning_rate=0.001),\n",
168167
" verbose=True\n",
169168
")"
170169
]
171170
},
172171
{
173172
"cell_type": "code",
174-
"execution_count": 7,
173+
"execution_count": 17,
175174
"id": "845375dc",
176175
"metadata": {},
177176
"outputs": [],
@@ -181,64 +180,64 @@
181180
},
182181
{
183182
"cell_type": "code",
184-
"execution_count": 8,
183+
"execution_count": 18,
185184
"id": "e3bdab93",
186185
"metadata": {},
187186
"outputs": [
188187
{
189188
"name": "stdout",
190189
"output_type": "stream",
191190
"text": [
192-
"[==============================] 100% Epoch 1/50 - loss: 12.6405 - - 0.07s\n",
193-
"[==============================] 100% Epoch 2/50 - loss: 8.8913 - - 0.04s\n",
194-
"[==============================] 100% Epoch 3/50 - loss: 5.5905 - - 0.03s\n",
195-
"[==============================] 100% Epoch 4/50 - loss: 1.8309 - - 0.03s\n",
196-
"[==============================] 100% Epoch 5/50 - loss: 1.3206 - - 0.03s\n",
197-
"[==============================] 100% Epoch 6/50 - loss: 0.0618 - - 0.03s\n",
198-
"[==============================] 100% Epoch 7/50 - loss: 0.0073 - - 0.04s\n",
199-
"[==============================] 100% Epoch 8/50 - loss: 0.0071 - - 0.03s\n",
200-
"[==============================] 100% Epoch 9/50 - loss: 0.0077 - - 0.03s\n",
201-
"[==============================] 100% Epoch 10/50 - loss: 0.0088 - - 0.03s\n",
202-
"[==============================] 100% Epoch 11/50 - loss: 0.0137 - - 0.03s\n",
203-
"[==============================] 100% Epoch 12/50 - loss: 0.0133 - - 0.03s\n",
204-
"[==============================] 100% Epoch 13/50 - loss: 0.0125 - - 0.03s\n",
205-
"[==============================] 100% Epoch 14/50 - loss: 0.0065 - - 0.03s\n",
206-
"[==============================] 100% Epoch 15/50 - loss: 0.0057 - - 0.03s\n",
207-
"[==============================] 100% Epoch 16/50 - loss: 0.0051 - - 0.03s\n",
208-
"[==============================] 100% Epoch 17/50 - loss: 0.0045 - - 0.03s\n",
209-
"[==============================] 100% Epoch 18/50 - loss: 0.0040 - - 0.03s\n",
210-
"[==============================] 100% Epoch 19/50 - loss: 0.0036 - - 0.03s\n",
211-
"[==============================] 100% Epoch 20/50 - loss: 0.0033 - - 0.04s\n",
212-
"[==============================] 100% Epoch 21/50 - loss: 0.0030 - - 0.04s\n",
213-
"[==============================] 100% Epoch 22/50 - loss: 0.0027 - - 0.03s\n",
214-
"[==============================] 100% Epoch 23/50 - loss: 0.0025 - - 0.03s\n",
215-
"[==============================] 100% Epoch 24/50 - loss: 0.0023 - - 0.03s\n",
216-
"[==============================] 100% Epoch 25/50 - loss: 0.0021 - - 0.04s\n",
217-
"[==============================] 100% Epoch 26/50 - loss: 0.0020 - - 0.04s\n",
218-
"[==============================] 100% Epoch 27/50 - loss: 0.0018 - - 0.03s\n",
219-
"[==============================] 100% Epoch 28/50 - loss: 0.0017 - - 0.04s\n",
220-
"[==============================] 100% Epoch 29/50 - loss: 0.0017 - - 0.06s\n",
221-
"[==============================] 100% Epoch 30/50 - loss: 0.0018 - - 0.03s\n",
222-
"[==============================] 100% Epoch 31/50 - loss: 0.0020 - - 0.03s\n",
223-
"[==============================] 100% Epoch 32/50 - loss: 0.0024 - - 0.03s\n",
224-
"[==============================] 100% Epoch 33/50 - loss: 0.0030 - - 0.03s\n",
225-
"[==============================] 100% Epoch 34/50 - loss: 0.0086 - - 0.03s\n",
226-
"[==============================] 100% Epoch 35/50 - loss: 0.0030 - - 0.03s\n",
227-
"[==============================] 100% Epoch 36/50 - loss: 0.0030 - - 0.03s\n",
228-
"[==============================] 100% Epoch 37/50 - loss: 0.0079 - - 0.03s\n",
229-
"[==============================] 100% Epoch 38/50 - loss: 0.0032 - - 0.03s\n",
230-
"[==============================] 100% Epoch 39/50 - loss: 0.0035 - - 0.03s\n",
231-
"[==============================] 100% Epoch 40/50 - loss: 0.0043 - - 0.04s\n",
232-
"[==============================] 100% Epoch 41/50 - loss: 0.0093 - - 0.03s\n",
233-
"[==============================] 100% Epoch 42/50 - loss: 0.0043 - - 0.03s\n",
234-
"[==============================] 100% Epoch 43/50 - loss: 0.0042 - - 0.03s\n",
235-
"[==============================] 100% Epoch 44/50 - loss: 0.0044 - - 0.03s\n",
236-
"[==============================] 100% Epoch 45/50 - loss: 0.0047 - - 0.03s\n",
237-
"[==============================] 100% Epoch 46/50 - loss: 0.0093 - - 0.03s\n",
238-
"[==============================] 100% Epoch 47/50 - loss: 0.0039 - - 0.03s\n",
239-
"[==============================] 100% Epoch 48/50 - loss: 0.0034 - - 0.03s\n",
240-
"[==============================] 100% Epoch 49/50 - loss: 0.0032 - - 0.04s\n",
241-
"[==============================] 100% Epoch 50/50 - loss: 0.0030 - - 0.04s\n",
191+
"[==============================] 100% Epoch 1/50 - loss: 13.4469 - - 0.47s\n",
192+
"[==============================] 100% Epoch 2/50 - loss: 11.0863 - - 0.47s\n",
193+
"[==============================] 100% Epoch 3/50 - loss: 5.6369 - - 0.45s\n",
194+
"[==============================] 100% Epoch 4/50 - loss: 3.0461 - - 0.47s\n",
195+
"[==============================] 100% Epoch 5/50 - loss: 1.0563 - - 0.47s\n",
196+
"[==============================] 100% Epoch 6/50 - loss: 0.0950 - - 0.47s\n",
197+
"[==============================] 100% Epoch 7/50 - loss: 0.0527 - - 0.48s\n",
198+
"[==============================] 100% Epoch 8/50 - loss: 0.0339 - - 0.48s\n",
199+
"[==============================] 100% Epoch 9/50 - loss: 0.0211 - - 0.46s\n",
200+
"[==============================] 100% Epoch 10/50 - loss: 0.0163 - - 0.46s\n",
201+
"[==============================] 100% Epoch 11/50 - loss: 0.0121 - - 0.47s\n",
202+
"[==============================] 100% Epoch 12/50 - loss: 0.0085 - - 0.47s\n",
203+
"[==============================] 100% Epoch 13/50 - loss: 0.0055 - - 0.47s\n",
204+
"[==============================] 100% Epoch 14/50 - loss: 0.0044 - - 0.46s\n",
205+
"[==============================] 100% Epoch 15/50 - loss: 0.0047 - - 0.46s\n",
206+
"[==============================] 100% Epoch 16/50 - loss: 0.0053 - - 0.45s\n",
207+
"[==============================] 100% Epoch 17/50 - loss: 0.0060 - - 0.47s\n",
208+
"[==============================] 100% Epoch 18/50 - loss: 0.0069 - - 0.45s\n",
209+
"[==============================] 100% Epoch 19/50 - loss: 0.0079 - - 0.47s\n",
210+
"[==============================] 100% Epoch 20/50 - loss: 0.0090 - - 0.46s\n",
211+
"[==============================] 100% Epoch 21/50 - loss: 0.0102 - - 0.46s\n",
212+
"[==============================] 100% Epoch 22/50 - loss: 0.0118 - - 0.51s\n",
213+
"[==============================] 100% Epoch 23/50 - loss: 0.0164 - - 0.50s\n",
214+
"[==============================] 100% Epoch 24/50 - loss: 0.0209 - - 0.49s\n",
215+
"[==============================] 100% Epoch 25/50 - loss: 0.0248 - - 0.53s\n",
216+
"[==============================] 100% Epoch 26/50 - loss: 0.0269 - - 0.54s\n",
217+
"[==============================] 100% Epoch 27/50 - loss: 0.0124 - - 0.56s\n",
218+
"[==============================] 100% Epoch 28/50 - loss: 0.0110 - - 0.53s\n",
219+
"[==============================] 100% Epoch 29/50 - loss: 0.0099 - - 0.49s\n",
220+
"[==============================] 100% Epoch 30/50 - loss: 0.0089 - - 0.48s\n",
221+
"[==============================] 100% Epoch 31/50 - loss: 0.0077 - - 0.52s\n",
222+
"[==============================] 100% Epoch 32/50 - loss: 0.0066 - - 0.55s\n",
223+
"[==============================] 100% Epoch 33/50 - loss: 0.0053 - - 0.55s\n",
224+
"[==============================] 100% Epoch 34/50 - loss: 0.0040 - - 0.50s\n",
225+
"[==============================] 100% Epoch 35/50 - loss: 0.0036 - - 0.50s\n",
226+
"[==============================] 100% Epoch 36/50 - loss: 0.0036 - - 0.49s\n",
227+
"[==============================] 100% Epoch 37/50 - loss: 0.0035 - - 0.49s\n",
228+
"[==============================] 100% Epoch 38/50 - loss: 0.0034 - - 0.53s\n",
229+
"[==============================] 100% Epoch 39/50 - loss: 0.0033 - - 0.54s\n",
230+
"[==============================] 100% Epoch 40/50 - loss: 0.0032 - - 0.50s\n",
231+
"[==============================] 100% Epoch 41/50 - loss: 0.0031 - - 0.49s\n",
232+
"[==============================] 100% Epoch 42/50 - loss: 0.0029 - - 0.49s\n",
233+
"[==============================] 100% Epoch 43/50 - loss: 0.0028 - - 0.52s\n",
234+
"[==============================] 100% Epoch 44/50 - loss: 0.0026 - - 0.54s\n",
235+
"[==============================] 100% Epoch 45/50 - loss: 0.0025 - - 0.56s\n",
236+
"[==============================] 100% Epoch 46/50 - loss: 0.0023 - - 0.48s\n",
237+
"[==============================] 100% Epoch 47/50 - loss: 0.0022 - - 0.48s\n",
238+
"[==============================] 100% Epoch 48/50 - loss: 0.0021 - - 0.51s\n",
239+
"[==============================] 100% Epoch 49/50 - loss: 0.0021 - - 0.51s\n",
240+
"[==============================] 100% Epoch 50/50 - loss: 0.0020 - - 0.56s\n",
242241
"\n"
243242
]
244243
}
@@ -247,14 +246,14 @@
247246
"history = model.fit(\n",
248247
" x_train_padded, y_train_padded,\n",
249248
" epochs=50,\n",
250-
" batch_size=5,\n",
249+
" batch_size=2,\n",
251250
" verbose=True\n",
252251
")"
253252
]
254253
},
255254
{
256255
"cell_type": "code",
257-
"execution_count": 9,
256+
"execution_count": 19,
258257
"id": "c1dc335b",
259258
"metadata": {},
260259
"outputs": [
@@ -263,21 +262,21 @@
263262
"output_type": "stream",
264263
"text": [
265264
"FR: je suis heureux.\n",
266-
"EN: goodbye goodbye goodbye goodbye\n",
265+
"EN: you cats cats\n",
267266
"\n",
268267
"FR: comment allez-vous ?\n",
269-
"EN: goodbye goodbye goodbye goodbye\n",
268+
"EN: you cats cats\n",
270269
"\n",
271270
"FR: bonjour le monde.\n",
272-
"EN: goodbye goodbye goodbye goodbye\n",
271+
"EN: you cats cats\n",
273272
"\n"
274273
]
275274
}
276275
],
277276
"source": [
278277
"def translate(sentence: str, model, fr_tokenizer, en_tokenizer) -> str:\n",
279278
" tokens = fr_tokenizer.texts_to_sequences([sentence], preprocess_ponctuation=True)[0]\n",
280-
" tokens = [model.SOS_IDX] + [t + 4 for t in tokens] + [model.EOS_IDX] # Shift indices by 4\n",
279+
" tokens = [model.SOS_IDX] + tokens + [model.EOS_IDX]\n",
281280
" padded = pad_sequences([tokens], max_length=max_len_x, padding='post', pad_value=model.PAD_IDX)\n",
282281
" \n",
283282
" pred = model.predict(padded, max_length=max_seq_len)[0]\n",

0 commit comments

Comments
 (0)