-
Notifications
You must be signed in to change notification settings - Fork 2k
Use stopping criteria from transformers (and other minor transformer fixes) #1723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the awesome fixes!! 🤗
Just some minor comments.
torch_dtype (`str`, *optional*): | ||
The torch_dtype to initialize your model with. | ||
dtype (`str`, *optional*): | ||
The dtype to initialize your model with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a breaking change you introduced in transformers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes -- we'll guarantee BC until v5.0.0 I believe
|
||
# BC: previously the type was set through `torch_dtype`. `dtype` is now prefered | ||
torch_dtype = kwargs.pop("torch_dtype", None) | ||
dtype = dtype or torch_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see the explanation here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should emit a deprecation warning for smolagents users? What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, this makes some CI tests fail:
TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'dtype'
completion_kwargs["max_new_tokens"] = max_new_tokens | ||
return dict( | ||
inputs=prompt_tensor, | ||
**prompt_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure of understanding this change... 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inputs were being prepared such that prompt_tensor
only contained the input_ids
. However, depending on the models and usage, the corresponding attention_mask
(also returned by the tokenizer) may also be needed for a correct output. While using smolagents
with transformers
models, we could see a related warning being thrown :)
These changes make it so we pass all tokenizer encoding outputs (input_ids
AND attention_mask
) to model.generate
, and thus guarantee correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the clear explanation! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why we get these errors for some models? https://github.com/huggingface/smolagents/actions/runs/17267725950/job/49003833323?pr=1723
LlamaForCausalLM.__init__() got an unexpected keyword argument 'dtype'
LlavaForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'
FAILED tests/test_agents.py::TestAgent::test_transformers_toolcalling_agent - TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'dtype'
FAILED tests/test_models.py::TestModel::test_transformers_message_no_tool - TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'dtype'
FAILED tests/test_models.py::TestModel::test_transformers_message_vl_no_tool - ValueError: Failed to load tokenizer and model for model_id='llava-hf/llava-interleave-qwen-0.5b-hf': LlavaForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'
Not sure, I'll have to debug :D (can do it tomorrow) |
I am rerunning the tests after the latest
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the transformers 4.56.1 patch release, the previous dtype
errors disappeared.
- I think the current CI error could be easily fixed:
AssertionError: assert 'This is a photo' == 'This is a very'
However, we support transformers>=4.0.0
, and therefore, those dtype
errors could be an issue for some users.
I would suggest:
- This PR handles only the stopping criteria
- Leave the support for both
dtype
andtorch_dtype
for a subsequent PR
What do you think?
Related to #1703
This PR:
transformers
stopping criteria to stop on arbitrary strings, as opposed to using a custom class (more details below)torch_dtype
->dtype
on transformers-related code. This was a recent deprecation.generate
. Generating the attention mask on the fly frominput_ids
ingenerate
is very brittle and should be avoided.The custom stopping criteria defined in
smolagents
compares the generated text against the defined strings. If the generated text ends in the same strings, it stops generation.This has two problems:
foo bar
, but the model generatesfoo bar
orfoo bar:
, generation won't stop, and we probably want it to stop.Both issues are addressed in the class present in
transformers
, so let's use it instead 🤗