Skip to content

Commit 9fb5fbf

Browse files
committed
Plan smolagent CLI autonomy
1 parent 5c684c1 commit 9fb5fbf

File tree

7 files changed

+1255
-152
lines changed

7 files changed

+1255
-152
lines changed

src/smolagents/agents.py

Lines changed: 166 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
extract_code_from_text,
9292
is_valid_name,
9393
make_init_file,
94+
make_json_serializable,
9495
parse_code_blobs,
9596
truncate_content,
9697
)
@@ -253,6 +254,24 @@ def dict(self):
253254
}
254255

255256

257+
@dataclass
258+
class RunStateSnapshot:
259+
step_number: int
260+
completed_action_steps: int
261+
pending_task: str | None
262+
last_plan: str | None
263+
memory_summary: list[dict[str, Any]]
264+
265+
def dict(self) -> dict[str, Any]:
266+
return {
267+
"step_number": self.step_number,
268+
"completed_action_steps": self.completed_action_steps,
269+
"pending_task": self.pending_task,
270+
"last_plan": self.last_plan,
271+
"memory_summary": self.memory_summary,
272+
}
273+
274+
256275
StreamEvent: TypeAlias = Union[
257276
ChatMessageStreamDelta,
258277
ChatMessageToolCall,
@@ -289,6 +308,8 @@ class MultiStepAgent(ABC):
289308
- Take the final answer, the agent's memory, and the agent itself as arguments.
290309
- Return a boolean indicating whether the final answer is valid.
291310
return_full_result (`bool`, default `False`): Whether to return the full [`RunResult`] object or just the final answer output from the agent run.
311+
tool_retry_limit (`int`, default `0`): Number of retries to allow for recoverable step errors.
312+
stagnation_window (`int`, default `0`): Trigger an unscheduled planning step after this many repeated non-progress steps.
292313
"""
293314

294315
def __init__(
@@ -308,6 +329,8 @@ def __init__(
308329
provide_run_summary: bool = False,
309330
final_answer_checks: list[Callable] | None = None,
310331
return_full_result: bool = False,
332+
tool_retry_limit: int = 0,
333+
stagnation_window: int = 0,
311334
logger: AgentLogger | None = None,
312335
):
313336
self.agent_name = self.__class__.__name__
@@ -334,6 +357,12 @@ def __init__(
334357
self.provide_run_summary = provide_run_summary
335358
self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
336359
self.return_full_result = return_full_result
360+
if tool_retry_limit < 0:
361+
raise ValueError(f"tool_retry_limit must be >= 0, got {tool_retry_limit}")
362+
if stagnation_window < 0:
363+
raise ValueError(f"stagnation_window must be >= 0, got {stagnation_window}")
364+
self.tool_retry_limit = tool_retry_limit
365+
self.stagnation_window = stagnation_window
337366
self.instructions = instructions
338367
self._setup_managed_agents(managed_agents)
339368
self._setup_tools(tools, add_base_tools)
@@ -350,6 +379,7 @@ def __init__(
350379
self.monitor = Monitor(self.model, self.logger)
351380
self._setup_step_callbacks(step_callbacks)
352381
self.stream_outputs = False
382+
self._reset_autonomy_state()
353383

354384
@property
355385
def system_prompt(self) -> str:
@@ -433,6 +463,55 @@ def _setup_step_callbacks(self, step_callbacks):
433463
# Register monitor update_metrics only for ActionStep for backward compatibility
434464
self.step_callbacks.register(ActionStep, self.monitor.update_metrics)
435465

466+
def _reset_autonomy_state(self):
467+
self._force_plan_step = False
468+
self._stagnant_step_count = 0
469+
self._last_progress_signature: str | None = None
470+
471+
def _is_retryable_error(self, error: AgentError) -> bool:
472+
return not isinstance(error, (AgentGenerationError, AgentMaxStepsError))
473+
474+
def _build_progress_signature(self, action_step: ActionStep) -> str:
475+
if action_step.error is not None:
476+
return f"error:{type(action_step.error).__name__}:{action_step.error.message}"
477+
signature = {
478+
"tool_calls": (
479+
[
480+
{"name": tool_call.name, "arguments": make_json_serializable(tool_call.arguments)}
481+
for tool_call in action_step.tool_calls
482+
]
483+
if action_step.tool_calls
484+
else []
485+
),
486+
"observations": action_step.observations,
487+
"action_output": make_json_serializable(action_step.action_output),
488+
"model_output": make_json_serializable(action_step.model_output),
489+
}
490+
return json.dumps(make_json_serializable(signature), sort_keys=True)
491+
492+
def _update_stagnation_tracking(self, action_step: ActionStep):
493+
if self.stagnation_window <= 0:
494+
return
495+
if action_step.is_final_answer:
496+
self._stagnant_step_count = 0
497+
self._last_progress_signature = None
498+
return
499+
500+
signature = self._build_progress_signature(action_step)
501+
if self._last_progress_signature == signature:
502+
self._stagnant_step_count += 1
503+
else:
504+
self._stagnant_step_count = 0
505+
self._last_progress_signature = signature
506+
507+
if self._stagnant_step_count >= self.stagnation_window:
508+
self._force_plan_step = True
509+
self._stagnant_step_count = 0
510+
self.logger.log(
511+
"Detected repeated non-progress steps. Triggering an unscheduled planning step.",
512+
level=LogLevel.INFO,
513+
)
514+
436515
def run(
437516
self,
438517
task: str,
@@ -468,6 +547,7 @@ def run(
468547
max_steps = max_steps or self.max_steps
469548
self.task = task
470549
self.interrupt_switch = False
550+
self._reset_autonomy_state()
471551
if additional_args:
472552
self.state.update(additional_args)
473553
self.task += f"""
@@ -542,14 +622,17 @@ def _run_stream(
542622
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
543623
self.step_number = 1
544624
returned_final_answer = False
625+
final_answer = None
545626
while not returned_final_answer and self.step_number <= max_steps:
546627
if self.interrupt_switch:
547628
raise AgentError("Agent interrupted.", self.logger)
548629

549630
# Run a planning step if scheduled
550-
if self.planning_interval is not None and (
551-
self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
552-
):
631+
should_plan = self._force_plan_step or (
632+
self.planning_interval is not None
633+
and (self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0)
634+
)
635+
if should_plan:
553636
planning_start_time = time.time()
554637
planning_step = None
555638
for element in self._generate_planning_step(
@@ -565,47 +648,67 @@ def _run_stream(
565648
)
566649
self._finalize_step(planning_step)
567650
self.memory.steps.append(planning_step)
568-
569-
# Start action step!
570-
action_step_start_time = time.time()
571-
action_step = ActionStep(
572-
step_number=self.step_number,
573-
timing=Timing(start_time=action_step_start_time),
574-
observations_images=images,
575-
)
576-
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
577-
try:
578-
for output in self._step_stream(action_step):
579-
# Yield all
580-
yield output
581-
582-
if isinstance(output, ActionOutput) and output.is_final_answer:
583-
final_answer = output.output
651+
self._force_plan_step = False
652+
653+
retries_left = self.tool_retry_limit
654+
while True:
655+
action_step_start_time = time.time()
656+
action_step = ActionStep(
657+
step_number=self.step_number,
658+
timing=Timing(start_time=action_step_start_time),
659+
observations_images=images,
660+
)
661+
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
662+
663+
retry_attempted = False
664+
try:
665+
for output in self._step_stream(action_step):
666+
# Yield all
667+
yield output
668+
669+
if isinstance(output, ActionOutput) and output.is_final_answer:
670+
final_answer = output.output
671+
self.logger.log(
672+
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
673+
level=LogLevel.INFO,
674+
)
675+
676+
if self.final_answer_checks:
677+
self._validate_final_answer(final_answer)
678+
returned_final_answer = True
679+
action_step.is_final_answer = True
680+
except AgentGenerationError as e:
681+
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
682+
action_step.error = e
683+
self._finalize_step(action_step)
684+
self.memory.steps.append(action_step)
685+
self._update_stagnation_tracking(action_step)
686+
yield action_step
687+
raise e
688+
except AgentError as e:
689+
# Other AgentError types are typically recoverable and can be retried when configured.
690+
action_step.error = e
691+
retry_attempted = retries_left > 0 and self._is_retryable_error(e)
692+
if retry_attempted:
693+
retries_left -= 1
584694
self.logger.log(
585-
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
695+
f"Retrying step {self.step_number} after recoverable error ({retries_left} retries left).",
586696
level=LogLevel.INFO,
587697
)
588698

589-
if self.final_answer_checks:
590-
self._validate_final_answer(final_answer)
591-
returned_final_answer = True
592-
action_step.is_final_answer = True
593-
594-
except AgentGenerationError as e:
595-
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
596-
raise e
597-
except AgentError as e:
598-
# Other AgentError types are caused by the Model, so we should log them and iterate.
599-
action_step.error = e
600-
finally:
601699
self._finalize_step(action_step)
602700
self.memory.steps.append(action_step)
701+
self._update_stagnation_tracking(action_step)
603702
yield action_step
703+
704+
if retry_attempted:
705+
continue
706+
604707
self.step_number += 1
708+
break
605709

606710
if not returned_final_answer and self.step_number == max_steps + 1:
607711
final_answer = self._handle_max_steps_reached(task)
608-
yield action_step
609712
final_answer_step = FinalAnswerStep(handle_agent_output_types(final_answer))
610713
self._finalize_step(final_answer_step)
611714
yield final_answer_step
@@ -865,6 +968,32 @@ def replay(self, detailed: bool = False):
865968
"""
866969
self.memory.replay(self.logger, detailed=detailed)
867970

971+
def get_run_state_snapshot(self) -> dict[str, Any]:
972+
"""Return a compact, serializable snapshot of the current run state."""
973+
action_steps = [step for step in self.memory.steps if isinstance(step, ActionStep)]
974+
planning_steps = [step for step in self.memory.steps if isinstance(step, PlanningStep)]
975+
memory_summary = []
976+
for message in self.write_memory_to_messages(summary_mode=True):
977+
memory_summary.append(
978+
{
979+
"role": str(message.role),
980+
"content": truncate_content(str(make_json_serializable(message.content)), max_length=5000),
981+
}
982+
)
983+
984+
snapshot = RunStateSnapshot(
985+
step_number=self.step_number,
986+
completed_action_steps=len(action_steps),
987+
pending_task=self.task,
988+
last_plan=planning_steps[-1].plan if planning_steps else None,
989+
memory_summary=memory_summary,
990+
).dict()
991+
snapshot["stagnation_state"] = {
992+
"window": self.stagnation_window,
993+
"last_signature": self._last_progress_signature,
994+
}
995+
return snapshot
996+
868997
def __call__(self, task: str, **kwargs):
869998
"""Adds additional prompting for the managed agent, runs it, and wraps the output.
870999
This method is called only by a managed agent.
@@ -1001,6 +1130,8 @@ def to_dict(self) -> dict[str, Any]:
10011130
"max_steps": self.max_steps,
10021131
"verbosity_level": int(self.logger.level),
10031132
"planning_interval": self.planning_interval,
1133+
"tool_retry_limit": self.tool_retry_limit,
1134+
"stagnation_window": self.stagnation_window,
10041135
"name": self.name,
10051136
"description": self.description,
10061137
"requirements": sorted(requirements),
@@ -1051,6 +1182,8 @@ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "MultiStepAgent":
10511182
"max_steps": agent_dict.get("max_steps"),
10521183
"verbosity_level": agent_dict.get("verbosity_level"),
10531184
"planning_interval": agent_dict.get("planning_interval"),
1185+
"tool_retry_limit": agent_dict.get("tool_retry_limit"),
1186+
"stagnation_window": agent_dict.get("stagnation_window"),
10541187
"name": agent_dict.get("name"),
10551188
"description": agent_dict.get("description"),
10561189
}

0 commit comments

Comments
 (0)