-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[generate] PromptLookupCandidateGenerator
won't generate forbidden tokens
#40726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
f2090b8
46cc6d5
1f72346
6cb81b4
e87e275
d57ee31
f3453cc
4589257
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1004,26 +1004,39 @@ class PromptLookupCandidateGenerator(CandidateGenerator): | |
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding | ||
|
||
Args: | ||
max_matching_ngram_size (`int`): | ||
The maximum ngram size to be considered for matching in the prompt | ||
num_output_tokens (`int`): | ||
eos_token_id (`torch.Tensor`, *optional*): | ||
The token id of the end of sequence token. | ||
num_output_tokens (`int`, *optional*, defaults to 10): | ||
The number of tokens to be output as candidate tokens. | ||
max_length (`int`): | ||
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. | ||
Defaults to 20, which is the max length used as default in generation config. | ||
max_matching_ngram_size (`int`, *optional*, defaults to 2): | ||
The maximum ngram size to be considered for matching in the prompt | ||
max_length (`int`, *optional*, defaults to 20): | ||
The number of total maximum tokens that can be generated. For decoder-only models that includes the | ||
prompt length. Defaults to 20, which is the max length used as default in generation config. | ||
logits_processor (`LogitsProcessorList`, *optional*): | ||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | ||
used to modify the prediction scores of the language modeling head applied at each generation step. In | ||
prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find | ||
forbidden tokens (p = -inf) and block them from being valid candidates. | ||
vocab_size (`int`, *optional*): | ||
The size of the vocabulary. Required if `logits_processor` is provided. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
eos_token_id: Optional[torch.Tensor] = None, | ||
num_output_tokens: int = 10, | ||
max_matching_ngram_size: Optional[int] = None, | ||
max_matching_ngram_size: int = 2, | ||
max_length: int = 20, | ||
logits_processor: Optional["LogitsProcessorList"] = None, | ||
vocab_size: Optional[int] = None, | ||
): | ||
self.num_output_tokens = num_output_tokens | ||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 | ||
self.max_matching_ngram_size = max_matching_ngram_size | ||
self.max_length = max_length | ||
self.eos_token_id = eos_token_id | ||
self.logits_processor = logits_processor | ||
self.vocab_size = vocab_size | ||
|
||
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: | ||
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") | ||
|
@@ -1039,7 +1052,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, | |
Return: | ||
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. | ||
""" | ||
input_length = input_ids.size(1) | ||
bsz, input_length = input_ids.shape | ||
|
||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token. | ||
if self.max_length == input_length + 1: | ||
|
@@ -1061,13 +1074,41 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, | |
match_indices = matches.nonzero(as_tuple=True)[1] | ||
|
||
# Iterate through match indices to find a valid continuation | ||
# TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the | ||
# longest valid candidates? | ||
for idx in match_indices: | ||
start_idx = idx + ngram_size | ||
end_idx = start_idx + self.num_output_tokens | ||
end_idx = min(end_idx, input_length, self.max_length) | ||
|
||
if start_idx < end_idx: | ||
chosen_ids = input_ids[0, start_idx:end_idx] | ||
|
||
# Check if the each new candidate token is forbidden according to the logits processor. If all | ||
# tokens are allowed, we keep `chosen_ids` as is. | ||
# 1. create random logits. | ||
# 2. apply the logits processor to get output logits for the next token, using the arbitrary | ||
# logits as input. | ||
# 3. compare the output logits with the next candidate token. If they are -inf, then the next | ||
# candidate token is forbidden and we don't want to generate it. | ||
if self.logits_processor is not None: | ||
sequence_with_candidate = input_ids | ||
for candidate_idx, new_candidate_token in enumerate(chosen_ids): | ||
fake_input_logits = torch.ones((bsz, self.vocab_size), device=input_ids.device) | ||
|
||
fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits) | ||
fake_candidate_logits = fake_output_logits[0, new_candidate_token] | ||
# next candidate token is forbidden -> crop chosen_ids accordingly | ||
if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min): | ||
chosen_ids = chosen_ids[:candidate_idx] | ||
break | ||
else: | ||
sequence_with_candidate = torch.cat( | ||
(input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1 | ||
) | ||
# no valid candidate tokens -> look for a different match | ||
if chosen_ids.shape[0] == 0: | ||
continue | ||
|
||
match_found = True | ||
|
||
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may | ||
|
@@ -1082,8 +1123,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, | |
if match_found: | ||
break | ||
|
||
if chosen_ids is None or len(chosen_ids) == 0: | ||
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding | ||
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding | ||
if not match_found or len(chosen_ids) == 0: | ||
return input_ids, None | ||
|
||
# Now need extend input_ids with chosen_ids | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -779,27 +779,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): | |
"blip2", # overridden `generate()` for all BLIP models | ||
"instructblip", | ||
"instructblipvideo", | ||
# TODO: The list is growing huge 🙃! Let's try to check if the config has any of audio/image/video token id and skip the test! | ||
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan | ||
"llava", | ||
"idefics2", | ||
"idefics3", | ||
"mllama", | ||
"paligemma", | ||
"emu3", | ||
"gotocr2", | ||
"qwen2vl", | ||
"qwen2_5_vl", | ||
"ayavision", | ||
"janus", | ||
"gemma3", | ||
"mistral3", | ||
"chameleon", | ||
"internvl", | ||
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`, | ||
# All models below: shouldn't suggest audio tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan | ||
"voxtral", | ||
"qwen2audio", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yaaaaaay 💃🏻 |
||
] | ||
): | ||
self.skipTest(reason="May fix in the future: need model-specific fixes") | ||
|
@@ -835,11 +814,12 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): | |
"return_dict_in_generate": True, | ||
"use_cache": True, | ||
} | ||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) | ||
|
||
output_greedy = model.generate(**generation_kwargs, **inputs_dict) | ||
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) | ||
|
||
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) | ||
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict) | ||
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) | ||
|
||
# The two outputs must match and their shape must be as expected | ||
self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(docstring args were out of order, and some were missing)