9191 extract_code_from_text ,
9292 is_valid_name ,
9393 make_init_file ,
94+ make_json_serializable ,
9495 parse_code_blobs ,
9596 truncate_content ,
9697)
@@ -253,6 +254,24 @@ def dict(self):
253254 }
254255
255256
257+ @dataclass
258+ class RunStateSnapshot :
259+ step_number : int
260+ completed_action_steps : int
261+ pending_task : str | None
262+ last_plan : str | None
263+ memory_summary : list [dict [str , Any ]]
264+
265+ def dict (self ) -> dict [str , Any ]:
266+ return {
267+ "step_number" : self .step_number ,
268+ "completed_action_steps" : self .completed_action_steps ,
269+ "pending_task" : self .pending_task ,
270+ "last_plan" : self .last_plan ,
271+ "memory_summary" : self .memory_summary ,
272+ }
273+
274+
256275StreamEvent : TypeAlias = Union [
257276 ChatMessageStreamDelta ,
258277 ChatMessageToolCall ,
@@ -289,6 +308,8 @@ class MultiStepAgent(ABC):
289308 - Take the final answer, the agent's memory, and the agent itself as arguments.
290309 - Return a boolean indicating whether the final answer is valid.
291310 return_full_result (`bool`, default `False`): Whether to return the full [`RunResult`] object or just the final answer output from the agent run.
311+ tool_retry_limit (`int`, default `0`): Number of retries to allow for recoverable step errors.
312+ stagnation_window (`int`, default `0`): Trigger an unscheduled planning step after this many repeated non-progress steps.
292313 """
293314
294315 def __init__ (
@@ -308,6 +329,8 @@ def __init__(
308329 provide_run_summary : bool = False ,
309330 final_answer_checks : list [Callable ] | None = None ,
310331 return_full_result : bool = False ,
332+ tool_retry_limit : int = 0 ,
333+ stagnation_window : int = 0 ,
311334 logger : AgentLogger | None = None ,
312335 ):
313336 self .agent_name = self .__class__ .__name__
@@ -334,6 +357,12 @@ def __init__(
334357 self .provide_run_summary = provide_run_summary
335358 self .final_answer_checks = final_answer_checks if final_answer_checks is not None else []
336359 self .return_full_result = return_full_result
360+ if tool_retry_limit < 0 :
361+ raise ValueError (f"tool_retry_limit must be >= 0, got { tool_retry_limit } " )
362+ if stagnation_window < 0 :
363+ raise ValueError (f"stagnation_window must be >= 0, got { stagnation_window } " )
364+ self .tool_retry_limit = tool_retry_limit
365+ self .stagnation_window = stagnation_window
337366 self .instructions = instructions
338367 self ._setup_managed_agents (managed_agents )
339368 self ._setup_tools (tools , add_base_tools )
@@ -350,6 +379,7 @@ def __init__(
350379 self .monitor = Monitor (self .model , self .logger )
351380 self ._setup_step_callbacks (step_callbacks )
352381 self .stream_outputs = False
382+ self ._reset_autonomy_state ()
353383
354384 @property
355385 def system_prompt (self ) -> str :
@@ -433,6 +463,55 @@ def _setup_step_callbacks(self, step_callbacks):
433463 # Register monitor update_metrics only for ActionStep for backward compatibility
434464 self .step_callbacks .register (ActionStep , self .monitor .update_metrics )
435465
466+ def _reset_autonomy_state (self ):
467+ self ._force_plan_step = False
468+ self ._stagnant_step_count = 0
469+ self ._last_progress_signature : str | None = None
470+
471+ def _is_retryable_error (self , error : AgentError ) -> bool :
472+ return not isinstance (error , (AgentGenerationError , AgentMaxStepsError ))
473+
474+ def _build_progress_signature (self , action_step : ActionStep ) -> str :
475+ if action_step .error is not None :
476+ return f"error:{ type (action_step .error ).__name__ } :{ action_step .error .message } "
477+ signature = {
478+ "tool_calls" : (
479+ [
480+ {"name" : tool_call .name , "arguments" : make_json_serializable (tool_call .arguments )}
481+ for tool_call in action_step .tool_calls
482+ ]
483+ if action_step .tool_calls
484+ else []
485+ ),
486+ "observations" : action_step .observations ,
487+ "action_output" : make_json_serializable (action_step .action_output ),
488+ "model_output" : make_json_serializable (action_step .model_output ),
489+ }
490+ return json .dumps (make_json_serializable (signature ), sort_keys = True )
491+
492+ def _update_stagnation_tracking (self , action_step : ActionStep ):
493+ if self .stagnation_window <= 0 :
494+ return
495+ if action_step .is_final_answer :
496+ self ._stagnant_step_count = 0
497+ self ._last_progress_signature = None
498+ return
499+
500+ signature = self ._build_progress_signature (action_step )
501+ if self ._last_progress_signature == signature :
502+ self ._stagnant_step_count += 1
503+ else :
504+ self ._stagnant_step_count = 0
505+ self ._last_progress_signature = signature
506+
507+ if self ._stagnant_step_count >= self .stagnation_window :
508+ self ._force_plan_step = True
509+ self ._stagnant_step_count = 0
510+ self .logger .log (
511+ "Detected repeated non-progress steps. Triggering an unscheduled planning step." ,
512+ level = LogLevel .INFO ,
513+ )
514+
436515 def run (
437516 self ,
438517 task : str ,
@@ -468,6 +547,7 @@ def run(
468547 max_steps = max_steps or self .max_steps
469548 self .task = task
470549 self .interrupt_switch = False
550+ self ._reset_autonomy_state ()
471551 if additional_args :
472552 self .state .update (additional_args )
473553 self .task += f"""
@@ -542,14 +622,17 @@ def _run_stream(
542622 ) -> Generator [ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta ]:
543623 self .step_number = 1
544624 returned_final_answer = False
625+ final_answer = None
545626 while not returned_final_answer and self .step_number <= max_steps :
546627 if self .interrupt_switch :
547628 raise AgentError ("Agent interrupted." , self .logger )
548629
549630 # Run a planning step if scheduled
550- if self .planning_interval is not None and (
551- self .step_number == 1 or (self .step_number - 1 ) % self .planning_interval == 0
552- ):
631+ should_plan = self ._force_plan_step or (
632+ self .planning_interval is not None
633+ and (self .step_number == 1 or (self .step_number - 1 ) % self .planning_interval == 0 )
634+ )
635+ if should_plan :
553636 planning_start_time = time .time ()
554637 planning_step = None
555638 for element in self ._generate_planning_step (
@@ -565,47 +648,67 @@ def _run_stream(
565648 )
566649 self ._finalize_step (planning_step )
567650 self .memory .steps .append (planning_step )
568-
569- # Start action step!
570- action_step_start_time = time .time ()
571- action_step = ActionStep (
572- step_number = self .step_number ,
573- timing = Timing (start_time = action_step_start_time ),
574- observations_images = images ,
575- )
576- self .logger .log_rule (f"Step { self .step_number } " , level = LogLevel .INFO )
577- try :
578- for output in self ._step_stream (action_step ):
579- # Yield all
580- yield output
581-
582- if isinstance (output , ActionOutput ) and output .is_final_answer :
583- final_answer = output .output
651+ self ._force_plan_step = False
652+
653+ retries_left = self .tool_retry_limit
654+ while True :
655+ action_step_start_time = time .time ()
656+ action_step = ActionStep (
657+ step_number = self .step_number ,
658+ timing = Timing (start_time = action_step_start_time ),
659+ observations_images = images ,
660+ )
661+ self .logger .log_rule (f"Step { self .step_number } " , level = LogLevel .INFO )
662+
663+ retry_attempted = False
664+ try :
665+ for output in self ._step_stream (action_step ):
666+ # Yield all
667+ yield output
668+
669+ if isinstance (output , ActionOutput ) and output .is_final_answer :
670+ final_answer = output .output
671+ self .logger .log (
672+ Text (f"Final answer: { final_answer } " , style = f"bold { YELLOW_HEX } " ),
673+ level = LogLevel .INFO ,
674+ )
675+
676+ if self .final_answer_checks :
677+ self ._validate_final_answer (final_answer )
678+ returned_final_answer = True
679+ action_step .is_final_answer = True
680+ except AgentGenerationError as e :
681+ # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
682+ action_step .error = e
683+ self ._finalize_step (action_step )
684+ self .memory .steps .append (action_step )
685+ self ._update_stagnation_tracking (action_step )
686+ yield action_step
687+ raise e
688+ except AgentError as e :
689+ # Other AgentError types are typically recoverable and can be retried when configured.
690+ action_step .error = e
691+ retry_attempted = retries_left > 0 and self ._is_retryable_error (e )
692+ if retry_attempted :
693+ retries_left -= 1
584694 self .logger .log (
585- Text ( f"Final answer: { final_answer } " , style = f"bold { YELLOW_HEX } " ) ,
695+ f"Retrying step { self . step_number } after recoverable error ( { retries_left } retries left)." ,
586696 level = LogLevel .INFO ,
587697 )
588698
589- if self .final_answer_checks :
590- self ._validate_final_answer (final_answer )
591- returned_final_answer = True
592- action_step .is_final_answer = True
593-
594- except AgentGenerationError as e :
595- # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
596- raise e
597- except AgentError as e :
598- # Other AgentError types are caused by the Model, so we should log them and iterate.
599- action_step .error = e
600- finally :
601699 self ._finalize_step (action_step )
602700 self .memory .steps .append (action_step )
701+ self ._update_stagnation_tracking (action_step )
603702 yield action_step
703+
704+ if retry_attempted :
705+ continue
706+
604707 self .step_number += 1
708+ break
605709
606710 if not returned_final_answer and self .step_number == max_steps + 1 :
607711 final_answer = self ._handle_max_steps_reached (task )
608- yield action_step
609712 final_answer_step = FinalAnswerStep (handle_agent_output_types (final_answer ))
610713 self ._finalize_step (final_answer_step )
611714 yield final_answer_step
@@ -865,6 +968,32 @@ def replay(self, detailed: bool = False):
865968 """
866969 self .memory .replay (self .logger , detailed = detailed )
867970
971+ def get_run_state_snapshot (self ) -> dict [str , Any ]:
972+ """Return a compact, serializable snapshot of the current run state."""
973+ action_steps = [step for step in self .memory .steps if isinstance (step , ActionStep )]
974+ planning_steps = [step for step in self .memory .steps if isinstance (step , PlanningStep )]
975+ memory_summary = []
976+ for message in self .write_memory_to_messages (summary_mode = True ):
977+ memory_summary .append (
978+ {
979+ "role" : str (message .role ),
980+ "content" : truncate_content (str (make_json_serializable (message .content )), max_length = 5000 ),
981+ }
982+ )
983+
984+ snapshot = RunStateSnapshot (
985+ step_number = self .step_number ,
986+ completed_action_steps = len (action_steps ),
987+ pending_task = self .task ,
988+ last_plan = planning_steps [- 1 ].plan if planning_steps else None ,
989+ memory_summary = memory_summary ,
990+ ).dict ()
991+ snapshot ["stagnation_state" ] = {
992+ "window" : self .stagnation_window ,
993+ "last_signature" : self ._last_progress_signature ,
994+ }
995+ return snapshot
996+
868997 def __call__ (self , task : str , ** kwargs ):
869998 """Adds additional prompting for the managed agent, runs it, and wraps the output.
870999 This method is called only by a managed agent.
@@ -1001,6 +1130,8 @@ def to_dict(self) -> dict[str, Any]:
10011130 "max_steps" : self .max_steps ,
10021131 "verbosity_level" : int (self .logger .level ),
10031132 "planning_interval" : self .planning_interval ,
1133+ "tool_retry_limit" : self .tool_retry_limit ,
1134+ "stagnation_window" : self .stagnation_window ,
10041135 "name" : self .name ,
10051136 "description" : self .description ,
10061137 "requirements" : sorted (requirements ),
@@ -1051,6 +1182,8 @@ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "MultiStepAgent":
10511182 "max_steps" : agent_dict .get ("max_steps" ),
10521183 "verbosity_level" : agent_dict .get ("verbosity_level" ),
10531184 "planning_interval" : agent_dict .get ("planning_interval" ),
1185+ "tool_retry_limit" : agent_dict .get ("tool_retry_limit" ),
1186+ "stagnation_window" : agent_dict .get ("stagnation_window" ),
10541187 "name" : agent_dict .get ("name" ),
10551188 "description" : agent_dict .get ("description" ),
10561189 }
0 commit comments