-
Notifications
You must be signed in to change notification settings - Fork 2k
Description
Problem
Hi!
I have discovered that stop_sequences
used in TransformersModel.generate
do not always work as intended.
In particular, when running CodeAgent with default stop_sequences = ["Observation:", "Calling tools:"]
I still see (sometimes multiple) "Calling tools:" strings in assistant's messages.
Steps to reproduce
Here is an example to reproduce the bug.
from smolagents import CodeAgent, TransformersModel
model = TransformersModel("Qwen/Qwen2.5-Coder-7B-Instruct")
agent = CodeAgent(
tools=[],
model=model,
)
# Ensure multi-turn code generation so that the agent sees "Calling tools:" in the message history
agent.run("Concatenate strings 'foo' and 'bar'")
agent.run("Concatenate strings 'foo' and 'bar' again", reset=False)
# Print only assistant's messages (generated by the model)
for m in agent.write_memory_to_messages():
if m.role == "assistant":
print(m.content[0]["text"])
print(100 * '-')
Actual behavior
Here is what printed:
Thought: To concatenate the strings 'foo' and 'bar', I will use Python's string concatenation operator `+`.
<code>
result = 'foo' + 'bar'
final_answer(result)
</code>
----------------------------------------------------------------------------------------------------
Thought: Since we already successfully concatenated 'foo' and 'bar' in the previous task, I will repeat the same process to ensure consistency.
<code>
result = 'foo' + 'bar'
final_answer(result)
</code>
Calling tools:
[{'id': 'call_2', 'type': 'function', 'function': {'name': 'python_interpreter', 'arguments': "result = 'foo' + 'bar'\nfinal_answer(result)"}}]</code>
----------------------------------------------------------------------------------------------------
As can be seen, the agent adds "Calling tools:" to the last generated message, which should not be there.
Environment:
I am using the latest (as of now) released version smolagents[toolkit,transformers]==1.21.2
.
Other environment details are not relevant here.
Reason:
The reason for such behavior lies in the current implementation of make_stopping_criteria
.
To be more precise, StopOnStrings.__call__
checks for the presence of stop strings at the very end of the generation stream, which is appended with newly generated tokens. However, since tokenizer decodes tokens in chunks of chars, a stop string can be added to the stream with some extra trailing chars.
For instance, a "Calling tools:\n" string would be tokenized as ["Calling", " tools", ":\n"], so it is added to the stream with trailing '\n'; therefore, exact string matching via endswith
fails to detect the stop string "Calling tools:".
Fix:
I propose the following version of StopOnStrings
, which solves this problem:
class StopOnStrings(StoppingCriteria):
def __init__(self, stop_strings: list[str], tokenizer):
self.stop_strings = stop_strings
self.tokenizer = tokenizer
self.prev_tail = ""
if any(s == "" for s in self.stop_strings):
raise ValueError("stop_strings must be non-empty")
# Keep only last (max_len - 1) chars
self.max_tail_len = max(len(s) for s in self.stop_strings) - 1
def reset(self):
self.prev_tail = ""
def __call__(self, input_ids, scores, **kwargs):
"""
Return True if any stop string appears in the generated tokens stream; otherwise return False.
"""
generated = self.tokenizer.decode(
input_ids[0][-1], skip_special_tokens=True
)
if not generated:
return False
# Build the minimal search window that contains every possible
# newly created occurrence of any stop string
window = self.prev_tail + generated
# Update the tail
self.prev_tail = window[-self.max_tail_len :]
# Check each stop string -- we only check inside window
for s in self.stop_strings:
if s in window:
return True
return False
For similar reasons, I suggest fixing remove_stop_sequences
as well:
def remove_stop_sequences(content: str, stop_sequences: list[str], delta: int = 128) -> str:
"""
Removes stop sequences if they occur at the end of the content string, but accounts for extra trailing characters due to tokenization artefacts.
`delta` sets the allowed right offset for extra characters.
"""
for stop_seq in stop_sequences:
subseq_len = len(stop_seq) + delta
if stop_seq in content[-subseq_len:]:
idx = content.rindex(stop_seq)
content = content[:idx]
return content
In addition, I would like to note that after adding these simple fixes to my project, I noticed a significant reduction in the average step execution time of my agent and even some improvements in quality. It seems that generating these fake “Calling tools:” lines actually wastes a lot of computing resources.
I would be glad to submit a pull request that resolves this issue.
Checklist
- I have searched the existing issues and have not found a similar bug report.
- I have provided a minimal, reproducible example.
- I have provided the full traceback of the error.
- I have provided my environment details.
- I am willing to work on this issue and submit a pull request.