Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,14 @@ Text Generation arguments



- **generating_batch_size**: int

Default = 1

The number of samples to generate at the same time



## NeoXArgsTokenizer

Tokenizer Arguments
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,3 +974,8 @@ class NeoXArgsTextgen(NeoXArgsTemplate):
"""
Tasks to evaluate on using lm_eval_harness
"""

generating_batch_size: int = 1
"""
The number of samples to generate at the same time.
"""
58 changes: 32 additions & 26 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def stream_tokens(
last_token_index_to_generate = min(
neox_args.seq_length
- 1, # never generate more than the model's sequence length
token_index_to_generate + maximum_tokens - 1,
token_generation_start_index.max().item() + maximum_tokens - 1,
)

with torch.no_grad():
Expand Down Expand Up @@ -351,6 +351,10 @@ def stream_tokens(
token_generation_start_index <= token_index_to_generate
) # check which batch items have been started

state_started = state_started & (
token_generation_start_index + maximum_tokens > token_index_to_generate
) # check which batch items have been ended

# switch out padding tokens for generated tokens
context_tokens[:, token_index_to_generate] = switch(
context_tokens[:, token_index_to_generate].view(-1),
Expand Down Expand Up @@ -439,29 +443,31 @@ def generate_samples_from_prompt(

start_time = time.time()
# Tokenize text, and check whether we should terminate process
batch_size = min(neox_args.generating_batch_size, input_count - input_pos)
terminate_runs = 0
if input_pos == input_count:
terminate_runs = 1
else:
raw_text = text[input_pos]
input_pos += 1

if raw_text == "":
context_tokens = [eos_token_id]
else:
context_tokens = neox_args.tokenizer.tokenize(raw_text)
context_length = len(context_tokens)

if context_length >= (neox_args.seq_length // 2):
print_rank_0(
"\nWarning! Context length",
context_length,
"\nPlease give smaller context (e.g. half of the "
"max sequence length)!",
)
context_tokens_list = []
for pos in range(input_pos, input_pos + batch_size):
raw_text = text[pos]
if raw_text == "":
context_tokens = [eos_token_id]
else:
context_tokens = neox_args.tokenizer.tokenize(raw_text)
context_length = len(context_tokens)
if context_length >= (neox_args.seq_length // 2):
print_rank_0(
"\nWarning! Context length",
context_length,
"\nPlease give smaller context (e.g. half of the "
"max sequence length)!",
)
context_tokens_list.append(context_tokens)
input_pos += batch_size
if not is_mp_rank_0():
context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens)
context_tokens_list = [
neox_args.tokenizer.tokenize("EMPTY TEXT") for _ in range(batch_size)]
terminate_runs = 0

terminate_runs = broadcast_terminate_signal(terminate_runs)
Expand All @@ -476,7 +482,7 @@ def generate_samples_from_prompt(
) in stream_tokens(
neox_args=neox_args,
model=model,
context_tokens=[context_tokens],
context_tokens=context_tokens_list,
eos_token_id=eos_token_id,
maximum_tokens=maximum_tokens,
recompute=recompute,
Expand All @@ -496,12 +502,12 @@ def generate_samples_from_prompt(
)
batch_is_done = is_done.cpu().numpy().tolist()

for tokens, start_index, end_index, is_done in zip(
batch_context_tokens,
batch_token_generation_start_index,
batch_token_generation_end_index,
batch_is_done,
):
for i in range(batch_size):
tokens = batch_context_tokens[i]
start_index = batch_token_generation_start_index[i]
end_index = batch_token_generation_end_index[i]
is_done = batch_is_done[i]
raw_text = text[input_pos - batch_size + i]

if end_index >= start_index:
generated_tokens = tokens[start_index : end_index + 1]
Expand Down