diff --git a/README.md b/README.md
index a1ac4e85..9d818f04 100644
--- a/README.md
+++ b/README.md
@@ -41,17 +41,18 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
### Active experiments
-| Name | Type | Expected End Date | Dependencies | Cookbook | Discussion |
-|---------------------------------------|--------------------------------|-------------------|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
-| [`InMemoryChatMessageStore`][1] | Memory Store | December 2024 | None |
| [Discuss][4] |
-| [`ChatMessageRetriever`][2] | Memory Component | December 2024 | None |
| [Discuss][4] |
-| [`ChatMessageWriter`][3] | Memory Component | December 2024 | None |
| [Discuss][4] |
-| [`QueryExpander`][5] | Query Expansion Component | October 2025 | None | None | [Discuss][6] |
-| [`EmbeddingBasedDocumentSplitter`][8] | EmbeddingBasedDocumentSplitter | August 2025 | None | None | [Discuss][7] |
-| [`MultiQueryEmbeddingRetriever`][13] | MultiQueryEmbeddingRetriever | November 2025 | None | None | [Discuss][11] |
-| [`MultiQueryTextRetriever`][14] | MultiQueryTextRetriever | November 2025 | None | None | [Discuss][12] |
-| [`OpenAIChatGenerator`][9] | Chat Generator Component | November 2025 | None |
| [Discuss][10] |
-| [`MarkdownHeaderLevelInferrer`][15] | Preprocessor | January 2025 | None | None | [Discuss][16] |
+| Name | Type | Expected End Date | Dependencies | Cookbook | Discussion |
+|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|-------------------|--------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
+| [`InMemoryChatMessageStore`][1] | Memory Store | December 2024 | None |
| [Discuss][4] |
+| [`ChatMessageRetriever`][2] | Memory Component | December 2024 | None |
| [Discuss][4] |
+| [`ChatMessageWriter`][3] | Memory Component | December 2024 | None |
| [Discuss][4] |
+| [`QueryExpander`][5] | Query Expansion Component | October 2025 | None | None | [Discuss][6] |
+| [`EmbeddingBasedDocumentSplitter`][8] | EmbeddingBasedDocumentSplitter | August 2025 | None | None | [Discuss][7] |
+| [`MultiQueryEmbeddingRetriever`][13] | MultiQueryEmbeddingRetriever | November 2025 | None | None | [Discuss][11] |
+| [`MultiQueryTextRetriever`][14] | MultiQueryTextRetriever | November 2025 | None | None | [Discuss][12] |
+| [`OpenAIChatGenerator`][9] | Chat Generator Component | November 2025 | None |
| [Discuss][10] |
+| [`MarkdownHeaderLevelInferrer`][15] | Preprocessor | January 2025 | None | None | [Discuss][16] |
+| [`Agent`][17]; [Confirmation Policies][18]; [ConfirmationUIs][19]; [ConfirmationStrategies][20]; [`ConfirmationUIResult` and `ToolExecutionDecision`][21] [HITLBreakpointException][22] | Human in the Loop | December 2025 | rich | None | [Discuss][23] |
[1]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/chat_message_stores/in_memory.py
[2]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/chat_message_retriever.py
@@ -69,8 +70,13 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
[14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/multi_query_text_retriever.py
[15]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/md_header_level_inferrer.py
[16]: https://github.com/deepset-ai/haystack-experimental/discussions/376
-
-
+[17]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/agent.py
+[18]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/policies.py
+[19]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py
+[20]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/strategies.py
+[21]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py
+[22]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/errors.py
+[23]: https://github.com/deepset-ai/haystack-experimental/discussions/XXX
### Adopted experiments
| Name | Type | Final release |
diff --git a/docs/pydoc/config/agents_api.yml b/docs/pydoc/config/agents_api.yml
new file mode 100644
index 00000000..9a92549f
--- /dev/null
+++ b/docs/pydoc/config/agents_api.yml
@@ -0,0 +1,36 @@
+loaders:
+ - type: haystack_pydoc_tools.loaders.CustomPythonLoader
+ search_path: [../../../]
+ modules: [
+ "haystack_experimental.components.agents.agent",
+ "haystack_experimental.components.agents.human_in_the_loop.breakpoint",
+ "haystack_experimental.components.agents.human_in_the_loop.dataclasses",
+ "haystack_experimental.components.agents.human_in_the_loop.errors",
+ "haystack_experimental.components.agents.human_in_the_loop.policies",
+ "haystack_experimental.components.agents.human_in_the_loop.strategies",
+ "haystack_experimental.components.agents.human_in_the_loop.types",
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces",
+ ]
+ ignore_when_discovered: ["__init__"]
+processors:
+ - type: filter
+ expression:
+ documented_only: true
+ do_not_filter_modules: false
+ skip_empty_modules: true
+ - type: smart
+ - type: crossref
+renderer:
+ type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
+ excerpt: Tool-using agents with provider-agnostic chat model support.
+ category_slug: haystack-api
+ title: Agents
+ slug: experimental-agents-api
+ order: 2
+ markdown:
+ descriptive_class_title: false
+ classdef_code_block: false
+ descriptive_module_title: true
+ add_method_class_prefix: true
+ add_member_class_prefix: false
+ filename: experimental_agents_api.md
diff --git a/haystack_experimental/components/agents/__init__.py b/haystack_experimental/components/agents/__init__.py
new file mode 100644
index 00000000..4170413e
--- /dev/null
+++ b/haystack_experimental/components/agents/__init__.py
@@ -0,0 +1,16 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import sys
+from typing import TYPE_CHECKING
+
+from lazy_imports import LazyImporter
+
+_import_structure = {"agent": ["Agent"]}
+
+if TYPE_CHECKING:
+ from .agent import Agent as Agent
+
+else:
+ sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
diff --git a/haystack_experimental/components/agents/agent.py b/haystack_experimental/components/agents/agent.py
new file mode 100644
index 00000000..60f451bc
--- /dev/null
+++ b/haystack_experimental/components/agents/agent.py
@@ -0,0 +1,634 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# pylint: disable=wrong-import-order,wrong-import-position,ungrouped-imports
+# ruff: noqa: I001
+
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+# Monkey patch Haystack's AgentSnapshot with our extended version
+import haystack.dataclasses.breakpoints as hdb
+from haystack_experimental.dataclasses.breakpoints import AgentSnapshot
+
+hdb.AgentSnapshot = AgentSnapshot # type: ignore[misc]
+
+# Monkey patch Haystack's breakpoint functions with our extended versions
+import haystack.core.pipeline.breakpoint as hs_breakpoint
+import haystack_experimental.core.pipeline.breakpoint as exp_breakpoint
+
+hs_breakpoint._create_agent_snapshot = exp_breakpoint._create_agent_snapshot
+hs_breakpoint._create_pipeline_snapshot_from_tool_invoker = exp_breakpoint._create_pipeline_snapshot_from_tool_invoker # type: ignore[assignment]
+hs_breakpoint._trigger_tool_invoker_breakpoint = exp_breakpoint._trigger_tool_invoker_breakpoint
+
+from haystack import logging
+from haystack.components.agents.agent import Agent as HaystackAgent
+from haystack.components.agents.agent import _ExecutionContext as Haystack_ExecutionContext
+from haystack.components.agents.agent import _schema_from_dict
+from haystack.components.agents.state import replace_values
+from haystack.components.generators.chat.types import ChatGenerator
+from haystack.core.errors import PipelineRuntimeError
+from haystack.core.pipeline import AsyncPipeline, Pipeline
+from haystack.core.pipeline.breakpoint import (
+ _create_pipeline_snapshot_from_chat_generator,
+ _create_pipeline_snapshot_from_tool_invoker,
+)
+from haystack.core.pipeline.utils import _deepcopy_with_exceptions
+from haystack.core.serialization import default_from_dict, import_class_by_name
+from haystack.dataclasses import ChatMessage
+from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint
+from haystack.dataclasses.streaming_chunk import StreamingCallbackT
+from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace
+from haystack.utils.callable_serialization import deserialize_callable
+from haystack.utils.deserialization import deserialize_chatgenerator_inplace
+
+from haystack_experimental.components.agents.human_in_the_loop import (
+ ConfirmationStrategy,
+ ToolExecutionDecision,
+ HITLBreakpointException,
+)
+from haystack_experimental.components.agents.human_in_the_loop.strategies import _process_confirmation_strategies
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class _ExecutionContext(Haystack_ExecutionContext):
+ """
+ Execution context for the Agent component
+
+ Extends Haystack's _ExecutionContext to include tool execution decisions for human-in-the-loop strategies.
+
+ :param tool_execution_decisions: Optional list of ToolExecutionDecision objects to use instead of prompting
+ the user. This is useful when restarting from a snapshot where tool execution decisions were already made.
+ """
+
+ tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None
+
+
+class Agent(HaystackAgent):
+ """
+ A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
+
+ NOTE: This class extends Haystack's Agent component to add support for human-in-the-loop confirmation strategies.
+
+ The component processes messages and executes tools until an exit condition is met.
+ The exit condition can be triggered either by a direct text response or by invoking a specific designated tool.
+ Multiple exit conditions can be specified.
+
+ When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits.
+
+ ### Usage example
+ ```python
+ from haystack.components.generators.chat import OpenAIChatGenerator
+ from haystack.dataclasses import ChatMessage
+ from haystack.tools.tool import Tool
+
+ from haystack_experimental.components.agents import Agent
+ from haystack_experimental.components.agents.human_in_the_loop import (
+ HumanInTheLoopStrategy,
+ AlwaysAskPolicy,
+ NeverAskPolicy,
+ SimpleConsoleUI,
+ )
+
+ calculator_tool = Tool(name="calculator", description="A tool for performing mathematical calculations.", ...)
+ search_tool = Tool(name="search", description="A tool for searching the web.", ...)
+
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(),
+ tools=[calculator_tool, search_tool],
+ confirmation_strategies={
+ calculator_tool.name: HumanInTheLoopStrategy(
+ confirmation_policy=NeverAskPolicy(), confirmation_ui=SimpleConsoleUI()
+ ),
+ search_tool.name: HumanInTheLoopStrategy(
+ confirmation_policy=AlwaysAskPolicy(), confirmation_ui=SimpleConsoleUI()
+ ),
+ },
+ )
+
+ # Run the agent
+ result = agent.run(
+ messages=[ChatMessage.from_user("Find information about Haystack")]
+ )
+
+ assert "messages" in result # Contains conversation history
+ ```
+ """
+
+ def __init__(
+ self,
+ *,
+ chat_generator: ChatGenerator,
+ tools: Optional[Union[list[Tool], Toolset]] = None,
+ system_prompt: Optional[str] = None,
+ exit_conditions: Optional[list[str]] = None,
+ state_schema: Optional[dict[str, Any]] = None,
+ max_agent_steps: int = 100,
+ streaming_callback: Optional[StreamingCallbackT] = None,
+ raise_on_tool_invocation_failure: bool = False,
+ confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None,
+ tool_invoker_kwargs: Optional[dict[str, Any]] = None,
+ ) -> None:
+ """
+ Initialize the agent component.
+
+ :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
+ :param tools: List of Tool objects or a Toolset that the agent can use.
+ :param system_prompt: System prompt for the agent.
+ :param exit_conditions: List of conditions that will cause the agent to return.
+ Can include "text" if the agent should return when it generates a message without tool calls,
+ or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
+ :param state_schema: The schema for the runtime state used by the tools.
+ :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
+ If the agent exceeds this number of steps, it will stop and return the current state.
+ :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
+ The same callback can be configured to emit tool results when a tool is called.
+ :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
+ If set to False, the exception will be turned into a chat message and passed to the LLM.
+ :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
+ :raises TypeError: If the chat_generator does not support tools parameter in its run method.
+ :raises ValueError: If the exit_conditions are not valid.
+ """
+ super(Agent, self).__init__(
+ chat_generator=chat_generator,
+ tools=tools,
+ system_prompt=system_prompt,
+ exit_conditions=exit_conditions,
+ state_schema=state_schema,
+ max_agent_steps=max_agent_steps,
+ streaming_callback=streaming_callback,
+ raise_on_tool_invocation_failure=raise_on_tool_invocation_failure,
+ tool_invoker_kwargs=tool_invoker_kwargs,
+ )
+ self._confirmation_strategies = confirmation_strategies or {}
+
+ def _initialize_fresh_execution(
+ self,
+ messages: list[ChatMessage],
+ streaming_callback: Optional[StreamingCallbackT],
+ requires_async: bool,
+ *,
+ system_prompt: Optional[str] = None,
+ tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
+ **kwargs: dict[str, Any],
+ ) -> _ExecutionContext:
+ """
+ Initialize execution context for a fresh run of the agent.
+
+ :param messages: List of ChatMessage objects to start the agent with.
+ :param streaming_callback: Optional callback for streaming responses.
+ :param requires_async: Whether the agent run requires asynchronous execution.
+ :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
+ :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
+ When passing tool names, tools are selected from the Agent's originally configured tools.
+ :param kwargs: Additional data to pass to the State used by the Agent.
+ """
+ exe_context = super(Agent, self)._initialize_fresh_execution(
+ messages=messages,
+ streaming_callback=streaming_callback,
+ requires_async=requires_async,
+ system_prompt=system_prompt,
+ tools=tools,
+ **kwargs,
+ )
+ # NOTE: 1st difference with parent method to add this to tool_invoker_inputs
+ if self._tool_invoker:
+ exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
+ self._tool_invoker.enable_streaming_callback_passthrough
+ )
+ # NOTE: 2nd difference is to use the extended _ExecutionContext
+ return _ExecutionContext(
+ state=exe_context.state,
+ component_visits=exe_context.component_visits,
+ chat_generator_inputs=exe_context.chat_generator_inputs,
+ tool_invoker_inputs=exe_context.tool_invoker_inputs,
+ )
+
+ def _initialize_from_snapshot( # type: ignore[override]
+ self,
+ snapshot: AgentSnapshot,
+ streaming_callback: Optional[StreamingCallbackT],
+ requires_async: bool,
+ *,
+ tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
+ ) -> _ExecutionContext:
+ """
+ Initialize execution context from an AgentSnapshot.
+
+ :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution.
+ :param streaming_callback: Optional callback for streaming responses.
+ :param requires_async: Whether the agent run requires asynchronous execution.
+ :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
+ When passing tool names, tools are selected from the Agent's originally configured tools.
+ """
+ exe_context = super(Agent, self)._initialize_from_snapshot(
+ snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools
+ )
+ # NOTE: 1st difference with parent method to add this to tool_invoker_inputs
+ if self._tool_invoker:
+ exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
+ self._tool_invoker.enable_streaming_callback_passthrough
+ )
+ # NOTE: 2nd difference is to use the extended _ExecutionContext and add tool_execution_decisions
+ return _ExecutionContext(
+ state=exe_context.state,
+ component_visits=exe_context.component_visits,
+ chat_generator_inputs=exe_context.chat_generator_inputs,
+ tool_invoker_inputs=exe_context.tool_invoker_inputs,
+ counter=exe_context.counter,
+ skip_chat_generator=exe_context.skip_chat_generator,
+ tool_execution_decisions=snapshot.tool_execution_decisions,
+ )
+
+ def run( # noqa: PLR0915
+ self,
+ messages: list[ChatMessage],
+ streaming_callback: Optional[StreamingCallbackT] = None,
+ *,
+ break_point: Optional[AgentBreakpoint] = None,
+ snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
+ system_prompt: Optional[str] = None,
+ tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
+ **kwargs: Any,
+ ) -> dict[str, Any]:
+ """
+ Process messages and execute tools until an exit condition is met.
+
+ :param messages: List of Haystack ChatMessage objects to process.
+ :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
+ The same callback can be configured to emit tool results when a tool is called.
+ :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
+ for "tool_invoker".
+ :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
+ the relevant information to restart the Agent execution from where it left off.
+ :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
+ :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
+ When passing tool names, tools are selected from the Agent's originally configured tools.
+ :param kwargs: Additional data to pass to the State schema used by the Agent.
+ The keys must match the schema defined in the Agent's `state_schema`.
+ :returns:
+ A dictionary with the following keys:
+ - "messages": List of all messages exchanged during the agent's run.
+ - "last_message": The last message exchanged during the agent's run.
+ - Any additional keys defined in the `state_schema`.
+ :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
+ :raises BreakpointException: If an agent breakpoint is triggered.
+ """
+ # We pop parent_snapshot from kwargs to avoid passing it into State.
+ parent_snapshot = kwargs.pop("parent_snapshot", None)
+ agent_inputs = {
+ "messages": messages,
+ "streaming_callback": streaming_callback,
+ "break_point": break_point,
+ "snapshot": snapshot,
+ **kwargs,
+ }
+ self._runtime_checks(break_point=break_point, snapshot=snapshot)
+
+ if snapshot:
+ exe_context = self._initialize_from_snapshot(
+ snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools
+ )
+ else:
+ exe_context = self._initialize_fresh_execution(
+ messages=messages,
+ streaming_callback=streaming_callback,
+ requires_async=False,
+ system_prompt=system_prompt,
+ tools=tools,
+ **kwargs,
+ )
+
+ with self._create_agent_span() as span:
+ span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
+
+ while exe_context.counter < self.max_agent_steps:
+ # Handle breakpoint and ChatGenerator call
+ Agent._check_chat_generator_breakpoint(
+ execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
+ )
+ # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
+ if exe_context.skip_chat_generator:
+ llm_messages = exe_context.state.get("messages", [])[-1:]
+ # Set to False so the next iteration will call the chat generator
+ exe_context.skip_chat_generator = False
+ else:
+ try:
+ result = Pipeline._run_component(
+ component_name="chat_generator",
+ component={"instance": self.chat_generator},
+ inputs={
+ "messages": exe_context.state.data["messages"],
+ **exe_context.chat_generator_inputs,
+ },
+ component_visits=exe_context.component_visits,
+ parent_span=span,
+ )
+ except PipelineRuntimeError as e:
+ pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
+ agent_name=getattr(self, "__component_name__", None),
+ execution_context=exe_context,
+ parent_snapshot=parent_snapshot,
+ )
+ e.pipeline_snapshot = pipeline_snapshot
+ raise e
+
+ llm_messages = result["replies"]
+ exe_context.state.set("messages", llm_messages)
+
+ # Check if any of the LLM responses contain a tool call or if the LLM is not using tools
+ if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
+ exe_context.counter += 1
+ break
+
+ # Apply confirmation strategies and update State and messages sent to ToolInvoker
+ try:
+ # Run confirmation strategies to get updated tool call messages and modified chat history
+ modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
+ confirmation_strategies=self._confirmation_strategies,
+ messages_with_tool_calls=llm_messages,
+ execution_context=exe_context,
+ )
+ # Replace the chat history in state with the modified one
+ exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
+ except HITLBreakpointException as tbp_error:
+ # We create a break_point to pass into _check_tool_invoker_breakpoint
+ break_point = AgentBreakpoint(
+ agent_name=getattr(self, "__component_name__", ""),
+ break_point=ToolBreakpoint(
+ component_name="tool_invoker",
+ tool_name=tbp_error.tool_name,
+ visit_count=exe_context.component_visits["tool_invoker"],
+ snapshot_file_path=tbp_error.snapshot_file_path,
+ ),
+ )
+
+ # Handle breakpoint
+ Agent._check_tool_invoker_breakpoint(
+ execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
+ )
+
+ # Run ToolInvoker
+ try:
+ # We only send the messages from the LLM to the tool invoker
+ tool_invoker_result = Pipeline._run_component(
+ component_name="tool_invoker",
+ component={"instance": self._tool_invoker},
+ inputs={
+ "messages": modified_tool_call_messages,
+ "state": exe_context.state,
+ **exe_context.tool_invoker_inputs,
+ },
+ component_visits=exe_context.component_visits,
+ parent_span=span,
+ )
+ except PipelineRuntimeError as e:
+ # Access the original Tool Invoker exception
+ original_error = e.__cause__
+ tool_name = getattr(original_error, "tool_name", None)
+
+ pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
+ tool_name=tool_name,
+ agent_name=getattr(self, "__component_name__", None),
+ execution_context=exe_context,
+ parent_snapshot=parent_snapshot,
+ )
+ e.pipeline_snapshot = pipeline_snapshot
+ raise e
+
+ # Set execution context tool execution decisions to empty after applying them b/c they should only
+ # be used once for the current tool calls
+ exe_context.tool_execution_decisions = None
+ tool_messages = tool_invoker_result["tool_messages"]
+ exe_context.state = tool_invoker_result["state"]
+ exe_context.state.set("messages", tool_messages)
+
+ # Check if any LLM message's tool call name matches an exit condition
+ if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
+ exe_context.counter += 1
+ break
+
+ # Increment the step counter
+ exe_context.counter += 1
+
+ if exe_context.counter >= self.max_agent_steps:
+ logger.warning(
+ "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
+ max_agent_steps=self.max_agent_steps,
+ )
+ span.set_content_tag("haystack.agent.output", exe_context.state.data)
+ span.set_tag("haystack.agent.steps_taken", exe_context.counter)
+
+ result = {**exe_context.state.data}
+ if msgs := result.get("messages"):
+ result["last_message"] = msgs[-1]
+ return result
+
+ async def run_async(
+ self,
+ messages: list[ChatMessage],
+ streaming_callback: Optional[StreamingCallbackT] = None,
+ *,
+ break_point: Optional[AgentBreakpoint] = None,
+ snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
+ system_prompt: Optional[str] = None,
+ tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
+ **kwargs: Any,
+ ) -> dict[str, Any]:
+ """
+ Asynchronously process messages and execute tools until the exit condition is met.
+
+ This is the asynchronous version of the `run` method. It follows the same logic but uses
+ asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
+ if available.
+
+ :param messages: List of Haystack ChatMessage objects to process.
+ :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the
+ LLM. The same callback can be configured to emit tool results when a tool is called.
+ :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
+ for "tool_invoker".
+ :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
+ the relevant information to restart the Agent execution from where it left off.
+ :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
+ :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
+ :param kwargs: Additional data to pass to the State schema used by the Agent.
+ The keys must match the schema defined in the Agent's `state_schema`.
+ :returns:
+ A dictionary with the following keys:
+ - "messages": List of all messages exchanged during the agent's run.
+ - "last_message": The last message exchanged during the agent's run.
+ - Any additional keys defined in the `state_schema`.
+ :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
+ :raises BreakpointException: If an agent breakpoint is triggered.
+ """
+ # We pop parent_snapshot from kwargs to avoid passing it into State.
+ parent_snapshot = kwargs.pop("parent_snapshot", None)
+ agent_inputs = {
+ "messages": messages,
+ "streaming_callback": streaming_callback,
+ "break_point": break_point,
+ "snapshot": snapshot,
+ **kwargs,
+ }
+ self._runtime_checks(break_point=break_point, snapshot=snapshot)
+
+ if snapshot:
+ exe_context = self._initialize_from_snapshot(
+ snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools
+ )
+ else:
+ exe_context = self._initialize_fresh_execution(
+ messages=messages,
+ streaming_callback=streaming_callback,
+ requires_async=True,
+ system_prompt=system_prompt,
+ tools=tools,
+ **kwargs,
+ )
+
+ with self._create_agent_span() as span:
+ span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
+
+ while exe_context.counter < self.max_agent_steps:
+ # Handle breakpoint and ChatGenerator call
+ self._check_chat_generator_breakpoint(
+ execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
+ )
+ # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
+ if exe_context.skip_chat_generator:
+ llm_messages = exe_context.state.get("messages", [])[-1:]
+ # Set to False so the next iteration will call the chat generator
+ exe_context.skip_chat_generator = False
+ else:
+ result = await AsyncPipeline._run_component_async(
+ component_name="chat_generator",
+ component={"instance": self.chat_generator},
+ component_inputs={
+ "messages": exe_context.state.data["messages"],
+ **exe_context.chat_generator_inputs,
+ },
+ component_visits=exe_context.component_visits,
+ parent_span=span,
+ )
+ llm_messages = result["replies"]
+ exe_context.state.set("messages", llm_messages)
+
+ # Check if any of the LLM responses contain a tool call or if the LLM is not using tools
+ if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
+ exe_context.counter += 1
+ break
+
+ # Apply confirmation strategies and update State and messages sent to ToolInvoker
+ try:
+ # Run confirmation strategies to get updated tool call messages and modified chat history
+ modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
+ confirmation_strategies=self._confirmation_strategies,
+ messages_with_tool_calls=llm_messages,
+ execution_context=exe_context,
+ )
+ # Replace the chat history in state with the modified one
+ exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values)
+ except HITLBreakpointException as tbp_error:
+ # We create a break_point to pass into _check_tool_invoker_breakpoint
+ break_point = AgentBreakpoint(
+ agent_name=getattr(self, "__component_name__", ""),
+ break_point=ToolBreakpoint(
+ component_name="tool_invoker",
+ tool_name=tbp_error.tool_name,
+ visit_count=exe_context.component_visits["tool_invoker"],
+ snapshot_file_path=tbp_error.snapshot_file_path,
+ ),
+ )
+
+ # Handle breakpoint
+ Agent._check_tool_invoker_breakpoint(
+ execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
+ )
+
+ # Run ToolInvoker
+ # We only send the messages from the LLM to the tool invoker
+ tool_invoker_result = await AsyncPipeline._run_component_async(
+ component_name="tool_invoker",
+ component={"instance": self._tool_invoker},
+ component_inputs={
+ "messages": modified_tool_call_messages,
+ "state": exe_context.state,
+ **exe_context.tool_invoker_inputs,
+ },
+ component_visits=exe_context.component_visits,
+ parent_span=span,
+ )
+
+ # Set execution context tool execution decisions to empty after applying them b/c they should only
+ # be used once for the current tool calls
+ exe_context.tool_execution_decisions = None
+ tool_messages = tool_invoker_result["tool_messages"]
+ exe_context.state = tool_invoker_result["state"]
+ exe_context.state.set("messages", tool_messages)
+
+ # Check if any LLM message's tool call name matches an exit condition
+ if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
+ exe_context.counter += 1
+ break
+
+ # Increment the step counter
+ exe_context.counter += 1
+
+ if exe_context.counter >= self.max_agent_steps:
+ logger.warning(
+ "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
+ max_agent_steps=self.max_agent_steps,
+ )
+ span.set_content_tag("haystack.agent.output", exe_context.state.data)
+ span.set_tag("haystack.agent.steps_taken", exe_context.counter)
+
+ result = {**exe_context.state.data}
+ if msgs := result.get("messages"):
+ result["last_message"] = msgs[-1]
+ return result
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serialize the component to a dictionary.
+
+ :return: Dictionary with serialized data
+ """
+ data = super(Agent, self).to_dict()
+ data["init_parameters"]["confirmation_strategies"] = (
+ {name: strategy.to_dict() for name, strategy in self._confirmation_strategies.items()}
+ if self._confirmation_strategies
+ else None
+ )
+ return data
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "Agent":
+ """
+ Deserialize the agent from a dictionary.
+
+ :param data: Dictionary to deserialize from
+ :return: Deserialized agent
+ """
+ init_params = data.get("init_parameters", {})
+
+ deserialize_chatgenerator_inplace(init_params, key="chat_generator")
+
+ if "state_schema" in init_params:
+ init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
+
+ if init_params.get("streaming_callback") is not None:
+ init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
+
+ deserialize_tools_or_toolset_inplace(init_params, key="tools")
+
+ if "confirmation_strategies" in init_params and init_params["confirmation_strategies"] is not None:
+ for name, strategy_dict in init_params["confirmation_strategies"].items():
+ strategy_class = import_class_by_name(strategy_dict["type"])
+ if not hasattr(strategy_class, "from_dict"):
+ raise TypeError(f"{strategy_class} does not have from_dict method implemented.")
+ init_params["confirmation_strategies"][name] = strategy_class.from_dict(strategy_dict)
+
+ return default_from_dict(cls, data)
diff --git a/haystack_experimental/components/agents/human_in_the_loop/__init__.py b/haystack_experimental/components/agents/human_in_the_loop/__init__.py
new file mode 100644
index 00000000..ab26024a
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/__init__.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import sys
+from typing import TYPE_CHECKING
+
+from lazy_imports import LazyImporter
+
+_import_structure = {
+ "dataclasses": ["ConfirmationUIResult", "ToolExecutionDecision"],
+ "errors": ["HITLBreakpointException"],
+ "policies": ["AlwaysAskPolicy", "NeverAskPolicy", "AskOncePolicy"],
+ "strategies": ["BlockingConfirmationStrategy", "BreakpointConfirmationStrategy"],
+ "types": ["ConfirmationPolicy", "ConfirmationUI", "ConfirmationStrategy"],
+ "user_interfaces": ["RichConsoleUI", "SimpleConsoleUI"],
+}
+
+if TYPE_CHECKING:
+ from .dataclasses import ConfirmationUIResult as ConfirmationUIResult
+ from .dataclasses import ToolExecutionDecision as ToolExecutionDecision
+ from .errors import HITLBreakpointException as HITLBreakpointException
+ from .policies import AlwaysAskPolicy as AlwaysAskPolicy
+ from .policies import AskOncePolicy as AskOncePolicy
+ from .policies import NeverAskPolicy as NeverAskPolicy
+ from .strategies import BlockingConfirmationStrategy as BlockingConfirmationStrategy
+ from .strategies import BreakpointConfirmationStrategy as BreakpointConfirmationStrategy
+ from .types import ConfirmationPolicy as ConfirmationPolicy
+ from .types import ConfirmationStrategy as ConfirmationStrategy
+ from .types import ConfirmationUI as ConfirmationUI
+ from .user_interfaces import RichConsoleUI as RichConsoleUI
+ from .user_interfaces import SimpleConsoleUI as SimpleConsoleUI
+
+else:
+ sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
diff --git a/haystack_experimental/components/agents/human_in_the_loop/breakpoint.py b/haystack_experimental/components/agents/human_in_the_loop/breakpoint.py
new file mode 100644
index 00000000..562435b4
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/breakpoint.py
@@ -0,0 +1,63 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint
+from haystack.utils import _deserialize_value_with_schema
+
+from haystack_experimental.components.agents.human_in_the_loop.strategies import _prepare_tool_args
+
+
+def get_tool_calls_and_descriptions_from_snapshot(
+ agent_snapshot: AgentSnapshot, breakpoint_tool_only: bool = True
+) -> tuple[list[dict], dict[str, str]]:
+ """
+ Extract tool calls and tool descriptions from an AgentSnapshot.
+
+ By default, only the tool call that caused the breakpoint is processed and its arguments are reconstructed.
+ This is useful for scenarios where you want to present the relevant tool call and its description
+ to a human for confirmation before execution.
+
+ :param agent_snapshot: The AgentSnapshot from which to extract tool calls and descriptions.
+ :param breakpoint_tool_only: If True, only the tool call that caused the breakpoint is returned. If False, all tool
+ calls are returned.
+ :returns:
+ A tuple containing a list of tool call dictionaries and a dictionary of tool descriptions
+ """
+ break_point = agent_snapshot.break_point.break_point
+ if not isinstance(break_point, ToolBreakpoint):
+ raise ValueError("The provided AgentSnapshot does not contain a ToolBreakpoint.")
+
+ tool_caused_break_point = break_point.tool_name
+
+ # Deserialize the tool invoker inputs from the snapshot
+ tool_invoker_inputs = _deserialize_value_with_schema(agent_snapshot.component_inputs["tool_invoker"])
+ tool_call_messages = tool_invoker_inputs["messages"]
+ state = tool_invoker_inputs["state"]
+ tool_name_to_tool = {t.name: t for t in tool_invoker_inputs["tools"]}
+
+ tool_calls = []
+ for msg in tool_call_messages:
+ if msg.tool_calls:
+ tool_calls.extend(msg.tool_calls)
+ serialized_tcs = [tc.to_dict() for tc in tool_calls]
+
+ # Reconstruct the final arguments for each tool call
+ tool_descriptions = {}
+ updated_tool_calls = []
+ for tc in serialized_tcs:
+ # Only process the tool that caused the breakpoint if breakpoint_tool_only is True
+ if breakpoint_tool_only and tc["tool_name"] != tool_caused_break_point:
+ continue
+
+ final_args = _prepare_tool_args(
+ tool=tool_name_to_tool[tc["tool_name"]],
+ tool_call_arguments=tc["arguments"],
+ state=state,
+ streaming_callback=tool_invoker_inputs.get("streaming_callback", None),
+ enable_streaming_passthrough=tool_invoker_inputs.get("enable_streaming_passthrough", False),
+ )
+ updated_tool_calls.append({**tc, "arguments": final_args})
+ tool_descriptions[tc["tool_name"]] = tool_name_to_tool[tc["tool_name"]].description
+
+ return updated_tool_calls, tool_descriptions
diff --git a/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py b/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py
new file mode 100644
index 00000000..74309819
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import asdict, dataclass
+from typing import Any, Optional
+
+
+@dataclass
+class ConfirmationUIResult:
+ """
+ Result of the confirmation UI interaction.
+
+ :param action:
+ The action taken by the user such as "confirm", "reject", or "modify".
+ This action type is not enforced to allow for custom actions to be implemented.
+ :param feedback:
+ Optional feedback message from the user. For example, if the user rejects the tool execution,
+ they might provide a reason for the rejection.
+ :param new_tool_params:
+ Optional set of new parameters for the tool. For example, if the user chooses to modify the tool parameters,
+ they can provide a new set of parameters here.
+ """
+
+ action: str # "confirm", "reject", "modify"
+ feedback: Optional[str] = None
+ new_tool_params: Optional[dict[str, Any]] = None
+
+
+@dataclass
+class ToolExecutionDecision:
+ """
+ Decision made regarding tool execution.
+
+ :param tool_name:
+ The name of the tool to be executed.
+ :param execute:
+ A boolean indicating whether to execute the tool with the provided parameters.
+ :param tool_call_id:
+ Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
+ specific tool invocation.
+ :param feedback:
+ Optional feedback message.
+ For example, if the tool execution is rejected, this can contain the reason. Or if the tool parameters were
+ modified, this can contain the modification details.
+ :param final_tool_params:
+ Optional final parameters for the tool if execution is confirmed or modified.
+ """
+
+ tool_name: str
+ execute: bool
+ tool_call_id: Optional[str] = None
+ feedback: Optional[str] = None
+ final_tool_params: Optional[dict[str, Any]] = None
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Convert the ToolExecutionDecision to a dictionary representation.
+
+ :return: A dictionary containing the tool execution decision details.
+ """
+ return asdict(self)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "ToolExecutionDecision":
+ """
+ Populate the ToolExecutionDecision from a dictionary representation.
+
+ :param data: A dictionary containing the tool execution decision details.
+ :return: An instance of ToolExecutionDecision.
+ """
+ return cls(**data)
diff --git a/haystack_experimental/components/agents/human_in_the_loop/errors.py b/haystack_experimental/components/agents/human_in_the_loop/errors.py
new file mode 100644
index 00000000..1ec33161
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/errors.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional
+
+
+class HITLBreakpointException(Exception):
+ """
+ Exception raised when a tool execution is paused by a ConfirmationStrategy (e.g. BreakpointConfirmationStrategy).
+ """
+
+ def __init__(
+ self, message: str, tool_name: str, snapshot_file_path: str, tool_call_id: Optional[str] = None
+ ) -> None:
+ """
+ Initialize the HITLBreakpointException.
+
+ :param message: The exception message.
+ :param tool_name: The name of the tool whose execution is paused.
+ :param snapshot_file_path: The file path to the saved pipeline snapshot.
+ :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
+ the decision with a specific tool invocation.
+ """
+ super().__init__(message)
+ self.tool_name = tool_name
+ self.snapshot_file_path = snapshot_file_path
+ self.tool_call_id = tool_call_id
diff --git a/haystack_experimental/components/agents/human_in_the_loop/policies.py b/haystack_experimental/components/agents/human_in_the_loop/policies.py
new file mode 100644
index 00000000..f80d4a54
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/policies.py
@@ -0,0 +1,78 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult
+from haystack_experimental.components.agents.human_in_the_loop.types import ConfirmationPolicy
+
+
+class AlwaysAskPolicy(ConfirmationPolicy):
+ """Always ask for confirmation."""
+
+ def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool:
+ """
+ Always ask for confirmation before executing the tool.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ :returns: Always returns True, indicating confirmation is needed.
+ """
+ return True
+
+
+class NeverAskPolicy(ConfirmationPolicy):
+ """Never ask for confirmation."""
+
+ def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool:
+ """
+ Never ask for confirmation, always proceed with tool execution.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ :returns: Always returns False, indicating no confirmation is needed.
+ """
+ return False
+
+
+class AskOncePolicy(ConfirmationPolicy):
+ """Ask only once per tool with specific parameters."""
+
+ def __init__(self):
+ self._asked_tools = {}
+
+ def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool:
+ """
+ Ask for confirmation only once per tool with specific parameters.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ :returns: True if confirmation is needed, False if already asked with the same parameters.
+ """
+ # Don't ask again if we've already asked for this tool with the same parameters
+ return not (tool_name in self._asked_tools and self._asked_tools[tool_name] == tool_params)
+
+ def update_after_confirmation(
+ self,
+ tool_name: str,
+ tool_description: str,
+ tool_params: dict[str, Any],
+ confirmation_result: ConfirmationUIResult,
+ ) -> None:
+ """
+ Store the tool and parameters if the action was "confirm" to avoid asking again.
+
+ This method updates the internal state to remember that the user has already confirmed the execution of the
+ tool with the given parameters.
+
+ :param tool_name: The name of the tool that was executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters that were passed to the tool.
+ :param confirmation_result: The result from the confirmation UI.
+ """
+ if confirmation_result.action == "confirm":
+ self._asked_tools[tool_name] = tool_params
diff --git a/haystack_experimental/components/agents/human_in_the_loop/strategies.py b/haystack_experimental/components/agents/human_in_the_loop/strategies.py
new file mode 100644
index 00000000..98679dfb
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/strategies.py
@@ -0,0 +1,455 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import replace
+from typing import TYPE_CHECKING, Any, Optional
+
+from haystack.components.agents.state import State
+from haystack.components.tools.tool_invoker import ToolInvoker
+from haystack.core.serialization import default_from_dict, default_to_dict, import_class_by_name
+from haystack.dataclasses import ChatMessage, StreamingCallbackT
+from haystack.tools import Tool
+
+from haystack_experimental.components.agents.human_in_the_loop import (
+ ConfirmationPolicy,
+ ConfirmationStrategy,
+ ConfirmationUI,
+ HITLBreakpointException,
+ ToolExecutionDecision,
+)
+
+if TYPE_CHECKING:
+ from haystack_experimental.components.agents.agent import _ExecutionContext
+
+
+_REJECTION_FEEDBACK_TEMPLATE = "Tool execution for '{tool_name}' was rejected by the user."
+_MODIFICATION_FEEDBACK_TEMPLATE = (
+ "The parameters for tool '{tool_name}' were updated by the user to:\n{final_tool_params}"
+)
+
+
+class BlockingConfirmationStrategy:
+ """
+ Confirmation strategy that blocks execution to gather user feedback.
+ """
+
+ def __init__(self, confirmation_policy: ConfirmationPolicy, confirmation_ui: ConfirmationUI) -> None:
+ """
+ Initialize the BlockingConfirmationStrategy with a confirmation policy and UI.
+
+ :param confirmation_policy:
+ The confirmation policy to determine when to ask for user confirmation.
+ :param confirmation_ui:
+ The user interface to interact with the user for confirmation.
+ """
+ self.confirmation_policy = confirmation_policy
+ self.confirmation_ui = confirmation_ui
+
+ def run(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
+ ) -> ToolExecutionDecision:
+ """
+ Run the human-in-the-loop strategy for a given tool and its parameters.
+
+ :param tool_name:
+ The name of the tool to be executed.
+ :param tool_description:
+ The description of the tool.
+ :param tool_params:
+ The parameters to be passed to the tool.
+ :param tool_call_id:
+ Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
+ specific tool invocation.
+
+ :returns:
+ A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a
+ feedback message if rejected.
+ """
+ # Check if we should ask based on policy
+ if not self.confirmation_policy.should_ask(
+ tool_name=tool_name, tool_description=tool_description, tool_params=tool_params
+ ):
+ return ToolExecutionDecision(
+ tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
+ )
+
+ # Get user confirmation through UI
+ confirmation_ui_result = self.confirmation_ui.get_user_confirmation(tool_name, tool_description, tool_params)
+
+ # Pass back the result to the policy for any learning/updating
+ self.confirmation_policy.update_after_confirmation(
+ tool_name, tool_description, tool_params, confirmation_ui_result
+ )
+
+ # Process the confirmation result
+ final_args = {}
+ if confirmation_ui_result.action == "reject":
+ explanation_text = _REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tool_name)
+ if confirmation_ui_result.feedback:
+ explanation_text += f" With feedback: {confirmation_ui_result.feedback}"
+ return ToolExecutionDecision(
+ tool_name=tool_name, execute=False, tool_call_id=tool_call_id, feedback=explanation_text
+ )
+ elif confirmation_ui_result.action == "modify" and confirmation_ui_result.new_tool_params:
+ # Update the tool call params with the new params
+ final_args.update(confirmation_ui_result.new_tool_params)
+ explanation_text = _MODIFICATION_FEEDBACK_TEMPLATE.format(tool_name=tool_name, final_tool_params=final_args)
+ if confirmation_ui_result.feedback:
+ explanation_text += f" With feedback: {confirmation_ui_result.feedback}"
+ return ToolExecutionDecision(
+ tool_name=tool_name,
+ tool_call_id=tool_call_id,
+ execute=True,
+ feedback=explanation_text,
+ final_tool_params=final_args,
+ )
+ else: # action == "confirm"
+ return ToolExecutionDecision(
+ tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes the BlockingConfirmationStrategy to a dictionary.
+
+ :returns:
+ Dictionary with serialized data.
+ """
+ return default_to_dict(
+ self, confirmation_policy=self.confirmation_policy.to_dict(), confirmation_ui=self.confirmation_ui.to_dict()
+ )
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "BlockingConfirmationStrategy":
+ """
+ Deserializes the BlockingConfirmationStrategy from a dictionary.
+
+ :param data:
+ Dictionary to deserialize from.
+
+ :returns:
+ Deserialized BlockingConfirmationStrategy.
+ """
+ policy_data = data["init_parameters"]["confirmation_policy"]
+ policy_class = import_class_by_name(policy_data["type"])
+ if not hasattr(policy_class, "from_dict"):
+ raise ValueError(f"Class {policy_class} does not implement from_dict method.")
+ ui_data = data["init_parameters"]["confirmation_ui"]
+ ui_class = import_class_by_name(ui_data["type"])
+ if not hasattr(ui_class, "from_dict"):
+ raise ValueError(f"Class {ui_class} does not implement from_dict method.")
+ return cls(confirmation_policy=policy_class.from_dict(policy_data), confirmation_ui=ui_class.from_dict(ui_data))
+
+
+class BreakpointConfirmationStrategy:
+ """
+ Confirmation strategy that raises a tool breakpoint exception to pause execution and gather user feedback.
+
+ This strategy is designed for scenarios where immediate user interaction is not possible.
+ When a tool execution requires confirmation, it raises an `HITLBreakpointException`, which is caught by the Agent.
+ The Agent then serialize its current state, including the tool call details. This information can then be used to
+ notify a user to review and confirm the tool execution.
+ """
+
+ def __init__(self, snapshot_file_path: str) -> None:
+ """
+ Initialize the BreakpointConfirmationStrategy.
+
+ :param snapshot_file_path: The path to the directory that the snapshot should be saved.
+ """
+ self.snapshot_file_path = snapshot_file_path
+
+ def run(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
+ ) -> ToolExecutionDecision:
+ """
+ Run the breakpoint confirmation strategy for a given tool and its parameters.
+
+ :param tool_name:
+ The name of the tool to be executed.
+ :param tool_description:
+ The description of the tool.
+ :param tool_params:
+ The parameters to be passed to the tool.
+ :param tool_call_id:
+ Optional unique identifier for the tool call. This can be used to track and correlate the decision with a
+ specific tool invocation.
+
+ :raises HITLBreakpointException:
+ Always raises an `HITLBreakpointException` exception to signal that user confirmation is required.
+
+ :returns:
+ This method does not return; it always raises an exception.
+ """
+ raise HITLBreakpointException(
+ message=f"Tool execution for '{tool_name}' requires user confirmation.",
+ tool_name=tool_name,
+ tool_call_id=tool_call_id,
+ snapshot_file_path=self.snapshot_file_path,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes the BreakpointConfirmationStrategy to a dictionary.
+ """
+ return default_to_dict(self, snapshot_file_path=self.snapshot_file_path)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "BreakpointConfirmationStrategy":
+ """
+ Deserializes the BreakpointConfirmationStrategy from a dictionary.
+
+ :param data:
+ Dictionary to deserialize from.
+
+ :returns:
+ Deserialized BreakpointConfirmationStrategy.
+ """
+ return default_from_dict(cls, data)
+
+
+def _prepare_tool_args(
+ *,
+ tool: Tool,
+ tool_call_arguments: dict[str, Any],
+ state: State,
+ streaming_callback: Optional[StreamingCallbackT] = None,
+ enable_streaming_passthrough: bool = False,
+) -> dict[str, Any]:
+ """
+ Prepare the final arguments for a tool by injecting state inputs and optionally a streaming callback.
+
+ :param tool:
+ The tool instance to prepare arguments for.
+ :param tool_call_arguments:
+ The initial arguments provided for the tool call.
+ :param state:
+ The current state containing inputs to be injected into the tool arguments.
+ :param streaming_callback:
+ Optional streaming callback to be injected if enabled and applicable.
+ :param enable_streaming_passthrough:
+ Flag indicating whether to inject the streaming callback into the tool arguments.
+
+ :returns:
+ A dictionary of final arguments ready for tool invocation.
+ """
+ # Combine user + state inputs
+ final_args = ToolInvoker._inject_state_args(tool, tool_call_arguments.copy(), state)
+ # Check whether to inject streaming_callback
+ if (
+ enable_streaming_passthrough
+ and streaming_callback is not None
+ and "streaming_callback" not in final_args
+ and "streaming_callback" in ToolInvoker._get_func_params(tool)
+ ):
+ final_args["streaming_callback"] = streaming_callback
+ return final_args
+
+
+def _process_confirmation_strategies(
+ *,
+ confirmation_strategies: dict[str, ConfirmationStrategy],
+ messages_with_tool_calls: list[ChatMessage],
+ execution_context: "_ExecutionContext",
+) -> tuple[list[ChatMessage], list[ChatMessage]]:
+ """
+ Run the confirmation strategies and return modified tool call messages and updated chat history.
+
+ :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
+ :param messages_with_tool_calls: Chat messages containing tool calls
+ :param execution_context: The current execution context of the agent
+ :returns:
+ Tuple of modified messages with confirmed tool calls and updated chat history
+ """
+ # Run confirmation strategies and get tool execution decisions
+ teds = _run_confirmation_strategies(
+ confirmation_strategies=confirmation_strategies,
+ messages_with_tool_calls=messages_with_tool_calls,
+ execution_context=execution_context,
+ )
+
+ # Apply tool execution decisions to messages_with_tool_calls
+ rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions(
+ tool_call_messages=messages_with_tool_calls,
+ tool_execution_decisions=teds,
+ )
+
+ # Update the chat history with rejection messages and new tool call messages
+ new_chat_history = _update_chat_history(
+ chat_history=execution_context.state.get("messages"),
+ rejection_messages=rejection_messages,
+ tool_call_and_explanation_messages=modified_tool_call_messages,
+ )
+
+ return modified_tool_call_messages, new_chat_history
+
+
+def _run_confirmation_strategies(
+ confirmation_strategies: dict[str, ConfirmationStrategy],
+ messages_with_tool_calls: list[ChatMessage],
+ execution_context: "_ExecutionContext",
+) -> list[ToolExecutionDecision]:
+ """
+ Run confirmation strategies for tool calls in the provided chat messages.
+
+ :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies
+ :param messages_with_tool_calls: Messages containing tool calls to process
+ :param execution_context: The current execution context containing state and inputs
+ :returns:
+ A list of ToolExecutionDecision objects representing the decisions made for each tool call.
+ """
+ state = execution_context.state
+ tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]}
+ existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else []
+ existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name}
+ existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id}
+
+ teds = []
+ for message in messages_with_tool_calls:
+ if not message.tool_calls:
+ continue
+
+ for tool_call in message.tool_calls:
+ tool_name = tool_call.tool_name
+ tool_to_invoke = tools_with_names[tool_name]
+
+ # Prepare final tool args
+ final_args = _prepare_tool_args(
+ tool=tool_to_invoke,
+ tool_call_arguments=tool_call.arguments,
+ state=state,
+ streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"),
+ enable_streaming_passthrough=execution_context.tool_invoker_inputs.get(
+ "enable_streaming_passthrough", False
+ ),
+ )
+
+ # Get tool execution decisions from confirmation strategies
+ # If no confirmation strategy is defined for this tool, proceed with execution
+ if tool_name not in confirmation_strategies:
+ teds.append(
+ ToolExecutionDecision(
+ tool_call_id=tool_call.id,
+ tool_name=tool_name,
+ execute=True,
+ final_tool_params=final_args,
+ )
+ )
+ continue
+
+ # Check if there's already a decision for this tool call in the execution context
+ ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name)
+
+ # If not, run the confirmation strategy
+ if not ted:
+ ted = confirmation_strategies[tool_name].run(
+ tool_name=tool_name, tool_description=tool_to_invoke.description, tool_params=final_args
+ )
+ teds.append(ted)
+
+ return teds
+
+
+def _apply_tool_execution_decisions(
+ tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision]
+) -> tuple[list[ChatMessage], list[ChatMessage]]:
+ """
+ Apply the tool execution decisions to the tool call messages.
+
+ :param tool_call_messages: The tool call messages to apply the decisions to.
+ :param tool_execution_decisions: The tool execution decisions to apply.
+ :returns:
+ A tuple containing:
+ - A list of rejection messages for rejected tool calls. These are pairs of tool call and tool call result
+ messages.
+ - A list of tool call messages for confirmed or modified tool calls. If tool parameters were modified,
+ a user message explaining the modification is included before the tool call message.
+ """
+ decision_by_id = {d.tool_call_id: d for d in tool_execution_decisions if d.tool_call_id}
+ decision_by_name = {d.tool_name: d for d in tool_execution_decisions if d.tool_name}
+
+ def make_assistant_message(chat_message, tool_calls):
+ return ChatMessage.from_assistant(
+ text=chat_message.text,
+ meta=chat_message.meta,
+ name=chat_message.name,
+ tool_calls=tool_calls,
+ reasoning=chat_message.reasoning,
+ )
+
+ new_tool_call_messages = []
+ rejection_messages = []
+
+ for chat_msg in tool_call_messages:
+ new_tool_calls = []
+ for tc in chat_msg.tool_calls or []:
+ ted = decision_by_id.get(tc.id or "") or decision_by_name.get(tc.tool_name)
+ if not ted:
+ # This shouldn't happen, if so something went wrong in _run_confirmation_strategies
+ continue
+
+ if not ted.execute:
+ # rejected tool call
+ tool_result_text = ted.feedback or _REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tc.tool_name)
+ rejection_messages.extend(
+ [
+ make_assistant_message(chat_msg, [tc]),
+ ChatMessage.from_tool(tool_result=tool_result_text, origin=tc, error=True),
+ ]
+ )
+ continue
+
+ # Covers confirm and modify cases
+ final_args = ted.final_tool_params or {}
+ if tc.arguments != final_args:
+ # In the modify case we add a user message explaining the modification otherwise the LLM won't know
+ # why the tool parameters changed and will likely just try and call the tool again with the
+ # original parameters.
+ user_text = ted.feedback or _MODIFICATION_FEEDBACK_TEMPLATE.format(
+ tool_name=tc.tool_name, final_tool_params=final_args
+ )
+ new_tool_call_messages.append(ChatMessage.from_user(text=user_text))
+ new_tool_calls.append(replace(tc, arguments=final_args))
+
+ # Only add the tool call message if there are any tool calls left (i.e. not all were rejected)
+ if new_tool_calls:
+ new_tool_call_messages.append(make_assistant_message(chat_msg, new_tool_calls))
+
+ return rejection_messages, new_tool_call_messages
+
+
+def _update_chat_history(
+ chat_history: list[ChatMessage],
+ rejection_messages: list[ChatMessage],
+ tool_call_and_explanation_messages: list[ChatMessage],
+) -> list[ChatMessage]:
+ """
+ Update the chat history to include rejection messages and tool call messages at the appropriate positions.
+
+ Steps:
+ 1. Identify the last user message and the last tool message in the current chat history.
+ 2. Determine the insertion point as the maximum index of these two messages.
+ 3. Create a new chat history that includes:
+ - All messages up to the insertion point.
+ - Any rejection messages (pairs of tool call and tool call result messages).
+ - Any tool call messages for confirmed or modified tool calls, including user messages explaining modifications.
+
+ :param chat_history: The current chat history.
+ :param rejection_messages: Chat messages to add for rejected tool calls (pairs of tool call and tool call result
+ messages).
+ :param tool_call_and_explanation_messages: Tool call messages for confirmed or modified tool calls, which may
+ include user messages explaining modifications.
+ :returns:
+ The updated chat history.
+ """
+ user_indices = [i for i, message in enumerate(chat_history) if message.is_from("user")]
+ tool_indices = [i for i, message in enumerate(chat_history) if message.is_from("tool")]
+
+ last_user_idx = max(user_indices) if user_indices else -1
+ last_tool_idx = max(tool_indices) if tool_indices else -1
+
+ insertion_point = max(last_user_idx, last_tool_idx)
+
+ new_chat_history = chat_history[: insertion_point + 1] + rejection_messages + tool_call_and_explanation_messages
+ return new_chat_history
diff --git a/haystack_experimental/components/agents/human_in_the_loop/types.py b/haystack_experimental/components/agents/human_in_the_loop/types.py
new file mode 100644
index 00000000..8a0626d3
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/types.py
@@ -0,0 +1,89 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Optional, Protocol
+
+from haystack.core.serialization import default_from_dict, default_to_dict
+
+from haystack_experimental.components.agents.human_in_the_loop.dataclasses import (
+ ConfirmationUIResult,
+ ToolExecutionDecision,
+)
+
+# Ellipsis are needed to define the Protocol but pylint complains. See https://github.com/pylint-dev/pylint/issues/9319.
+# pylint: disable=unnecessary-ellipsis
+
+
+class ConfirmationUI(Protocol):
+ """Base class for confirmation UIs."""
+
+ def get_user_confirmation(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
+ ) -> ConfirmationUIResult:
+ """Get user confirmation for tool execution."""
+ ...
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize the UI to a dictionary."""
+ return default_to_dict(self)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "ConfirmationUI":
+ """Deserialize the ConfirmationUI from a dictionary."""
+ return default_from_dict(cls, data)
+
+
+class ConfirmationPolicy(Protocol):
+ """Base class for confirmation policies."""
+
+ def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool:
+ """Determine whether to ask for confirmation."""
+ ...
+
+ def update_after_confirmation(
+ self,
+ tool_name: str,
+ tool_description: str,
+ tool_params: dict[str, Any],
+ confirmation_result: ConfirmationUIResult,
+ ) -> None:
+ """Update the policy based on the confirmation UI result."""
+ pass
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize the policy to a dictionary."""
+ return default_to_dict(self)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "ConfirmationPolicy":
+ """Deserialize the policy from a dictionary."""
+ return default_from_dict(cls, data)
+
+
+class ConfirmationStrategy(Protocol):
+ def run(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None
+ ) -> ToolExecutionDecision:
+ """
+ Run the confirmation strategy for a given tool and its parameters.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate
+ the decision with a specific tool invocation.
+
+ :returns:
+ The result of the confirmation strategy (e.g., tool output, rejection message, etc.).
+ """
+ ...
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize the strategy to a dictionary."""
+ ...
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "ConfirmationStrategy":
+ """Deserialize the strategy from a dictionary."""
+ ...
diff --git a/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py b/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py
new file mode 100644
index 00000000..1f6f2430
--- /dev/null
+++ b/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py
@@ -0,0 +1,209 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from threading import Lock
+from typing import Any, Optional
+
+from haystack.core.serialization import default_to_dict
+from rich.console import Console
+from rich.panel import Panel
+from rich.prompt import Prompt
+
+from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult
+from haystack_experimental.components.agents.human_in_the_loop.types import ConfirmationUI
+
+_ui_lock = Lock()
+
+
+class RichConsoleUI(ConfirmationUI):
+ """Rich console interface for user interaction."""
+
+ def __init__(self, console: Optional[Console] = None):
+ self.console = console or Console()
+
+ def get_user_confirmation(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
+ ) -> ConfirmationUIResult:
+ """
+ Get user confirmation for tool execution via rich console prompts.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ :returns: ConfirmationUIResult based on user input.
+ """
+ with _ui_lock:
+ self._display_tool_info(tool_name, tool_description, tool_params)
+ # If wrong input is provided, Prompt.ask will re-prompt
+ choice = Prompt.ask("\nYour choice", choices=["y", "n", "m"], default="y", console=self.console)
+ return self._process_choice(choice, tool_params)
+
+ def _display_tool_info(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> None:
+ """
+ Display tool information and parameters in a rich panel.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ """
+ lines = [
+ f"[bold yellow]Tool:[/bold yellow] {tool_name}",
+ f"[bold yellow]Description:[/bold yellow] {tool_description}",
+ "\n[bold yellow]Arguments:[/bold yellow]",
+ ]
+
+ if tool_params:
+ for k, v in tool_params.items():
+ lines.append(f"[cyan]{k}:[/cyan] {v}")
+ else:
+ lines.append(" (No arguments)")
+
+ self.console.print(Panel("\n".join(lines), title="🔧 Tool Execution Request", title_align="left"))
+
+ def _process_choice(self, choice: str, tool_params: dict[str, Any]) -> ConfirmationUIResult:
+ """
+ Process the user's choice and return the corresponding ConfirmationUIResult.
+
+ :param choice: The user's choice ('y', 'n', or 'm').
+ :param tool_params: The original tool parameters.
+ :returns:
+ ConfirmationUIResult based on user input.
+ """
+ if choice == "y":
+ return ConfirmationUIResult(action="confirm")
+ elif choice == "m":
+ return self._modify_params(tool_params)
+ else: # reject
+ feedback = Prompt.ask("Feedback message (optional)", default="", console=self.console)
+ return ConfirmationUIResult(action="reject", feedback=feedback or None)
+
+ def _modify_params(self, tool_params: dict[str, Any]) -> ConfirmationUIResult:
+ """
+ Prompt the user to modify tool parameters.
+
+ :param tool_params: The original tool parameters.
+ :returns:
+ ConfirmationUIResult with modified parameters.
+ """
+ new_params: dict[str, Any] = {}
+ for k, v in tool_params.items():
+ # We don't JSON dump strings to avoid users needing to input extra quotes
+ default_val = json.dumps(v) if not isinstance(v, str) else v
+ while True:
+ new_val = Prompt.ask(f"Modify '{k}'", default=default_val, console=self.console)
+ try:
+ if isinstance(v, str):
+ # Always treat input as string
+ new_params[k] = new_val
+ else:
+ # Parse JSON for all non-string types
+ new_params[k] = json.loads(new_val)
+ break
+ except json.JSONDecodeError:
+ self.console.print("[red]❌ Invalid JSON, please try again.[/red]")
+
+ return ConfirmationUIResult(action="modify", new_tool_params=new_params)
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes the RichConsoleConfirmationUI to a dictionary.
+
+ :returns:
+ Dictionary with serialized data.
+ """
+ # Note: Console object is not serializable; we store None
+ return default_to_dict(self, console=None)
+
+
+class SimpleConsoleUI(ConfirmationUI):
+ """Simple console interface using standard input/output."""
+
+ def get_user_confirmation(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
+ ) -> ConfirmationUIResult:
+ """
+ Get user confirmation for tool execution via simple console prompts.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ """
+ with _ui_lock:
+ self._display_tool_info(tool_name, tool_description, tool_params)
+ valid_choices = {"y", "yes", "n", "no", "m", "modify"}
+ while True:
+ choice = input("Confirm execution? (y=confirm / n=reject / m=modify): ").strip().lower()
+ if choice in valid_choices:
+ break
+ print("Invalid input. Please enter 'y', 'n', or 'm'.")
+ return self._process_choice(choice, tool_params)
+
+ def _display_tool_info(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> None:
+ """
+ Display tool information and parameters in the console.
+
+ :param tool_name: The name of the tool to be executed.
+ :param tool_description: The description of the tool.
+ :param tool_params: The parameters to be passed to the tool.
+ """
+ print("\n--- Tool Execution Request ---")
+ print(f"Tool: {tool_name}")
+ print(f"Description: {tool_description}")
+ print("Arguments:")
+ if tool_params:
+ for k, v in tool_params.items():
+ print(f" {k}: {v}")
+ else:
+ print(" (No arguments)")
+ print("-" * 30)
+
+ def _process_choice(self, choice: str, tool_params: dict[str, Any]) -> ConfirmationUIResult:
+ """
+ Process the user's choice and return the corresponding ConfirmationUIResult.
+
+ :param choice: The user's choice ('y', 'n', or 'm').
+ :param tool_params: The original tool parameters.
+ :returns:
+ ConfirmationUIResult based on user input.
+ """
+ if choice in ("y", "yes"):
+ return ConfirmationUIResult(action="confirm")
+ elif choice in ("m", "modify"):
+ return self._modify_params(tool_params)
+ else: # reject
+ feedback = input("Feedback message (optional): ").strip()
+ return ConfirmationUIResult(action="reject", feedback=feedback or None)
+
+ def _modify_params(self, tool_params: dict[str, Any]) -> ConfirmationUIResult:
+ """
+ Prompt the user to modify tool parameters.
+
+ :param tool_params: The original tool parameters.
+ :returns:
+ ConfirmationUIResult with modified parameters.
+ """
+ new_params: dict[str, Any] = {}
+
+ if not tool_params:
+ print("No parameters to modify, skipping modification.")
+ return ConfirmationUIResult(action="modify", new_tool_params=new_params)
+
+ for k, v in tool_params.items():
+ # We don't JSON dump strings to avoid users needing to input extra quotes
+ default_val = json.dumps(v) if not isinstance(v, str) else v
+ while True:
+ new_val = input(f"Modify '{k}' (current: {default_val}): ").strip() or default_val
+ try:
+ if isinstance(v, str):
+ # Always treat input as string
+ new_params[k] = new_val
+ else:
+ # Parse JSON for all non-string types
+ new_params[k] = json.loads(new_val)
+ break
+ except json.JSONDecodeError:
+ print("❌ Invalid JSON, please try again.")
+
+ return ConfirmationUIResult(action="modify", new_tool_params=new_params)
diff --git a/haystack_experimental/core/__init__.py b/haystack_experimental/core/__init__.py
new file mode 100644
index 00000000..c1764a6e
--- /dev/null
+++ b/haystack_experimental/core/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
diff --git a/haystack_experimental/core/pipeline/__init__.py b/haystack_experimental/core/pipeline/__init__.py
new file mode 100644
index 00000000..c1764a6e
--- /dev/null
+++ b/haystack_experimental/core/pipeline/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py
new file mode 100644
index 00000000..da782ebb
--- /dev/null
+++ b/haystack_experimental/core/pipeline/breakpoint.py
@@ -0,0 +1,174 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from copy import deepcopy
+from dataclasses import replace
+from datetime import datetime
+from typing import TYPE_CHECKING, Any, Optional
+
+from haystack import logging
+from haystack.core.errors import BreakpointException
+from haystack.core.pipeline.breakpoint import _save_pipeline_snapshot
+from haystack.dataclasses import ChatMessage
+from haystack.dataclasses.breakpoints import AgentBreakpoint, PipelineSnapshot, PipelineState, ToolBreakpoint
+from haystack.utils.base_serialization import _serialize_value_with_schema
+from haystack.utils.misc import _get_output_dir
+
+from haystack_experimental.dataclasses.breakpoints import AgentSnapshot
+
+if TYPE_CHECKING:
+ from haystack_experimental.components.agents.agent import _ExecutionContext
+ from haystack_experimental.components.agents.human_in_the_loop import ToolExecutionDecision
+
+
+logger = logging.getLogger(__name__)
+
+
+def _create_agent_snapshot(
+ *,
+ component_visits: dict[str, int],
+ agent_breakpoint: AgentBreakpoint,
+ component_inputs: dict[str, Any],
+ tool_execution_decisions: Optional[list["ToolExecutionDecision"]] = None,
+) -> AgentSnapshot:
+ """
+ Create a snapshot of the agent's state.
+
+ NOTE: Only difference to Haystack's native implementation is the addition of tool_execution_decisions to the
+ AgentSnapshot.
+
+ :param component_visits: The visit counts for the agent's components.
+ :param agent_breakpoint: AgentBreakpoint object containing breakpoints
+ :param component_inputs: The inputs to the agent's components.
+ :param tool_execution_decisions: Optional list of ToolExecutionDecision objects representing decisions made
+ regarding tool executions.
+ :return: An AgentSnapshot containing the agent's state and component visits.
+ """
+ return AgentSnapshot(
+ component_inputs={
+ "chat_generator": _serialize_value_with_schema(deepcopy(component_inputs["chat_generator"])),
+ "tool_invoker": _serialize_value_with_schema(deepcopy(component_inputs["tool_invoker"])),
+ },
+ component_visits=component_visits,
+ break_point=agent_breakpoint,
+ timestamp=datetime.now(),
+ tool_execution_decisions=tool_execution_decisions,
+ )
+
+
+def _create_pipeline_snapshot_from_tool_invoker(
+ *,
+ execution_context: "_ExecutionContext",
+ tool_name: Optional[str] = None,
+ agent_name: Optional[str] = None,
+ break_point: Optional[AgentBreakpoint] = None,
+ parent_snapshot: Optional[PipelineSnapshot] = None,
+) -> PipelineSnapshot:
+ """
+ Create a pipeline snapshot when a tool invoker breakpoint is raised or an exception during execution occurs.
+
+ :param execution_context: The current execution context of the agent.
+ :param tool_name: The name of the tool that triggered the breakpoint, if available.
+ :param agent_name: The name of the agent component if present in a pipeline.
+ :param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
+ A scenario where a new breakpoint is created is when an exception occurs during tool execution and we want to
+ capture the state at that point.
+ :param parent_snapshot: An optional parent PipelineSnapshot to build upon.
+ :returns:
+ A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
+ """
+ if break_point is None:
+ agent_breakpoint = AgentBreakpoint(
+ agent_name=agent_name or "agent",
+ break_point=ToolBreakpoint(
+ component_name="tool_invoker",
+ visit_count=execution_context.component_visits["tool_invoker"],
+ tool_name=tool_name,
+ snapshot_file_path=_get_output_dir("pipeline_snapshot"),
+ ),
+ )
+ else:
+ agent_breakpoint = break_point
+
+ messages = execution_context.state.data["messages"]
+ agent_snapshot = _create_agent_snapshot(
+ component_visits=execution_context.component_visits,
+ agent_breakpoint=agent_breakpoint,
+ component_inputs={
+ "chat_generator": {"messages": messages[:-1], **execution_context.chat_generator_inputs},
+ "tool_invoker": {
+ "messages": messages[-1:], # tool invoker consumes last msg from the chat_generator, contains tool call
+ "state": execution_context.state,
+ **execution_context.tool_invoker_inputs,
+ },
+ },
+ tool_execution_decisions=execution_context.tool_execution_decisions,
+ )
+ if parent_snapshot is None:
+ # Create an empty pipeline snapshot if no parent snapshot is provided
+ final_snapshot = PipelineSnapshot(
+ pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}),
+ timestamp=agent_snapshot.timestamp,
+ break_point=agent_snapshot.break_point,
+ agent_snapshot=agent_snapshot,
+ original_input_data={},
+ ordered_component_names=[],
+ include_outputs_from=set(),
+ )
+ else:
+ final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot)
+
+ return final_snapshot
+
+
+def _trigger_tool_invoker_breakpoint(*, llm_messages: list[ChatMessage], pipeline_snapshot: PipelineSnapshot) -> None:
+ """
+ Check if a tool call breakpoint should be triggered before executing the tool invoker.
+
+ NOTE: Only difference to Haystack's native implementation is that it includes the fix from
+ PR https://github.com/deepset-ai/haystack/pull/9853 where we make sure to check all tool calls in a chat message
+ when checking if a BreakpointException should be made.
+
+ :param llm_messages: List of ChatMessage objects containing potential tool calls.
+ :param pipeline_snapshot: PipelineSnapshot object containing the state of the pipeline and Agent snapshot.
+ :raises BreakpointException: If the breakpoint is triggered, indicating a breakpoint has been reached for a tool
+ call.
+ """
+ if not pipeline_snapshot.agent_snapshot:
+ raise ValueError("PipelineSnapshot must contain an AgentSnapshot to trigger a tool call breakpoint.")
+
+ if not isinstance(pipeline_snapshot.agent_snapshot.break_point.break_point, ToolBreakpoint):
+ return
+
+ tool_breakpoint = pipeline_snapshot.agent_snapshot.break_point.break_point
+
+ # Check if we should break for this specific tool or all tools
+ if tool_breakpoint.tool_name is None:
+ # Break for any tool call
+ should_break = any(msg.tool_call for msg in llm_messages)
+ else:
+ # Break only for the specific tool
+ should_break = any(
+ tc.tool_name == tool_breakpoint.tool_name for msg in llm_messages for tc in msg.tool_calls or []
+ )
+
+ if not should_break:
+ return # No breakpoint triggered
+
+ _save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot)
+
+ msg = (
+ f"Breaking at {tool_breakpoint.component_name} visit count "
+ f"{pipeline_snapshot.agent_snapshot.component_visits[tool_breakpoint.component_name]}"
+ )
+ if tool_breakpoint.tool_name:
+ msg += f" for tool {tool_breakpoint.tool_name}"
+ logger.info(msg)
+
+ raise BreakpointException(
+ message=msg,
+ component=tool_breakpoint.component_name,
+ inputs=pipeline_snapshot.agent_snapshot.component_inputs,
+ results=pipeline_snapshot.agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"],
+ )
diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py
new file mode 100644
index 00000000..c1764a6e
--- /dev/null
+++ b/haystack_experimental/dataclasses/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
diff --git a/haystack_experimental/dataclasses/breakpoints.py b/haystack_experimental/dataclasses/breakpoints.py
new file mode 100644
index 00000000..880badc4
--- /dev/null
+++ b/haystack_experimental/dataclasses/breakpoints.py
@@ -0,0 +1,53 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, Optional
+
+from haystack.dataclasses.breakpoints import AgentBreakpoint
+from haystack.dataclasses.breakpoints import AgentSnapshot as HaystackAgentSnapshot
+
+from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ToolExecutionDecision
+
+
+@dataclass
+class AgentSnapshot(HaystackAgentSnapshot):
+ tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Convert the AgentSnapshot to a dictionary representation.
+
+ :return: A dictionary containing the agent state, timestamp, and breakpoint.
+ """
+ return {
+ "component_inputs": self.component_inputs,
+ "component_visits": self.component_visits,
+ "break_point": self.break_point.to_dict(),
+ "timestamp": self.timestamp.isoformat() if self.timestamp else None,
+ "tool_execution_decisions": [ted.to_dict() for ted in self.tool_execution_decisions]
+ if self.tool_execution_decisions
+ else None,
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict) -> "AgentSnapshot":
+ """
+ Populate the AgentSnapshot from a dictionary representation.
+
+ :param data: A dictionary containing the agent state, timestamp, and breakpoint.
+ :return: An instance of AgentSnapshot.
+ """
+ return cls(
+ component_inputs=data["component_inputs"],
+ component_visits=data["component_visits"],
+ break_point=AgentBreakpoint.from_dict(data["break_point"]),
+ timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None,
+ tool_execution_decisions=[
+ ToolExecutionDecision.from_dict(ted) for ted in data.get("tool_execution_decisions", [])
+ ]
+ if data.get("tool_execution_decisions")
+ else None,
+ )
diff --git a/hitl_breakpoint_example.py b/hitl_breakpoint_example.py
new file mode 100644
index 00000000..0b7a5cbe
--- /dev/null
+++ b/hitl_breakpoint_example.py
@@ -0,0 +1,186 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from pathlib import Path
+from typing import Any, Optional
+
+from haystack.components.generators.chat import OpenAIChatGenerator
+from haystack.core.errors import BreakpointException
+from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
+from haystack.dataclasses import ChatMessage
+from haystack.dataclasses.breakpoints import PipelineSnapshot
+from haystack.tools import create_tool_from_function
+from rich.console import Console
+
+from haystack_experimental.components.agents.agent import Agent
+from haystack_experimental.components.agents.human_in_the_loop import (
+ AlwaysAskPolicy,
+ BlockingConfirmationStrategy,
+ BreakpointConfirmationStrategy,
+ RichConsoleUI,
+ ToolExecutionDecision,
+)
+from haystack_experimental.components.agents.human_in_the_loop.breakpoint import (
+ get_tool_calls_and_descriptions_from_snapshot,
+)
+
+
+def get_bank_balance(account_id: str) -> str:
+ """
+ Simulate fetching a bank balance for a given account ID.
+
+ :param account_id: The ID of the bank account.
+ :returns:
+ A string representing the bank balance.
+ """
+ return f"Balance for account {account_id} is $1,234.56"
+
+
+def addition(a: float, b: float) -> float:
+ """
+ A simple addition function.
+
+ :param a: First float.
+ :param b: Second float.
+ :returns:
+ Sum of a and b.
+ """
+ return a + b
+
+
+def get_latest_snapshot(snapshot_file_path: str) -> PipelineSnapshot:
+ """
+ Load the latest pipeline snapshot from the 'pipeline_snapshots' directory.
+ """
+ snapshot_dir = Path(snapshot_file_path)
+ possible_snapshots = [snapshot_dir / f for f in os.listdir(snapshot_dir)]
+ latest_snapshot_file = str(max(possible_snapshots, key=os.path.getctime))
+ return load_pipeline_snapshot(latest_snapshot_file)
+
+
+def frontend_simulate_tool_decision(
+ tool_calls: list[dict[str, Any]], tool_descriptions: dict[str, str], console: Console
+) -> list[dict]:
+ """
+ Simulate front-end receiving tool calls, prompting user, and sending back decisions.
+
+ :param tool_calls:
+ A list of tool call dictionaries containing tool_name, id, and arguments.
+ :param tool_descriptions:
+ A dictionary mapping tool names to their descriptions.
+ :param console:
+ A Rich Console instance for displaying prompts and messages.
+ :returns:
+ A list of serialized ToolExecutionDecision dictionaries.
+ """
+
+ confirmation_strategy = BlockingConfirmationStrategy(
+ confirmation_policy=AlwaysAskPolicy(),
+ confirmation_ui=RichConsoleUI(console=console),
+ )
+
+ tool_execution_decisions = []
+ for tc in tool_calls:
+ tool_execution_decisions.append(
+ confirmation_strategy.run(
+ tool_name=tc["tool_name"],
+ tool_description=tool_descriptions[tc["tool_name"]],
+ tool_call_id=tc["id"],
+ tool_params=tc["arguments"],
+ )
+ )
+ return [ted.to_dict() for ted in tool_execution_decisions]
+
+
+def run_agent(
+ agent: Agent,
+ messages: list[ChatMessage],
+ console: Console,
+ snapshot_file_path: Optional[str] = None,
+ tool_execution_decisions: Optional[list[dict[str, Any]]] = None,
+) -> Optional[dict[str, Any]]:
+ """
+ Run the agent with the given messages and optional snapshot.
+ """
+ # Load the latest snapshot if a path is provided
+ snapshot = None
+ if snapshot_file_path:
+ snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path)
+
+ # Add any new tool execution decisions to the snapshot
+ if tool_execution_decisions:
+ teds = [ToolExecutionDecision.from_dict(ted) for ted in tool_execution_decisions]
+ existing_decisions = snapshot.agent_snapshot.tool_execution_decisions or []
+ snapshot.agent_snapshot.tool_execution_decisions = existing_decisions + teds
+
+ try:
+ return agent.run(messages=messages, snapshot=snapshot.agent_snapshot if snapshot else None)
+ except BreakpointException as e:
+ console.print("[bold red]Execution paused by Breakpoint Confirmation Strategy:[/bold red]", str(e))
+ return None
+
+
+def main(user_message: str):
+ """
+ Main function to demonstrate the Breakpoint Confirmation Strategy with an agent.
+ """
+ cons = Console()
+ cons.print("\n[bold blue]=== Breakpoint Confirmation Strategy Example ===[/bold blue]\n")
+ cons.print(f"[bold yellow]User Message:[/bold yellow] {user_message}\n")
+
+ # Define agent with both tools and breakpoint confirmation strategies
+ addition_tool = create_tool_from_function(
+ function=addition,
+ name="addition",
+ description="Add two floats together.",
+ )
+ balance_tool = create_tool_from_function(
+ function=get_bank_balance,
+ name="get_bank_balance",
+ description="Get the bank balance for a given account ID.",
+ )
+ snapshot_fp = "pipeline_snapshots"
+ bank_agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4.1"),
+ tools=[balance_tool, addition_tool],
+ system_prompt="You are a helpful financial assistant. Use the provided tool to get bank balances when needed.",
+ confirmation_strategies={
+ balance_tool.name: BreakpointConfirmationStrategy(snapshot_file_path=snapshot_fp),
+ addition_tool.name: BreakpointConfirmationStrategy(snapshot_file_path=snapshot_fp),
+ },
+ )
+
+ # Step 1: Initial run
+ result = run_agent(bank_agent, [ChatMessage.from_user(user_message)], cons)
+
+ # Step 2: Loop to handle break point confirmation strategy until agent completes
+ while result is None:
+ # Load the latest snapshot from disk and prep data for front-end
+ loaded_snapshot = get_latest_snapshot(snapshot_file_path=snapshot_fp)
+ serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot(
+ agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
+ )
+
+ # Simulate front-end interaction
+ serialized_teds = frontend_simulate_tool_decision(serialized_tool_calls, tool_descripts, cons)
+
+ # Re-run the agent with the new tool execution decisions
+ result = run_agent(bank_agent, [], cons, snapshot_fp, serialized_teds)
+
+ # Step 3: Final result
+ last_message = result["last_message"]
+ cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}")
+
+
+if __name__ == "__main__":
+ for usr_msg in [
+ # Single tool call question --> Works
+ "What's the balance of account 56789?",
+ # Two tool call question --> Works
+ "What's the balance of account 56789 and what is 5.5 + 3.2?",
+ # Multiple sequential tool calls question --> Works
+ "What's the balance of account 56789? If it's lower than $2000, what's the balance of account 12345?",
+ ]:
+ main(usr_msg)
diff --git a/hitl_intro_example.py b/hitl_intro_example.py
new file mode 100644
index 00000000..11985036
--- /dev/null
+++ b/hitl_intro_example.py
@@ -0,0 +1,114 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from haystack.components.generators.chat import OpenAIChatGenerator
+from haystack.dataclasses import ChatMessage
+from haystack.tools import create_tool_from_function
+from rich.console import Console
+
+from haystack_experimental.components.agents.agent import Agent
+from haystack_experimental.components.agents.human_in_the_loop import (
+ AlwaysAskPolicy,
+ AskOncePolicy,
+ BlockingConfirmationStrategy,
+ NeverAskPolicy,
+ RichConsoleUI,
+ SimpleConsoleUI,
+)
+
+
+def addition(a: float, b: float) -> float:
+ """
+ A simple addition function.
+
+ :param a: First float.
+ :param b: Second float.
+ :returns:
+ Sum of a and b.
+ """
+ return a + b
+
+
+addition_tool = create_tool_from_function(
+ function=addition,
+ name="addition",
+ description="Add two floats together.",
+)
+
+
+def get_bank_balance(account_id: str) -> str:
+ """
+ Simulate fetching a bank balance for a given account ID.
+
+ :param account_id: The ID of the bank account.
+ :returns:
+ A string representing the bank balance.
+ """
+ return f"Balance for account {account_id} is $1,234.56"
+
+
+balance_tool = create_tool_from_function(
+ function=get_bank_balance,
+ name="get_bank_balance",
+ description="Get the bank balance for a given account ID.",
+)
+
+
+def get_phone_number(name: str) -> str:
+ """
+ Simulate fetching a phone number for a given name.
+
+ :param name: The name of the person.
+ :returns:
+ A string representing the phone number.
+ """
+ return f"The phone number for {name} is (123) 456-7890"
+
+
+phone_tool = create_tool_from_function(
+ function=get_phone_number,
+ name="get_phone_number",
+ description="Get the phone number for a given name.",
+)
+
+# Define shared console
+cons = Console()
+
+# Define Main Agent with multiple tools and different confirmation strategies
+agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4.1"),
+ tools=[balance_tool, addition_tool, phone_tool],
+ system_prompt="You are a helpful financial assistant. Use the provided tool to get bank balances when needed.",
+ confirmation_strategies={
+ balance_tool.name: BlockingConfirmationStrategy(
+ confirmation_policy=AlwaysAskPolicy(), confirmation_ui=RichConsoleUI(console=cons)
+ ),
+ addition_tool.name: BlockingConfirmationStrategy(
+ confirmation_policy=NeverAskPolicy(), confirmation_ui=SimpleConsoleUI()
+ ),
+ phone_tool.name: BlockingConfirmationStrategy(
+ confirmation_policy=AskOncePolicy(), confirmation_ui=SimpleConsoleUI()
+ ),
+ },
+)
+
+# Call bank tool with confirmation (Always Ask) using RichConsoleUI
+result = agent.run([ChatMessage.from_user("What's the balance of account 56789?")])
+last_message = result["last_message"]
+cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}")
+
+# Call addition tool with confirmation (Never Ask)
+result = agent.run([ChatMessage.from_user("What is 5.5 + 3.2?")])
+last_message = result["last_message"]
+print(f"\nAgent Result: {last_message.text}")
+
+# Call phone tool with confirmation (Ask Once) using SimpleConsoleUI
+result = agent.run([ChatMessage.from_user("What is the phone number of Alice?")])
+last_message = result["last_message"]
+print(f"\nAgent Result: {last_message.text}")
+
+# Call phone tool again to see that it doesn't ask for confirmation the second time
+result = agent.run([ChatMessage.from_user("What is the phone number of Alice?")])
+last_message = result["last_message"]
+print(f"\nAgent Result: {last_message.text}")
diff --git a/hitl_multi_agent_example.py b/hitl_multi_agent_example.py
new file mode 100644
index 00000000..ce352d0d
--- /dev/null
+++ b/hitl_multi_agent_example.py
@@ -0,0 +1,119 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from haystack.components.generators.chat import OpenAIChatGenerator
+from haystack.dataclasses import ChatMessage
+from haystack.tools import ComponentTool, create_tool_from_function
+from rich.console import Console
+
+from haystack_experimental.components.agents.agent import Agent
+from haystack_experimental.components.agents.human_in_the_loop import (
+ AlwaysAskPolicy,
+ BlockingConfirmationStrategy,
+ RichConsoleUI,
+)
+
+
+def addition(a: float, b: float) -> float:
+ """
+ A simple addition function.
+
+ :param a: First float.
+ :param b: Second float.
+ :returns:
+ Sum of a and b.
+ """
+ return a + b
+
+
+addition_tool = create_tool_from_function(
+ function=addition,
+ name="addition",
+ description="Add two floats together.",
+)
+
+
+def get_bank_balance(account_id: str) -> str:
+ """
+ Simulate fetching a bank balance for a given account ID.
+
+ :param account_id: The ID of the bank account.
+ :returns:
+ A string representing the bank balance.
+ """
+ return f"Balance for account {account_id} is $1,234.56"
+
+
+balance_tool = create_tool_from_function(
+ function=get_bank_balance,
+ name="get_bank_balance",
+ description="Get the bank balance for a given account ID.",
+)
+
+# Define shared console for all UIs
+cons = Console()
+
+# Define Bank Sub-Agent
+bank_agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4.1"),
+ tools=[balance_tool],
+ system_prompt="You are a helpful financial assistant. Use the provided tool to get bank balances when needed.",
+ confirmation_strategies={
+ balance_tool.name: BlockingConfirmationStrategy(
+ confirmation_policy=AlwaysAskPolicy(), confirmation_ui=RichConsoleUI(console=cons)
+ ),
+ },
+)
+bank_agent_tool = ComponentTool(
+ component=bank_agent,
+ name="bank_agent_tool",
+ description="A bank agent that can get bank balances.",
+)
+
+# Define Math Sub-Agent
+math_agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4.1"),
+ tools=[addition_tool],
+ system_prompt="You are a helpful math assistant. Use the provided tool to perform addition when needed.",
+ confirmation_strategies={
+ addition_tool.name: BlockingConfirmationStrategy(
+ # We use AlwaysAskPolicy here for demonstration; in real scenarios, you might choose NeverAskPolicy
+ confirmation_policy=AlwaysAskPolicy(),
+ confirmation_ui=RichConsoleUI(console=cons),
+ ),
+ },
+)
+math_agent_tool = ComponentTool(
+ component=math_agent,
+ name="math_agent_tool",
+ description="A math agent that can perform addition.",
+)
+
+# Define Main Agent with Sub-Agents as tools
+planner_agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4.1"),
+ tools=[bank_agent_tool, math_agent_tool],
+ system_prompt="""You are a master agent that can delegate tasks to sub-agents based on the user's request.
+Available sub-agents:
+- bank_agent_tool: A bank agent that can get bank balances.
+- math_agent_tool: A math agent that can perform addition.
+Use the appropriate sub-agent to handle the user's request.
+""",
+)
+
+# Make bank balance request to planner agent
+result = planner_agent.run([ChatMessage.from_user("What's the balance of account 56789?")])
+last_message = result["last_message"]
+cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}")
+
+# Make addition request to planner agent
+result = planner_agent.run([ChatMessage.from_user("What is 5.5 + 3.2?")])
+last_message = result["last_message"]
+print(f"\nAgent Result: {last_message.text}")
+
+# Make bank balance request and addition request to planner agent
+# NOTE: This will try and invoke both sub-agents in parallel requiring a thread-safe UI
+result = planner_agent.run([ChatMessage.from_user("What's the balance of account 56789 and what is 5.5 + 3.2?")])
+last_message = result["last_message"]
+print(f"\nAgent Result: {last_message.text}")
diff --git a/pyproject.toml b/pyproject.toml
index d3ed63f3..61fdde40 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,6 +28,7 @@ classifiers = [
dependencies = [
"haystack-ai",
+ "rich", # For pretty printing in the console used by human-in-the-loop utilities
]
[project.urls]
diff --git a/test/components/agents/__init__.py b/test/components/agents/__init__.py
new file mode 100644
index 00000000..c1764a6e
--- /dev/null
+++ b/test/components/agents/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
diff --git a/test/components/agents/human_in_the_loop/__init__.py b/test/components/agents/human_in_the_loop/__init__.py
new file mode 100644
index 00000000..c1764a6e
--- /dev/null
+++ b/test/components/agents/human_in_the_loop/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
diff --git a/test/components/agents/human_in_the_loop/test_breakpoint.py b/test/components/agents/human_in_the_loop/test_breakpoint.py
new file mode 100644
index 00000000..503d2a1c
--- /dev/null
+++ b/test/components/agents/human_in_the_loop/test_breakpoint.py
@@ -0,0 +1,148 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint
+
+from haystack_experimental.dataclasses.breakpoints import AgentSnapshot
+from haystack_experimental.components.agents.human_in_the_loop.breakpoint import (
+ get_tool_calls_and_descriptions_from_snapshot,
+)
+
+
+def get_bank_balance(account_id: str) -> str:
+ return f"The balance for account {account_id} is $1,234.56."
+
+
+def addition(a: float, b: float) -> float:
+ return a + b
+
+
+def test_get_tool_calls_and_descriptions_from_snapshot():
+ agent_snapshot = AgentSnapshot(
+ component_inputs={
+ "chat_generator": {},
+ "tool_invoker": {
+ "serialization_schema": {
+ "type": "object",
+ "properties": {
+ "messages": {
+ "type": "array",
+ "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
+ },
+ "state": {"type": "haystack.components.agents.state.state.State"},
+ "tools": {"type": "array", "items": {"type": "haystack.tools.tool.Tool"}},
+ "enable_streaming_callback_passthrough": {"type": "boolean"},
+ },
+ },
+ "serialized_data": {
+ "messages": [
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "tool_call": {
+ "tool_name": "get_bank_balance",
+ "arguments": {"account_id": "56789"},
+ "id": None,
+ }
+ }
+ ],
+ }
+ ],
+ "state": {
+ "schema": {
+ "messages": {
+ "type": "list[haystack.dataclasses.chat_message.ChatMessage]",
+ "handler": "haystack.components.agents.state.state_utils.merge_lists",
+ }
+ },
+ "data": {
+ "serialization_schema": {
+ "type": "object",
+ "properties": {
+ "messages": {
+ "type": "array",
+ "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"},
+ }
+ },
+ },
+ "serialized_data": {
+ "messages": [
+ {
+ "role": "system",
+ "content": [
+ {
+ "text": "You are a helpful financial assistant. Use the provided tool to get bank balances when needed."
+ }
+ ],
+ },
+ {
+ "role": "user",
+ "content": [{"text": "What's the balance of account 56789?"}],
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "tool_call": {
+ "tool_name": "get_bank_balance",
+ "arguments": {"account_id": "56789"},
+ "id": None,
+ }
+ }
+ ],
+ },
+ ]
+ },
+ },
+ },
+ "tools": [
+ {
+ "type": "haystack.tools.tool.Tool",
+ "data": {
+ "name": "get_bank_balance",
+ "description": "Get the bank balance for a given account ID.",
+ "parameters": {
+ "properties": {"account_id": {"type": "string"}},
+ "required": ["account_id"],
+ "type": "object",
+ },
+ "function": "test.components.agents.human_in_the_loop.test_breakpoint.get_bank_balance",
+ },
+ },
+ {
+ "type": "haystack.tools.tool.Tool",
+ "data": {
+ "name": "addition",
+ "description": "Add two floats together.",
+ "parameters": {
+ "properties": {"a": {"type": "number"}, "b": {"type": "number"}},
+ "required": ["a", "b"],
+ "type": "object",
+ },
+ "function": "test.components.agents.human_in_the_loop.test_breakpoint.addition",
+ },
+ },
+ ],
+ "enable_streaming_callback_passthrough": False,
+ },
+ },
+ },
+ component_visits={"chat_generator": 1, "tool_invoker": 0},
+ break_point=AgentBreakpoint(
+ agent_name="agent",
+ break_point=ToolBreakpoint(
+ tool_name="get_bank_balance", component_name="tool_invoker", visit_count=0, snapshot_file_path=None
+ ),
+ ),
+ )
+
+ tool_calls, tool_descriptions = get_tool_calls_and_descriptions_from_snapshot(
+ agent_snapshot=agent_snapshot, breakpoint_tool_only=True
+ )
+
+ assert len(tool_calls) == 1
+ assert tool_calls[0]["tool_name"] == "get_bank_balance"
+ assert tool_calls[0]["arguments"] == {"account_id": "56789"}
+ assert tool_descriptions == {"get_bank_balance": "Get the bank balance for a given account ID."}
diff --git a/test/components/agents/human_in_the_loop/test_dataclasses.py b/test/components/agents/human_in_the_loop/test_dataclasses.py
new file mode 100644
index 00000000..ff0d0cd2
--- /dev/null
+++ b/test/components/agents/human_in_the_loop/test_dataclasses.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from haystack_experimental.components.agents.human_in_the_loop import (
+ ConfirmationUIResult,
+ ToolExecutionDecision,
+)
+
+
+class TestConfirmationUIResult:
+ def test_init(self):
+ original = ConfirmationUIResult(
+ action="reject",
+ feedback="Changed my mind",
+ )
+ assert original.action == "reject"
+ assert original.feedback == "Changed my mind"
+ assert original.new_tool_params is None
+
+
+class TestToolExecutionDecision:
+ def test_init(self):
+ decision = ToolExecutionDecision(
+ execute=True,
+ tool_name="test_tool",
+ tool_call_id="test_tool_call_id",
+ final_tool_params={"param1": "new_value"},
+ )
+ assert decision.execute is True
+ assert decision.final_tool_params == {"param1": "new_value"}
+ assert decision.tool_call_id == "test_tool_call_id"
+ assert decision.tool_name == "test_tool"
+
+ def test_to_dict(self):
+ original = ToolExecutionDecision(
+ execute=True,
+ tool_name="test_tool",
+ tool_call_id="test_tool_call_id",
+ final_tool_params={"param1": "new_value"},
+ )
+ as_dict = original.to_dict()
+ assert as_dict == {
+ "execute": True,
+ "tool_name": "test_tool",
+ "tool_call_id": "test_tool_call_id",
+ "feedback": None,
+ "final_tool_params": {"param1": "new_value"},
+ }
+
+ def test_from_dict(self):
+ data = {
+ "execute": False,
+ "tool_name": "another_tool",
+ "tool_call_id": "another_tool_call_id",
+ "feedback": "Not needed",
+ "final_tool_params": {"paramA": 123},
+ }
+ decision = ToolExecutionDecision.from_dict(data)
+ assert decision.execute is False
+ assert decision.tool_name == "another_tool"
+ assert decision.tool_call_id == "another_tool_call_id"
+ assert decision.feedback == "Not needed"
+ assert decision.final_tool_params == {"paramA": 123}
diff --git a/test/components/agents/human_in_the_loop/test_policies.py b/test/components/agents/human_in_the_loop/test_policies.py
new file mode 100644
index 00000000..3f63fb41
--- /dev/null
+++ b/test/components/agents/human_in_the_loop/test_policies.py
@@ -0,0 +1,114 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import pytest
+from haystack.tools import Tool, create_tool_from_function
+
+from haystack_experimental.components.agents.human_in_the_loop import (
+ AlwaysAskPolicy,
+ AskOncePolicy,
+ ConfirmationUIResult,
+ NeverAskPolicy,
+)
+
+
+def addition(x: int, y: int) -> int:
+ return x + y
+
+
+@pytest.fixture
+def addition_tool() -> Tool:
+ return create_tool_from_function(
+ function=addition,
+ name="Addition tool",
+ description="Adds two integers together.",
+ )
+
+
+class TestAlwaysAskPolicy:
+ def test_should_ask_always_true(self, addition_tool):
+ policy = AlwaysAskPolicy()
+ assert policy.should_ask(addition_tool.name, addition_tool.description, {"x": 1, "y": 2}) is True
+
+ def test_to_dict(self):
+ policy = AlwaysAskPolicy()
+ policy_dict = policy.to_dict()
+ assert (
+ policy_dict["type"] == "haystack_experimental.components.agents.human_in_the_loop.policies.AlwaysAskPolicy"
+ )
+ assert policy_dict["init_parameters"] == {}
+
+ def test_from_dict(self):
+ policy_dict = {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AlwaysAskPolicy",
+ "init_parameters": {},
+ }
+ policy = AlwaysAskPolicy.from_dict(policy_dict)
+ assert isinstance(policy, AlwaysAskPolicy)
+
+
+class TestAskOncePolicy:
+ def test_should_ask_first_time_true(self, addition_tool):
+ policy = AskOncePolicy()
+ assert policy.should_ask(addition_tool.name, addition_tool.description, {"x": 1, "y": 2}) is True
+
+ def test_should_ask_second_time_false(self, addition_tool):
+ policy = AskOncePolicy()
+ params = {"x": 1, "y": 2}
+ assert policy.should_ask(addition_tool.name, addition_tool.description, params) is True
+ # Simulate the update after confirmation that occurs in HumanInTheLoopStrategy
+ policy.update_after_confirmation(
+ addition_tool.name, addition_tool.description, params, ConfirmationUIResult(action="confirm", feedback=None)
+ )
+ assert policy.should_ask(addition_tool.name, addition_tool.description, params) is False
+
+ def test_should_ask_different_params_true(self, addition_tool):
+ policy = AskOncePolicy()
+ params1 = {"x": 1, "y": 2}
+ params2 = {"x": 3, "y": 4}
+ assert policy.should_ask(addition_tool.name, addition_tool.description, params1) is True
+ # Simulate the update after confirmation that occurs in HumanInTheLoopStrategy
+ policy.update_after_confirmation(
+ addition_tool.name,
+ addition_tool.description,
+ params1,
+ ConfirmationUIResult(action="confirm", feedback=None),
+ )
+ assert policy.should_ask(addition_tool.name, addition_tool.description, params2) is True
+
+ def test_to_dict(self):
+ policy = AskOncePolicy()
+ policy_dict = policy.to_dict()
+ assert policy_dict["type"] == "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy"
+ assert policy_dict["init_parameters"] == {}
+
+ def test_from_dict(self):
+ policy_dict = {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy",
+ "init_parameters": {},
+ }
+ policy = AskOncePolicy.from_dict(policy_dict)
+ assert isinstance(policy, AskOncePolicy)
+
+
+class TestNeverAskPolicy:
+ def test_should_ask_always_false(self, addition_tool):
+ policy = NeverAskPolicy()
+ assert policy.should_ask(addition_tool.name, addition_tool.description, {"x": 1, "y": 2}) is False
+
+ def test_to_dict(self):
+ policy = NeverAskPolicy()
+ policy_dict = policy.to_dict()
+ assert (
+ policy_dict["type"] == "haystack_experimental.components.agents.human_in_the_loop.policies.NeverAskPolicy"
+ )
+ assert policy_dict["init_parameters"] == {}
+
+ def test_from_dict(self):
+ policy_dict = {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.policies.NeverAskPolicy",
+ "init_parameters": {},
+ }
+ policy = NeverAskPolicy.from_dict(policy_dict)
+ assert isinstance(policy, NeverAskPolicy)
diff --git a/test/components/agents/human_in_the_loop/test_strategies.py b/test/components/agents/human_in_the_loop/test_strategies.py
new file mode 100644
index 00000000..ea91c884
--- /dev/null
+++ b/test/components/agents/human_in_the_loop/test_strategies.py
@@ -0,0 +1,340 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from dataclasses import replace
+import pytest
+from haystack.components.agents.state.state import State
+from haystack.dataclasses import ChatMessage, ToolCall
+from haystack.tools import Tool, create_tool_from_function
+
+from haystack_experimental.components.agents.agent import _ExecutionContext
+from haystack_experimental.components.agents.human_in_the_loop import (
+ AlwaysAskPolicy,
+ AskOncePolicy,
+ BlockingConfirmationStrategy,
+ BreakpointConfirmationStrategy,
+ HITLBreakpointException,
+ SimpleConsoleUI,
+ ToolExecutionDecision,
+ NeverAskPolicy,
+ ConfirmationUIResult,
+)
+from haystack_experimental.components.agents.human_in_the_loop.strategies import (
+ _apply_tool_execution_decisions,
+ _run_confirmation_strategies,
+ _update_chat_history,
+)
+
+
+def addition_tool(a: int, b: int) -> int:
+ return a + b
+
+
+@pytest.fixture
+def tools() -> list[Tool]:
+ tool = create_tool_from_function(
+ function=addition_tool, name="addition_tool", description="A tool that adds two integers together."
+ )
+ return [tool]
+
+
+@pytest.fixture
+def execution_context(tools) -> _ExecutionContext:
+ return _ExecutionContext(
+ state=State(schema={"messages": {"type": list[ChatMessage]}}),
+ component_visits={"chat_generator": 0, "tool_invoker": 0},
+ chat_generator_inputs={},
+ tool_invoker_inputs={"tools": tools},
+ counter=0,
+ skip_chat_generator=False,
+ tool_execution_decisions=None,
+ )
+
+
+class TestBlockingConfirmationStrategy:
+ def test_initialization(self):
+ strategy = BlockingConfirmationStrategy(confirmation_policy=AskOncePolicy(), confirmation_ui=SimpleConsoleUI())
+ assert isinstance(strategy.confirmation_policy, AskOncePolicy)
+ assert isinstance(strategy.confirmation_ui, SimpleConsoleUI)
+
+ def test_to_dict(self):
+ strategy = BlockingConfirmationStrategy(confirmation_policy=AskOncePolicy(), confirmation_ui=SimpleConsoleUI())
+ strategy_dict = strategy.to_dict()
+ assert strategy_dict == {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BlockingConfirmationStrategy",
+ "init_parameters": {
+ "confirmation_policy": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy",
+ "init_parameters": {},
+ },
+ "confirmation_ui": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI",
+ "init_parameters": {},
+ },
+ },
+ }
+
+ def test_from_dict(self):
+ strategy_dict = {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.HumanInTheLoopStrategy",
+ "init_parameters": {
+ "confirmation_policy": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy",
+ "init_parameters": {},
+ },
+ "confirmation_ui": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI",
+ "init_parameters": {},
+ },
+ },
+ }
+ strategy = BlockingConfirmationStrategy.from_dict(strategy_dict)
+ assert isinstance(strategy, BlockingConfirmationStrategy)
+ assert isinstance(strategy.confirmation_policy, AskOncePolicy)
+ assert isinstance(strategy.confirmation_ui, SimpleConsoleUI)
+
+ def test_run_confirm(self, monkeypatch):
+ strategy = BlockingConfirmationStrategy(AlwaysAskPolicy(), SimpleConsoleUI())
+
+ # Mock the UI to always confirm
+ def mock_get_user_confirmation(tool_name, tool_description, tool_params):
+ return ConfirmationUIResult(action="confirm")
+
+ monkeypatch.setattr(strategy.confirmation_ui, "get_user_confirmation", mock_get_user_confirmation)
+
+ decision = strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"})
+ assert decision.tool_name == "test_tool"
+ assert decision.execute is True
+ assert decision.final_tool_params == {"param1": "value1"}
+
+ def test_run_modify(self, monkeypatch):
+ strategy = BlockingConfirmationStrategy(AlwaysAskPolicy(), SimpleConsoleUI())
+
+ # Mock the UI to always modify
+ def mock_get_user_confirmation(tool_name, tool_description, tool_params):
+ return ConfirmationUIResult(action="modify", new_tool_params={"param1": "new_value"})
+
+ monkeypatch.setattr(strategy.confirmation_ui, "get_user_confirmation", mock_get_user_confirmation)
+
+ decision = strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"})
+ assert decision.tool_name == "test_tool"
+ assert decision.execute is True
+ assert decision.final_tool_params == {"param1": "new_value"}
+ assert decision.feedback == (
+ "The parameters for tool 'test_tool' were updated by the user to:\n{'param1': 'new_value'}"
+ )
+
+ def test_run_reject(self, monkeypatch):
+ strategy = BlockingConfirmationStrategy(AlwaysAskPolicy(), SimpleConsoleUI())
+
+ # Mock the UI to always reject
+ def mock_get_user_confirmation(tool_name, tool_description, tool_params):
+ return ConfirmationUIResult(action="reject", feedback="Not needed")
+
+ monkeypatch.setattr(strategy.confirmation_ui, "get_user_confirmation", mock_get_user_confirmation)
+
+ decision = strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"})
+ assert decision.tool_name == "test_tool"
+ assert decision.execute is False
+ assert decision.final_tool_params is None
+ assert decision.feedback == "Tool execution for 'test_tool' was rejected by the user. With feedback: Not needed"
+
+
+class TestBreakpointConfirmationStrategy:
+ def test_initialization(self):
+ strategy = BreakpointConfirmationStrategy(snapshot_file_path="test")
+ assert strategy.snapshot_file_path == "test"
+
+ def test_to_dict(self):
+ strategy = BreakpointConfirmationStrategy(snapshot_file_path="test")
+ strategy_dict = strategy.to_dict()
+ assert strategy_dict == {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BreakpointConfirmationStrategy",
+ "init_parameters": {"snapshot_file_path": "test"},
+ }
+
+ def test_from_dict(self):
+ strategy_dict = {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BreakpointConfirmationStrategy",
+ "init_parameters": {"snapshot_file_path": "test"},
+ }
+ strategy = BreakpointConfirmationStrategy.from_dict(strategy_dict)
+ assert isinstance(strategy, BreakpointConfirmationStrategy)
+ assert strategy.snapshot_file_path == "test"
+
+ def test_run(self):
+ strategy = BreakpointConfirmationStrategy(snapshot_file_path="test")
+ with pytest.raises(HITLBreakpointException):
+ strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"})
+
+
+class TestRunConfirmationStrategies:
+ def test_run_confirmation_strategies_hitl_breakpoint(self, tmp_path, tools, execution_context):
+ with pytest.raises(HITLBreakpointException):
+ _run_confirmation_strategies(
+ confirmation_strategies={tools[0].name: BreakpointConfirmationStrategy(str(tmp_path))},
+ messages_with_tool_calls=[
+ ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"})]),
+ ],
+ execution_context=execution_context,
+ )
+
+ def test_run_confirmation_strategies_no_strategy(self, tools, execution_context):
+ teds = _run_confirmation_strategies(
+ confirmation_strategies={},
+ messages_with_tool_calls=[
+ ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"})]),
+ ],
+ execution_context=execution_context,
+ )
+ assert teds == [
+ ToolExecutionDecision(tool_name=tools[0].name, execute=True, final_tool_params={"param1": "value1"})
+ ]
+
+ def test_run_confirmation_strategies_with_strategy(self, tools, execution_context):
+ teds = _run_confirmation_strategies(
+ confirmation_strategies={tools[0].name: BlockingConfirmationStrategy(NeverAskPolicy(), SimpleConsoleUI())},
+ messages_with_tool_calls=[
+ ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"})]),
+ ],
+ execution_context=execution_context,
+ )
+ assert teds == [
+ ToolExecutionDecision(tool_name=tools[0].name, execute=True, final_tool_params={"param1": "value1"})
+ ]
+
+ def test_run_confirmation_strategies_with_existing_teds(self, tools, execution_context):
+ exe_context_with_teds = replace(
+ execution_context,
+ tool_execution_decisions=[
+ ToolExecutionDecision(
+ tool_name=tools[0].name, execute=True, tool_call_id="123", final_tool_params={"param1": "new_value"}
+ )
+ ],
+ )
+ teds = _run_confirmation_strategies(
+ confirmation_strategies={tools[0].name: BlockingConfirmationStrategy(NeverAskPolicy(), SimpleConsoleUI())},
+ messages_with_tool_calls=[
+ ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"}, id="123")]),
+ ],
+ execution_context=exe_context_with_teds,
+ )
+ assert teds == [
+ ToolExecutionDecision(
+ tool_name=tools[0].name, execute=True, tool_call_id="123", final_tool_params={"param1": "new_value"}
+ )
+ ]
+
+
+class TestApplyToolExecutionDecisions:
+ @pytest.fixture
+ def assistant_message(self, tools):
+ tool_call = ToolCall(tool_name=tools[0].name, arguments={"a": 1, "b": 2}, id="1")
+ return ChatMessage.from_assistant(tool_calls=[tool_call])
+
+ def test_apply_tool_execution_decisions_reject(self, tools, assistant_message):
+ rejection_messages, new_tool_call_messages = _apply_tool_execution_decisions(
+ tool_call_messages=[assistant_message],
+ tool_execution_decisions=[
+ ToolExecutionDecision(
+ tool_name=tools[0].name,
+ execute=False,
+ tool_call_id="1",
+ feedback=(
+ "The tool execution for 'addition_tool' was rejected by the user. With feedback: Not needed"
+ ),
+ )
+ ],
+ )
+ assert rejection_messages == [
+ assistant_message,
+ ChatMessage.from_tool(
+ tool_result=(
+ "The tool execution for 'addition_tool' was rejected by the user. With feedback: Not needed"
+ ),
+ origin=assistant_message.tool_call,
+ error=True,
+ ),
+ ]
+ assert new_tool_call_messages == []
+
+ def test_apply_tool_execution_decisions_confirm(self, tools, assistant_message):
+ rejection_messages, new_tool_call_messages = _apply_tool_execution_decisions(
+ tool_call_messages=[assistant_message],
+ tool_execution_decisions=[
+ ToolExecutionDecision(
+ tool_name=tools[0].name, execute=True, tool_call_id="1", final_tool_params={"a": 1, "b": 2}
+ )
+ ],
+ )
+ assert rejection_messages == []
+ assert new_tool_call_messages == [assistant_message]
+
+ def test_apply_tool_execution_decisions_modify(self, tools, assistant_message):
+ rejection_messages, new_tool_call_messages = _apply_tool_execution_decisions(
+ tool_call_messages=[assistant_message],
+ tool_execution_decisions=[
+ ToolExecutionDecision(
+ tool_name=tools[0].name,
+ execute=True,
+ tool_call_id="1",
+ final_tool_params={"a": 5, "b": 6},
+ feedback="The parameters for tool 'addition_tool' were updated by the user to:\n{'a': 5, 'b': 6}",
+ )
+ ],
+ )
+ assert rejection_messages == []
+ assert new_tool_call_messages == [
+ ChatMessage.from_user(
+ "The parameters for tool 'addition_tool' were updated by the user to:\n{'a': 5, 'b': 6}"
+ ),
+ ChatMessage.from_assistant(
+ tool_calls=[ToolCall(tool_name=tools[0].name, arguments={"a": 5, "b": 6}, id="1")]
+ ),
+ ]
+
+
+class TestUpdateChatHistory:
+ @pytest.fixture
+ def chat_history_one_tool_call(self):
+ return [
+ ChatMessage.from_user("Hello"),
+ ChatMessage.from_assistant(tool_calls=[ToolCall("tool1", {"a": 1, "b": 2}, id="1")]),
+ ]
+
+ def test_update_chat_history_rejection(self, chat_history_one_tool_call):
+ """Test that the new history includes a tool call result message after a rejection."""
+ rejection_messages = [
+ ChatMessage.from_assistant(tool_calls=[chat_history_one_tool_call[1].tool_call]),
+ ChatMessage.from_tool(
+ tool_result="The tool execution for 'tool1' was rejected by the user. With feedback: Not needed",
+ origin=chat_history_one_tool_call[1].tool_call,
+ error=True,
+ ),
+ ]
+ updated_messages = _update_chat_history(
+ chat_history_one_tool_call, rejection_messages=rejection_messages, tool_call_and_explanation_messages=[]
+ )
+ assert updated_messages == [chat_history_one_tool_call[0], *rejection_messages]
+
+ def test_update_chat_history_confirm(self, chat_history_one_tool_call):
+ """No changes should be made if the tool call was confirmed."""
+ tool_call_messages = [ChatMessage.from_assistant(tool_calls=[chat_history_one_tool_call[1].tool_call])]
+ updated_messages = _update_chat_history(
+ chat_history_one_tool_call, rejection_messages=[], tool_call_and_explanation_messages=tool_call_messages
+ )
+ assert updated_messages == chat_history_one_tool_call
+
+ def test_update_chat_history_modify(self, chat_history_one_tool_call):
+ """Test that the new history includes a user message and updated tool call after a modification."""
+ tool_call_messages = [
+ ChatMessage.from_user(
+ "The parameters for tool 'tool1' were updated by the user to:\n{'param': 'new_value'}"
+ ),
+ ChatMessage.from_assistant(tool_calls=[ToolCall("tool1", {"param": "new_value"}, id="1")]),
+ ]
+ updated_messages = _update_chat_history(
+ chat_history_one_tool_call, rejection_messages=[], tool_call_and_explanation_messages=tool_call_messages
+ )
+ assert updated_messages == [chat_history_one_tool_call[0], *tool_call_messages]
diff --git a/test/components/agents/human_in_the_loop/test_user_interfaces.py b/test/components/agents/human_in_the_loop/test_user_interfaces.py
new file mode 100644
index 00000000..de3474dc
--- /dev/null
+++ b/test/components/agents/human_in_the_loop/test_user_interfaces.py
@@ -0,0 +1,221 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from haystack.tools import create_tool_from_function
+
+from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult
+from haystack_experimental.components.agents.human_in_the_loop.user_interfaces import RichConsoleUI, SimpleConsoleUI
+
+
+def multiply_tool(x: int) -> int:
+ return x * 2
+
+
+@pytest.fixture
+def tool():
+ return create_tool_from_function(
+ function=multiply_tool,
+ name="test_tool",
+ description="A test tool that multiplies input by 2.",
+ )
+
+
+class TestRichConsoleUI:
+ @pytest.mark.parametrize("choice", ["y"])
+ def test_process_choice_confirm(self, tool, choice):
+ ui = RichConsoleUI(console=MagicMock())
+
+ with patch(
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask",
+ side_effect=[choice, "feedback"],
+ ):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "confirm"
+ assert result.new_tool_params is None
+ assert result.feedback is None
+
+ @pytest.mark.parametrize("choice", ["m"])
+ def test_process_choice_modify(self, tool, choice):
+ ui = RichConsoleUI(console=MagicMock())
+
+ with patch(
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask",
+ side_effect=["m", "2"],
+ ):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {"x": 2}
+
+ def test_process_choice_modify_dict_param(self, tool):
+ ui = RichConsoleUI(console=MagicMock())
+
+ with patch(
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask",
+ side_effect=["m", '{"key": "value"}'],
+ ):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {"param1": {"key": "value"}}
+
+ def test_process_choice_modify_dict_param_invalid_json(self, tool):
+ ui = RichConsoleUI(console=MagicMock())
+
+ with patch(
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask",
+ side_effect=["m", "invalid_json", '{"key": "value"}'],
+ ):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {"param1": {"key": "value"}}
+
+ @pytest.mark.parametrize("choice", ["n"])
+ def test_process_choice_reject(self, tool, choice):
+ ui = RichConsoleUI(console=MagicMock())
+
+ with patch(
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask",
+ side_effect=["n", "Changed my mind"],
+ ):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "reject"
+ assert result.feedback == "Changed my mind"
+
+ def test_to_dict(self):
+ ui = RichConsoleUI()
+ data = ui.to_dict()
+ assert data["type"] == (
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.RichConsoleUI"
+ )
+ assert data["init_parameters"]["console"] is None
+
+ def test_from_dict(self):
+ ui = RichConsoleUI()
+ data = ui.to_dict()
+ new_ui = RichConsoleUI.from_dict(data)
+ assert isinstance(new_ui, RichConsoleUI)
+
+
+class TestSimpleConsoleUI:
+ @pytest.mark.parametrize("choice", ["y", "yes", "Y", "YES"])
+ def test_process_choice_confirm(self, tool, choice):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=[choice]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"y": "abc"})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "confirm"
+
+ @pytest.mark.parametrize("choice", ["m", "modify", "M", "MODIFY"])
+ def test_process_choice_modify(self, tool, choice):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=[choice, "new_value"]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"y": "abc"})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {"y": "new_value"}
+
+ def test_process_choice_modify_dict_param(self, tool):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=["m", '{"key": "value"}']):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {"param1": {"key": "value"}}
+
+ def test_process_choice_modify_dict_param_invalid_json(self, tool):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=["m", "invalid_json", '{"key": "value"}']):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {"param1": {"key": "value"}}
+
+ @pytest.mark.parametrize("choice", ["n", "no", "N", "NO"])
+ def test_process_choice_reject(self, tool, choice):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=[choice, "Changed my mind"]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"param1": "value1"})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "reject"
+ assert result.feedback == "Changed my mind"
+
+ def test_process_choice_no_tool_params_confirm(self, tool):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=["y"]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "confirm"
+ assert result.new_tool_params is None
+ assert result.feedback is None
+
+ def test_process_choice_no_tool_params_modify(self, tool):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=["m"]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "modify"
+ assert result.new_tool_params == {}
+ assert result.feedback is None
+
+ def test_process_choice_no_tool_params_reject(self, tool):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=["n", "Changed my mind"]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "reject"
+ assert result.new_tool_params is None
+ assert result.feedback == "Changed my mind"
+
+ def test_to_dict(self):
+ ui = SimpleConsoleUI()
+ data = ui.to_dict()
+ assert data["type"] == (
+ "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI"
+ )
+ assert data["init_parameters"] == {}
+
+ def test_from_dict(self):
+ ui = SimpleConsoleUI()
+ data = ui.to_dict()
+ new_ui = SimpleConsoleUI.from_dict(data)
+ assert isinstance(new_ui, SimpleConsoleUI)
+
+ def test_get_user_confirmation_invalid_input_then_valid(self, tool):
+ ui = SimpleConsoleUI()
+
+ with patch("builtins.input", side_effect=["invalid", "y"]):
+ result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1})
+
+ assert isinstance(result, ConfirmationUIResult)
+ assert result.action == "confirm"
+ assert result.new_tool_params is None
+ assert result.feedback is None
diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py
new file mode 100644
index 00000000..849faf8e
--- /dev/null
+++ b/test/components/agents/test_agent.py
@@ -0,0 +1,341 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from pathlib import Path
+from typing import Any, Optional
+
+import pytest
+from haystack.components.generators.chat.openai import OpenAIChatGenerator
+from haystack.core.errors import BreakpointException
+from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
+from haystack.dataclasses import ChatMessage
+from haystack.dataclasses.breakpoints import PipelineSnapshot
+from haystack.tools import Tool, create_tool_from_function
+
+from haystack_experimental.components.agents.agent import Agent
+from haystack_experimental.components.agents.human_in_the_loop import (
+ AlwaysAskPolicy,
+ BlockingConfirmationStrategy,
+ BreakpointConfirmationStrategy,
+ ConfirmationStrategy,
+ ConfirmationUI,
+ ConfirmationUIResult,
+ NeverAskPolicy,
+ SimpleConsoleUI,
+ ToolExecutionDecision,
+)
+from haystack_experimental.components.agents.human_in_the_loop.breakpoint import (
+ get_tool_calls_and_descriptions_from_snapshot,
+)
+
+
+class TestUserInterface(ConfirmationUI):
+ def __init__(self, ui_result: ConfirmationUIResult) -> None:
+ self.ui_result = ui_result
+
+ def get_user_confirmation(
+ self, tool_name: str, tool_description: str, tool_params: dict[str, Any]
+ ) -> ConfirmationUIResult:
+ return self.ui_result
+
+
+def frontend_simulate_tool_decision(
+ tool_calls: list[dict[str, Any]],
+ tool_descriptions: dict[str, str],
+ confirmation_ui_result: ConfirmationUIResult,
+) -> list[dict]:
+ confirmation_strategy = BlockingConfirmationStrategy(
+ confirmation_policy=AlwaysAskPolicy(),
+ confirmation_ui=TestUserInterface(ui_result=confirmation_ui_result),
+ )
+
+ tool_execution_decisions = []
+ for tc in tool_calls:
+ tool_execution_decisions.append(
+ confirmation_strategy.run(
+ tool_name=tc["tool_name"],
+ tool_description=tool_descriptions[tc["tool_name"]],
+ tool_call_id=tc["id"],
+ tool_params=tc["arguments"],
+ )
+ )
+ return [ted.to_dict() for ted in tool_execution_decisions]
+
+
+def get_latest_snapshot(snapshot_file_path: str) -> PipelineSnapshot:
+ snapshot_dir = Path(snapshot_file_path)
+ possible_snapshots = [snapshot_dir / f for f in os.listdir(snapshot_dir)]
+ latest_snapshot_file = str(max(possible_snapshots, key=os.path.getctime))
+ return load_pipeline_snapshot(latest_snapshot_file)
+
+
+def run_agent(
+ agent: Agent,
+ messages: list[ChatMessage],
+ snapshot_file_path: Optional[str] = None,
+ tool_execution_decisions: Optional[list[dict[str, Any]]] = None,
+) -> Optional[dict[str, Any]]:
+ # Load the latest snapshot if a path is provided
+ snapshot = None
+ if snapshot_file_path:
+ snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path)
+
+ # Add any new tool execution decisions to the snapshot
+ if tool_execution_decisions:
+ teds = [ToolExecutionDecision.from_dict(ted) for ted in tool_execution_decisions]
+ existing_decisions = snapshot.agent_snapshot.tool_execution_decisions or []
+ snapshot.agent_snapshot.tool_execution_decisions = existing_decisions + teds
+
+ try:
+ return agent.run(messages=messages, snapshot=snapshot.agent_snapshot if snapshot else None)
+ except BreakpointException:
+ return None
+
+
+async def run_agent_async(
+ agent: Agent,
+ messages: list[ChatMessage],
+ snapshot_file_path: Optional[str] = None,
+ tool_execution_decisions: Optional[list[dict[str, Any]]] = None,
+) -> Optional[dict[str, Any]]:
+ # Load the latest snapshot if a path is provided
+ snapshot = None
+ if snapshot_file_path:
+ snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path)
+
+ # Add any new tool execution decisions to the snapshot
+ if tool_execution_decisions:
+ teds = [ToolExecutionDecision.from_dict(ted) for ted in tool_execution_decisions]
+ existing_decisions = snapshot.agent_snapshot.tool_execution_decisions or []
+ snapshot.agent_snapshot.tool_execution_decisions = existing_decisions + teds
+
+ try:
+ return await agent.run_async(messages=messages, snapshot=snapshot.agent_snapshot if snapshot else None)
+ except BreakpointException:
+ return None
+
+
+def addition_tool(a: int, b: int) -> int:
+ return a + b
+
+
+@pytest.fixture
+def tools() -> list[Tool]:
+ tool = create_tool_from_function(
+ function=addition_tool, name="addition_tool", description="A tool that adds two integers together."
+ )
+ return [tool]
+
+
+@pytest.fixture
+def confirmation_strategies() -> dict[str, ConfirmationStrategy]:
+ return {"addition_tool": BlockingConfirmationStrategy(NeverAskPolicy(), SimpleConsoleUI())}
+
+
+class TestAgent:
+ def test_to_dict(self, tools, confirmation_strategies, monkeypatch):
+ monkeypatch.setenv("OPENAI_API_KEY", "test")
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(), tools=tools, confirmation_strategies=confirmation_strategies
+ )
+ agent_dict = agent.to_dict()
+ assert agent_dict == {
+ "type": "haystack_experimental.components.agents.agent.Agent",
+ "init_parameters": {
+ "chat_generator": {
+ "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
+ "init_parameters": {
+ "model": "gpt-4o-mini",
+ "streaming_callback": None,
+ "api_base_url": None,
+ "organization": None,
+ "generation_kwargs": {},
+ "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
+ "timeout": None,
+ "max_retries": None,
+ "tools": None,
+ "tools_strict": False,
+ "http_client_kwargs": None,
+ },
+ },
+ "tools": [
+ {
+ "type": "haystack.tools.tool.Tool",
+ "data": {
+ "name": "addition_tool",
+ "description": "A tool that adds two integers together.",
+ "parameters": {
+ "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
+ "required": ["a", "b"],
+ "type": "object",
+ },
+ "function": "test.components.agents.test_agent.addition_tool",
+ "outputs_to_string": None,
+ "inputs_from_state": None,
+ "outputs_to_state": None,
+ },
+ }
+ ],
+ "system_prompt": None,
+ "exit_conditions": ["text"],
+ "state_schema": {},
+ "max_agent_steps": 100,
+ "streaming_callback": None,
+ "raise_on_tool_invocation_failure": False,
+ "tool_invoker_kwargs": None,
+ "confirmation_strategies": {
+ "addition_tool": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BlockingConfirmationStrategy",
+ "init_parameters": {
+ "confirmation_policy": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.policies.NeverAskPolicy",
+ "init_parameters": {},
+ },
+ "confirmation_ui": {
+ "type": "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI",
+ "init_parameters": {},
+ },
+ },
+ }
+ },
+ },
+ }
+
+ def test_from_dict(self, tools, confirmation_strategies, monkeypatch):
+ monkeypatch.setenv("OPENAI_API_KEY", "test")
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(), tools=tools, confirmation_strategies=confirmation_strategies
+ )
+ deserialized_agent = Agent.from_dict(agent.to_dict())
+ assert deserialized_agent.to_dict() == agent.to_dict()
+ assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator)
+ assert len(deserialized_agent.tools) == 1
+ assert deserialized_agent.tools[0].name == "addition_tool"
+ assert isinstance(deserialized_agent._tool_invoker, type(agent._tool_invoker))
+ assert isinstance(deserialized_agent._confirmation_strategies["addition_tool"], BlockingConfirmationStrategy)
+ assert isinstance(
+ deserialized_agent._confirmation_strategies["addition_tool"].confirmation_policy, NeverAskPolicy
+ )
+ assert isinstance(deserialized_agent._confirmation_strategies["addition_tool"].confirmation_ui, SimpleConsoleUI)
+
+
+class TestAgentConfirmationStrategy:
+ @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
+ @pytest.mark.integration
+ def test_run_blocking_confirmation_strategy_modify(self, tools):
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"),
+ tools=tools,
+ confirmation_strategies={
+ "addition_tool": BlockingConfirmationStrategy(
+ AlwaysAskPolicy(),
+ TestUserInterface(ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3})),
+ )
+ },
+ )
+ agent.warm_up()
+
+ result = agent.run([ChatMessage.from_user("What is 2+2?")])
+
+ assert isinstance(result["last_message"], ChatMessage)
+ assert "5" in result["last_message"].text
+
+ @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
+ @pytest.mark.integration
+ def test_run_breakpoint_confirmation_strategy_modify(self, tools, tmp_path):
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"),
+ tools=tools,
+ confirmation_strategies={
+ "addition_tool": BreakpointConfirmationStrategy(snapshot_file_path=str(tmp_path)),
+ },
+ )
+ agent.warm_up()
+
+ # Step 1: Initial run
+ result = run_agent(agent, [ChatMessage.from_user("What is 2+2?")])
+
+ # Step 2: Loop to handle break point confirmation strategy until agent completes
+ while result is None:
+ # Load the latest snapshot from disk and prep data for front-end
+ loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))
+ serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot(
+ agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
+ )
+
+ # Simulate front-end interaction
+ serialized_teds = frontend_simulate_tool_decision(
+ serialized_tool_calls,
+ tool_descripts,
+ ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3}),
+ )
+
+ # Re-run the agent with the new tool execution decisions
+ result = run_agent(agent, [], str(tmp_path), serialized_teds)
+
+ # Step 3: Final result
+ last_message = result["last_message"]
+ assert isinstance(last_message, ChatMessage)
+ assert "5" in last_message.text
+
+ @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
+ @pytest.mark.integration
+ @pytest.mark.asyncio
+ async def test_run_async_blocking_confirmation_strategy_modify(self, tools):
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"),
+ tools=tools,
+ confirmation_strategies={
+ "addition_tool": BlockingConfirmationStrategy(
+ AlwaysAskPolicy(),
+ TestUserInterface(ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3})),
+ )
+ },
+ )
+ agent.warm_up()
+
+ result = await agent.run_async([ChatMessage.from_user("What is 2+2?")])
+
+ assert isinstance(result["last_message"], ChatMessage)
+ assert "5" in result["last_message"].text
+
+ @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
+ @pytest.mark.integration
+ @pytest.mark.asyncio
+ async def test_run_async_breakpoint_confirmation_strategy_modify(self, tools, tmp_path):
+ agent = Agent(
+ chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"),
+ tools=tools,
+ confirmation_strategies={
+ "addition_tool": BreakpointConfirmationStrategy(snapshot_file_path=str(tmp_path)),
+ },
+ )
+ agent.warm_up()
+
+ # Step 1: Initial run
+ result = await run_agent_async(agent, [ChatMessage.from_user("What is 2+2?")])
+
+ # Step 2: Loop to handle break point confirmation strategy until agent completes
+ while result is None:
+ # Load the latest snapshot from disk and prep data for front-end
+ loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))
+ serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot(
+ agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
+ )
+
+ # Simulate front-end interaction
+ serialized_teds = frontend_simulate_tool_decision(
+ serialized_tool_calls,
+ tool_descripts,
+ ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3}),
+ )
+
+ # Re-run the agent with the new tool execution decisions
+ result = await run_agent_async(agent, [], str(tmp_path), serialized_teds)
+
+ # Step 3: Final result
+ last_message = result["last_message"]
+ assert isinstance(last_message, ChatMessage)
+ assert "5" in last_message.text