-
Notifications
You must be signed in to change notification settings - Fork 30.3k
[generate] PromptLookupCandidateGenerator
won't generate forbidden tokens
#40726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super!!!!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great thanks! One question, should we raise a warning that prompt decoder will not filter candidates by logits so we can't support all generation config params?
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`, | ||
# All models below: shouldn't suggest audio tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan | ||
"voxtral", | ||
"qwen2audio", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yaaaaaay 💃🏻
@zucchini-nlp great comment. We can indeed make it work with ANY logits processor that blocks tokens, not just The logic is slightly more complex, but it is added in the latest commit. In a nutshell, we simulate using the logits processor with fake input logits, using the same processors as we use for the main model, and any selected token with Regarding long-term maintenance: if we're keeping assisted generation (we are for now), we're also keeping this one. This is the candidate-based generation strategy with the fewest requirements. |
PromptLookupCandidateGenerator
accepts bad_words_ids
PromptLookupCandidateGenerator
won't generate forbidden tokens
max_matching_ngram_size (`int`): | ||
The maximum ngram size to be considered for matching in the prompt | ||
num_output_tokens (`int`): | ||
eos_token_id (`torch.Tensor`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(docstring args were out of order, and some were missing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, happy that it worked well! Thanks for aligning prompt decoder with common generation API. Left one tiny question otherwise lgtm
if self.logits_processor is not None: | ||
sequence_with_candidate = input_ids | ||
for candidate_idx, new_candidate_token in enumerate(chosen_ids): | ||
fake_input_logits = torch.ones((bsz, self.vocab_size), device=input_ids.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prob we can create a fake logit once, since it's always same. Can be helpful with models like gemma with huuge vocab size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought so at first, but then considered that in-place ops (in custom logits processors) might behave poorly. Double-checked it now, it seems resilient to in-place ops -> changed it.
[For maintainers] Suggested jobs to run (before merge) run-slow: idefics2, idefics3, qwen2_5_vl, smolvlm |
What does this PR do?
See title.
In the process, deflakes a lot of tests :) (It was generating forbidden tokens in tests, like image tokens in VLMs -- now we can specify which sequences are forbidden)
py.test tests/models/voxtral/test_modeling_voxtral.py -k test_prompt_lookup_decoding_matches_greedy_search --flake-finder --flake-runs 1000
now runs with without problems.