Skip to content

Commit 70ff7b8

Browse files
authored
Text Generation Functions: Add Benchmark Script (#342)
* Add text gen benchmark * Fixes * Augment config * Fix top-p * Scale up Transformer * Fix top p search * Increase max length * Fix bug in top p search * Small change * Print config at the end of test * Scale up model * Small change * Reduce num_samples * Reduce max length * Raise exceptions on wrong inputs * Enhance error msg * Address comments - I * Fix imports * Address comments - II * Delete config file * Refactor config * Address NIT * Fix * Add README, add Beam Search * Add BS run * Address review comments * Add table in README
1 parent a8e0dbe commit 70ff7b8

File tree

4 files changed

+259
-1
lines changed

4 files changed

+259
-1
lines changed

keras_nlp/benchmarks/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# KerasNLP Benchmarks
2+
3+
This directory houses a collection of scripts for benchmarking APIs and utility
4+
functions which KerasNLP provides.
5+
6+
## Text Generation
7+
For benchmarking text generation functions, the following command can be run
8+
from the root of the repository:
9+
10+
```sh
11+
python3 ./keras_nlp/benchmarks/text_generation.py
12+
```
13+
14+
On running this script on Google Colab (with Tesla T4 GPU, and TensorFlow 2.10.0),
15+
the following results were obtained:
16+
17+
| **Decoding Strategy** | **Graph Mode (sec)** | **Graph Mode with XLA (sec)** |
18+
|:---------------------: |:--------------------: |:-----------------------------: |
19+
| Greedy Search | 495.78 | 293.77 |
20+
| Beam Search | 564.23 | 615.17 |
21+
| Random Search | 446.55 | 296.21 |
22+
| Top-k Search | 458.68 | 302.66 |
23+
| Top-p Search | 468.63 | 565.50 |
24+
25+
To change the configuration, say, for example, number of layers in the transformer
26+
model used for inference, the user can modify the config dictionaries given at
27+
the top of the script.

keras_nlp/benchmarks/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmark for text generation."""
16+
17+
import time
18+
19+
import tensorflow as tf
20+
from tensorflow import keras
21+
22+
import keras_nlp
23+
from keras_nlp.utils import beam_search
24+
from keras_nlp.utils import greedy_search
25+
from keras_nlp.utils import random_search
26+
from keras_nlp.utils import top_k_search
27+
from keras_nlp.utils import top_p_search
28+
29+
SEED = 42
30+
31+
DATASET_ARGS = {
32+
"vocab_size": 40000,
33+
"num_samples": 1000,
34+
"batch_size": 2,
35+
}
36+
37+
TEXT_GEN_ARGS = {
38+
"max_length": 64,
39+
"end_token_id": 2,
40+
"pad_token_id": 0,
41+
}
42+
43+
MODEL_ARGS = {
44+
"max_length": 300,
45+
"embed_dim": 768,
46+
"num_layers": 8,
47+
"num_heads": 8,
48+
"ff_dim": 3072,
49+
}
50+
51+
TEST_RUNS = [
52+
{
53+
"decoding_fn": greedy_search,
54+
"execution_methods": ["xla", "graph"],
55+
"args": TEXT_GEN_ARGS,
56+
},
57+
{
58+
"decoding_fn": beam_search,
59+
"execution_methods": ["xla", "graph"],
60+
"args": {
61+
"num_beams": 2,
62+
"from_logits": True,
63+
**TEXT_GEN_ARGS,
64+
},
65+
},
66+
{
67+
"decoding_fn": random_search,
68+
"execution_methods": ["xla", "graph"],
69+
"args": {
70+
"seed": SEED,
71+
"from_logits": True,
72+
**TEXT_GEN_ARGS,
73+
},
74+
},
75+
{
76+
"decoding_fn": top_k_search,
77+
"execution_methods": ["xla", "graph"],
78+
"args": {
79+
"k": 5,
80+
"seed": SEED,
81+
"from_logits": True,
82+
**TEXT_GEN_ARGS,
83+
},
84+
},
85+
{
86+
"decoding_fn": top_p_search,
87+
"execution_methods": ["xla", "graph"],
88+
"args": {
89+
"p": 0.9,
90+
"seed": SEED,
91+
"from_logits": True,
92+
**TEXT_GEN_ARGS,
93+
},
94+
},
95+
]
96+
97+
98+
def generate_random_ds(vocab_size, num_samples, batch_size, seed):
99+
prompt_length = 2
100+
inputs = tf.random.uniform(
101+
shape=(num_samples, prompt_length),
102+
minval=0,
103+
maxval=vocab_size - 1,
104+
dtype=tf.dtypes.int32,
105+
seed=seed,
106+
)
107+
108+
ds = tf.data.Dataset.from_tensor_slices(inputs)
109+
ds = ds.batch(batch_size)
110+
return ds
111+
112+
113+
def build_model(
114+
vocab_size, max_length, embed_dim, num_layers, num_heads, ff_dim
115+
):
116+
inputs = keras.layers.Input(shape=(None,), dtype=tf.int32)
117+
# Embedding.
118+
x = keras_nlp.layers.TokenAndPositionEmbedding(
119+
vocabulary_size=vocab_size,
120+
sequence_length=max_length,
121+
embedding_dim=embed_dim,
122+
mask_zero=True,
123+
)(inputs)
124+
# Transformer decoders.
125+
for _ in range(num_layers):
126+
x = keras_nlp.layers.TransformerDecoder(
127+
num_heads=num_heads,
128+
intermediate_dim=ff_dim,
129+
)(x)
130+
# Output.
131+
outputs = keras.layers.Dense(vocab_size)(x)
132+
model = keras.Model(inputs=inputs, outputs=outputs)
133+
return model
134+
135+
136+
def generate_text(
137+
decoding_fn,
138+
token_probability_fn,
139+
prompt,
140+
text_gen_args,
141+
jit_compile,
142+
):
143+
class TestModel(tf.keras.Model):
144+
def call(self, inputs):
145+
generated = decoding_fn(
146+
token_probability_fn=token_probability_fn,
147+
prompt=inputs,
148+
**text_gen_args,
149+
)
150+
return generated
151+
152+
test_model = TestModel()
153+
test_model.compile(jit_compile=jit_compile)
154+
155+
t0 = time.time()
156+
_ = test_model.predict(prompt)
157+
return time.time() - t0
158+
159+
160+
def main():
161+
keras.utils.set_random_seed(SEED)
162+
csv_path = time.strftime("text_gen_%Y-%m-%d_%H-%M-%S.csv")
163+
164+
ds = generate_random_ds(
165+
vocab_size=DATASET_ARGS["vocab_size"],
166+
num_samples=DATASET_ARGS["num_samples"],
167+
batch_size=DATASET_ARGS["batch_size"],
168+
seed=SEED,
169+
)
170+
171+
model = build_model(
172+
vocab_size=DATASET_ARGS["vocab_size"],
173+
max_length=MODEL_ARGS["max_length"],
174+
embed_dim=MODEL_ARGS["embed_dim"],
175+
num_layers=MODEL_ARGS["num_layers"],
176+
num_heads=MODEL_ARGS["num_heads"],
177+
ff_dim=MODEL_ARGS["ff_dim"],
178+
)
179+
180+
def token_logits_fn(inputs):
181+
output = model(inputs)
182+
return output[:, -1, :]
183+
184+
print("*************************************\n")
185+
186+
with open(csv_path, "w") as res_handler:
187+
res_handler.write("decoding_strategy,execution_method,time\n")
188+
for test_run in TEST_RUNS:
189+
decoding_fn = test_run["decoding_fn"]
190+
decoding_strategy = decoding_fn.__name__
191+
192+
for execution_method in test_run["execution_methods"]:
193+
print(f"Running {decoding_strategy} in {execution_method} mode")
194+
195+
if execution_method == "graph":
196+
jit_compile = False
197+
elif execution_method == "xla":
198+
jit_compile = True
199+
200+
time_taken = generate_text(
201+
decoding_fn=decoding_fn,
202+
token_probability_fn=token_logits_fn,
203+
prompt=ds,
204+
text_gen_args=test_run["args"],
205+
jit_compile=jit_compile,
206+
)
207+
print("Time taken: ", time_taken)
208+
res_handler.write(
209+
f"{decoding_strategy},{execution_method}," f"{time_taken}\n"
210+
)
211+
print()
212+
print("*************************************")
213+
214+
print(f"Writing results to {csv_path}")
215+
216+
217+
if __name__ == "__main__":
218+
main()

keras_nlp/utils/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def one_step(length, prompt):
802802
probs = tf.where(
803803
shifted_keep_mask,
804804
sorted_preds,
805-
tf.zeros(pred.shape, dtype=sorted_preds.dtype),
805+
tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype),
806806
)
807807
sorted_next_token = tf.random.categorical(
808808
tf.math.log(probs), 1, seed=seed

0 commit comments

Comments
 (0)