Skip to content

Commit 3df60d6

Browse files
authored
Add GPT-2 Model and its Variants (#354)
* Add GPT-2 * Small edit * Fix doc-string * Fix some typos * Rename GPTXLarge to GPTExtraLarge * Small changes * Small fix * Format * Address comments - II * Reverse other file changes * Small edit * Update doc-string * Remove param ct
1 parent 9bc5d37 commit 3df60d6

File tree

3 files changed

+383
-0
lines changed

3 files changed

+383
-0
lines changed

keras_nlp/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from keras_nlp.models.bert import BertPreprocessor
2121
from keras_nlp.models.bert import BertSmall
2222
from keras_nlp.models.bert import BertTiny
23+
from keras_nlp.models.gpt2 import Gpt2Base
24+
from keras_nlp.models.gpt2 import Gpt2Custom
25+
from keras_nlp.models.gpt2 import Gpt2ExtraLarge
26+
from keras_nlp.models.gpt2 import Gpt2Large
27+
from keras_nlp.models.gpt2 import Gpt2Medium
2328
from keras_nlp.models.roberta import RobertaBase
2429
from keras_nlp.models.roberta import RobertaClassifier
2530
from keras_nlp.models.roberta import RobertaCustom

keras_nlp/models/gpt2.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""GPT-2 model configurable class, preconfigured versions, and task heads."""
16+
17+
import tensorflow as tf
18+
from tensorflow import keras
19+
20+
from keras_nlp.layers import PositionEmbedding
21+
from keras_nlp.layers import TransformerDecoder
22+
23+
24+
def _gpt_2_kernel_initializer(stddev=0.02):
25+
return keras.initializers.RandomNormal(stddev=stddev)
26+
27+
28+
class Gpt2Custom(keras.Model):
29+
"""GPT-2 core network with customizable hyperparameters.
30+
31+
This network implements a Transformer-based decoder network,
32+
Generative Pretrained Transformer-2 (GPT-2), as described in
33+
["Language Models are Unsupervised Multitask Learners"](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).
34+
It includes the embedding lookups and transformer layers.
35+
36+
This class gives a fully customizable GPT-2 model with any number of layers,
37+
heads, and embedding dimensions. For specific GPT-2 architectures
38+
defined in the paper, see, for example, `keras_nlp.models.Gpt2Base`.
39+
40+
Args:
41+
vocabulary_size: int. The size of the token vocabulary.
42+
num_layers: int. The number of transformer layers.
43+
num_heads: int. The number of attention heads for each transformer.
44+
The hidden size must be divisible by the number of attention heads.
45+
hidden_dim: int. The size of the transformer encoding and pooler layers.
46+
intermediate_dim: int. The output dimension of the first Dense layer in
47+
a two-layer feedforward network for each transformer.
48+
dropout: float. Dropout probability for the Transformer encoder.
49+
max_sequence_length: int. The maximum sequence length that this encoder
50+
can consume. If None, `max_sequence_length` uses the value from
51+
sequence length. This determines the variable shape for positional
52+
embeddings.
53+
name: string, optional. Name of the model.
54+
trainable: boolean, optional. If the model's variables should be
55+
trainable.
56+
57+
Example usage:
58+
```python
59+
# Randomly initialized GPT-2 decoder
60+
model = keras_nlp.models.Gpt2Custom(
61+
vocabulary_size=50257,
62+
num_layers=12,
63+
num_heads=12,
64+
hidden_dim=768,
65+
intermediate_dim=3072,
66+
max_sequence_length=1024,
67+
name="encoder",
68+
)
69+
70+
# Call encoder on the inputs
71+
input_data = {
72+
"token_ids": tf.random.uniform(
73+
shape=(1, 12), dtype=tf.int64, maxval=model.vocabulary_size
74+
),
75+
"padding_mask": tf.constant(
76+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
77+
),
78+
}
79+
output = model(input_data)
80+
```
81+
"""
82+
83+
def __init__(
84+
self,
85+
vocabulary_size,
86+
num_layers,
87+
num_heads,
88+
hidden_dim,
89+
intermediate_dim,
90+
dropout=0.1,
91+
max_sequence_length=1024,
92+
name=None,
93+
trainable=True,
94+
):
95+
96+
# Inputs
97+
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
98+
padding_mask = keras.Input(
99+
shape=(None,), dtype="int32", name="padding_mask"
100+
)
101+
102+
# Embed tokens, positions.
103+
token_embedding = keras.layers.Embedding(
104+
input_dim=vocabulary_size,
105+
output_dim=hidden_dim,
106+
embeddings_initializer=_gpt_2_kernel_initializer(stddev=0.01),
107+
name="token_embedding",
108+
)(token_ids)
109+
110+
# Can't use `TokenAndPositionEmbedding` layer here because of different
111+
# initializers.
112+
position_embedding = PositionEmbedding(
113+
initializer=_gpt_2_kernel_initializer(stddev=0.02),
114+
sequence_length=max_sequence_length,
115+
name="position_embedding",
116+
)(token_embedding)
117+
118+
# Sum and apply dropout to embeddings.
119+
x = keras.layers.Add()((token_embedding, position_embedding))
120+
x = keras.layers.Dropout(
121+
dropout,
122+
name="embeddings_dropout",
123+
)(x)
124+
125+
# Apply successive transformer decoder blocks.
126+
for i in range(num_layers):
127+
x = TransformerDecoder(
128+
intermediate_dim=intermediate_dim,
129+
num_heads=num_heads,
130+
dropout=dropout,
131+
activation=lambda x: keras.activations.gelu(
132+
x, approximate=True
133+
),
134+
layer_norm_epsilon=1e-05,
135+
kernel_initializer=_gpt_2_kernel_initializer(stddev=0.02),
136+
normalize_first=True,
137+
name=f"transformer_layer_{i}",
138+
)(x, decoder_padding_mask=padding_mask)
139+
140+
sequence_output = keras.layers.LayerNormalization(
141+
name="layer_norm",
142+
axis=-1,
143+
epsilon=1e-05,
144+
dtype=tf.float32,
145+
)(x)
146+
147+
# Instantiate using Functional API Model constructor
148+
super().__init__(
149+
inputs={
150+
"token_ids": token_ids,
151+
"padding_mask": padding_mask,
152+
},
153+
outputs=sequence_output,
154+
name=name,
155+
trainable=trainable,
156+
)
157+
# All references to `self` below this line
158+
self.vocabulary_size = vocabulary_size
159+
self.num_layers = num_layers
160+
self.num_heads = num_heads
161+
self.hidden_dim = hidden_dim
162+
self.intermediate_dim = intermediate_dim
163+
self.dropout = dropout
164+
self.max_sequence_length = max_sequence_length
165+
166+
def get_config(self):
167+
config = super().get_config()
168+
config.update(
169+
{
170+
"vocabulary_size": self.vocabulary_size,
171+
"num_layers": self.num_layers,
172+
"num_heads": self.num_heads,
173+
"hidden_dim": self.hidden_dim,
174+
"intermediate_dim": self.intermediate_dim,
175+
"dropout": self.dropout,
176+
"max_sequence_length": self.max_sequence_length,
177+
}
178+
)
179+
return config
180+
181+
182+
MODEL_DOCSTRING = """GPT-2 "{type}" architecture.
183+
184+
This network implements a Transformer-based decoder as
185+
described in
186+
["Language Models are Unsupervised Multitask Learners"](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).
187+
It includes the embedding lookups and transformer layers.
188+
189+
Args:
190+
vocabulary_size: int, optional. The size of the token vocabulary.
191+
name: String, optional. Name of the model.
192+
trainable: boolean, optional. If the model's variables should be
193+
trainable.
194+
195+
Example usage:
196+
```python
197+
# Randomly initialized Gpt2{type} encoder
198+
model = keras_nlp.models.Gpt2{type}(vocabulary_size=10000)
199+
200+
# Call encoder on the inputs.
201+
input_data = {{
202+
"token_ids": tf.random.uniform(
203+
shape=(1, 1024), dtype=tf.int64, maxval=model.vocabulary_size
204+
),
205+
"padding_mask": tf.constant([1] * 1024, shape=(1, 1024)),
206+
}}
207+
output = model(input_data)
208+
"""
209+
210+
211+
def Gpt2Base(vocabulary_size, name=None, trainable=True):
212+
return Gpt2Custom(
213+
vocabulary_size=vocabulary_size,
214+
num_layers=12,
215+
num_heads=12,
216+
hidden_dim=768,
217+
intermediate_dim=3072,
218+
dropout=0.1,
219+
max_sequence_length=1024,
220+
name=name,
221+
trainable=trainable,
222+
)
223+
224+
225+
def Gpt2Medium(vocabulary_size, name=None, trainable=True):
226+
return Gpt2Custom(
227+
vocabulary_size=vocabulary_size,
228+
num_layers=24,
229+
num_heads=16,
230+
hidden_dim=1024,
231+
intermediate_dim=4096,
232+
dropout=0.1,
233+
max_sequence_length=1024,
234+
name=name,
235+
trainable=trainable,
236+
)
237+
238+
239+
def Gpt2Large(vocabulary_size, name=None, trainable=True):
240+
return Gpt2Custom(
241+
vocabulary_size=vocabulary_size,
242+
num_layers=36,
243+
num_heads=20,
244+
hidden_dim=1280,
245+
intermediate_dim=5120,
246+
dropout=0.1,
247+
max_sequence_length=1024,
248+
name=name,
249+
trainable=trainable,
250+
)
251+
252+
253+
def Gpt2ExtraLarge(vocabulary_size, name=None, trainable=True):
254+
return Gpt2Custom(
255+
vocabulary_size=vocabulary_size,
256+
num_layers=48,
257+
num_heads=25,
258+
hidden_dim=1600,
259+
intermediate_dim=6400,
260+
dropout=0.1,
261+
max_sequence_length=1024,
262+
name=name,
263+
trainable=trainable,
264+
)
265+
266+
267+
setattr(
268+
Gpt2Base,
269+
"__doc__",
270+
MODEL_DOCSTRING.format(type="Base"),
271+
)
272+
setattr(
273+
Gpt2Medium,
274+
"__doc__",
275+
MODEL_DOCSTRING.format(type="Medium"),
276+
)
277+
setattr(
278+
Gpt2Large,
279+
"__doc__",
280+
MODEL_DOCSTRING.format(type="Large"),
281+
)
282+
setattr(
283+
Gpt2ExtraLarge,
284+
"__doc__",
285+
MODEL_DOCSTRING.format(type="ExtraLarge"),
286+
)

keras_nlp/models/gpt2_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for GPT-2 model."""
15+
16+
import os
17+
18+
import tensorflow as tf
19+
from absl.testing import parameterized
20+
from tensorflow import keras
21+
22+
from keras_nlp.models import gpt2
23+
24+
25+
class Gpt2Test(tf.test.TestCase, parameterized.TestCase):
26+
def setUp(self):
27+
self.model = gpt2.Gpt2Custom(
28+
vocabulary_size=1000,
29+
num_layers=2,
30+
num_heads=2,
31+
hidden_dim=64,
32+
intermediate_dim=128,
33+
max_sequence_length=128,
34+
name="gpt2_test",
35+
)
36+
self.batch_size = 8
37+
self.input_batch = {
38+
"token_ids": tf.ones(
39+
(self.batch_size, self.model.max_sequence_length), dtype="int32"
40+
),
41+
"padding_mask": tf.ones(
42+
(self.batch_size, self.model.max_sequence_length), dtype="int32"
43+
),
44+
}
45+
46+
self.input_dataset = tf.data.Dataset.from_tensor_slices(
47+
self.input_batch
48+
).batch(2)
49+
50+
def test_valid_call_gpt2(self):
51+
self.model(self.input_batch)
52+
53+
def test_variable_sequence_length_call_gpt2(self):
54+
for seq_length in (25, 50, 75):
55+
input_data = {
56+
"token_ids": tf.ones(
57+
(self.batch_size, seq_length), dtype="int32"
58+
),
59+
"padding_mask": tf.ones(
60+
(self.batch_size, seq_length), dtype="int32"
61+
),
62+
}
63+
self.model(input_data)
64+
65+
def test_valid_call_gpt2_base(self):
66+
model = gpt2.Gpt2Base(vocabulary_size=1000, name="gpt2_base_test")
67+
model(self.input_batch)
68+
69+
@parameterized.named_parameters(
70+
("jit_compile_false", False), ("jit_compile_true", True)
71+
)
72+
def test_gpt2_base_compile(self, jit_compile):
73+
model = gpt2.Gpt2Base(vocabulary_size=1000, name="gpt2_base_test")
74+
model.compile(jit_compile=jit_compile)
75+
model.predict(self.input_batch)
76+
77+
@parameterized.named_parameters(
78+
("jit_compile_false", False), ("jit_compile_true", True)
79+
)
80+
def test_gpt2_base_compile_batched_ds(self, jit_compile):
81+
model = gpt2.Gpt2Base(vocabulary_size=1000, name="gpt2_base_test")
82+
model.compile(jit_compile=jit_compile)
83+
model.predict(self.input_dataset)
84+
85+
def test_saving_model(self):
86+
model_output = self.model(self.input_batch)
87+
save_path = os.path.join(self.get_temp_dir(), "model")
88+
self.model.save(save_path)
89+
restored_model = keras.models.load_model(save_path)
90+
91+
restored_output = restored_model(self.input_batch)
92+
self.assertAllClose(model_output, restored_output)

0 commit comments

Comments
 (0)