diff --git a/README.md b/README.md index a1ac4e85..9d818f04 100644 --- a/README.md +++ b/README.md @@ -41,17 +41,18 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b ### Active experiments -| Name | Type | Expected End Date | Dependencies | Cookbook | Discussion | -|---------------------------------------|--------------------------------|-------------------|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| -| [`InMemoryChatMessageStore`][1] | Memory Store | December 2024 | None | Open In Colab | [Discuss][4] | -| [`ChatMessageRetriever`][2] | Memory Component | December 2024 | None | Open In Colab | [Discuss][4] | -| [`ChatMessageWriter`][3] | Memory Component | December 2024 | None | Open In Colab | [Discuss][4] | -| [`QueryExpander`][5] | Query Expansion Component | October 2025 | None | None | [Discuss][6] | -| [`EmbeddingBasedDocumentSplitter`][8] | EmbeddingBasedDocumentSplitter | August 2025 | None | None | [Discuss][7] | -| [`MultiQueryEmbeddingRetriever`][13] | MultiQueryEmbeddingRetriever | November 2025 | None | None | [Discuss][11] | -| [`MultiQueryTextRetriever`][14] | MultiQueryTextRetriever | November 2025 | None | None | [Discuss][12] | -| [`OpenAIChatGenerator`][9] | Chat Generator Component | November 2025 | None | Open In Colab | [Discuss][10] | -| [`MarkdownHeaderLevelInferrer`][15] | Preprocessor | January 2025 | None | None | [Discuss][16] | +| Name | Type | Expected End Date | Dependencies | Cookbook | Discussion | +|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|-------------------|--------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| +| [`InMemoryChatMessageStore`][1] | Memory Store | December 2024 | None | Open In Colab | [Discuss][4] | +| [`ChatMessageRetriever`][2] | Memory Component | December 2024 | None | Open In Colab | [Discuss][4] | +| [`ChatMessageWriter`][3] | Memory Component | December 2024 | None | Open In Colab | [Discuss][4] | +| [`QueryExpander`][5] | Query Expansion Component | October 2025 | None | None | [Discuss][6] | +| [`EmbeddingBasedDocumentSplitter`][8] | EmbeddingBasedDocumentSplitter | August 2025 | None | None | [Discuss][7] | +| [`MultiQueryEmbeddingRetriever`][13] | MultiQueryEmbeddingRetriever | November 2025 | None | None | [Discuss][11] | +| [`MultiQueryTextRetriever`][14] | MultiQueryTextRetriever | November 2025 | None | None | [Discuss][12] | +| [`OpenAIChatGenerator`][9] | Chat Generator Component | November 2025 | None | Open In Colab | [Discuss][10] | +| [`MarkdownHeaderLevelInferrer`][15] | Preprocessor | January 2025 | None | None | [Discuss][16] | +| [`Agent`][17]; [Confirmation Policies][18]; [ConfirmationUIs][19]; [ConfirmationStrategies][20]; [`ConfirmationUIResult` and `ToolExecutionDecision`][21] [HITLBreakpointException][22] | Human in the Loop | December 2025 | rich | None | [Discuss][23] | [1]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/chat_message_stores/in_memory.py [2]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/chat_message_retriever.py @@ -69,8 +70,13 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b [14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/multi_query_text_retriever.py [15]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/md_header_level_inferrer.py [16]: https://github.com/deepset-ai/haystack-experimental/discussions/376 - - +[17]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/agent.py +[18]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/policies.py +[19]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py +[20]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/strategies.py +[21]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py +[22]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/agents/human_in_the_loop/errors.py +[23]: https://github.com/deepset-ai/haystack-experimental/discussions/XXX ### Adopted experiments | Name | Type | Final release | diff --git a/docs/pydoc/config/agents_api.yml b/docs/pydoc/config/agents_api.yml new file mode 100644 index 00000000..9a92549f --- /dev/null +++ b/docs/pydoc/config/agents_api.yml @@ -0,0 +1,36 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../../../] + modules: [ + "haystack_experimental.components.agents.agent", + "haystack_experimental.components.agents.human_in_the_loop.breakpoint", + "haystack_experimental.components.agents.human_in_the_loop.dataclasses", + "haystack_experimental.components.agents.human_in_the_loop.errors", + "haystack_experimental.components.agents.human_in_the_loop.policies", + "haystack_experimental.components.agents.human_in_the_loop.strategies", + "haystack_experimental.components.agents.human_in_the_loop.types", + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer + excerpt: Tool-using agents with provider-agnostic chat model support. + category_slug: haystack-api + title: Agents + slug: experimental-agents-api + order: 2 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: experimental_agents_api.md diff --git a/haystack_experimental/components/agents/__init__.py b/haystack_experimental/components/agents/__init__.py new file mode 100644 index 00000000..4170413e --- /dev/null +++ b/haystack_experimental/components/agents/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import sys +from typing import TYPE_CHECKING + +from lazy_imports import LazyImporter + +_import_structure = {"agent": ["Agent"]} + +if TYPE_CHECKING: + from .agent import Agent as Agent + +else: + sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure) diff --git a/haystack_experimental/components/agents/agent.py b/haystack_experimental/components/agents/agent.py new file mode 100644 index 00000000..60f451bc --- /dev/null +++ b/haystack_experimental/components/agents/agent.py @@ -0,0 +1,634 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=wrong-import-order,wrong-import-position,ungrouped-imports +# ruff: noqa: I001 + +from dataclasses import dataclass +from typing import Any, Optional, Union + +# Monkey patch Haystack's AgentSnapshot with our extended version +import haystack.dataclasses.breakpoints as hdb +from haystack_experimental.dataclasses.breakpoints import AgentSnapshot + +hdb.AgentSnapshot = AgentSnapshot # type: ignore[misc] + +# Monkey patch Haystack's breakpoint functions with our extended versions +import haystack.core.pipeline.breakpoint as hs_breakpoint +import haystack_experimental.core.pipeline.breakpoint as exp_breakpoint + +hs_breakpoint._create_agent_snapshot = exp_breakpoint._create_agent_snapshot +hs_breakpoint._create_pipeline_snapshot_from_tool_invoker = exp_breakpoint._create_pipeline_snapshot_from_tool_invoker # type: ignore[assignment] +hs_breakpoint._trigger_tool_invoker_breakpoint = exp_breakpoint._trigger_tool_invoker_breakpoint + +from haystack import logging +from haystack.components.agents.agent import Agent as HaystackAgent +from haystack.components.agents.agent import _ExecutionContext as Haystack_ExecutionContext +from haystack.components.agents.agent import _schema_from_dict +from haystack.components.agents.state import replace_values +from haystack.components.generators.chat.types import ChatGenerator +from haystack.core.errors import PipelineRuntimeError +from haystack.core.pipeline import AsyncPipeline, Pipeline +from haystack.core.pipeline.breakpoint import ( + _create_pipeline_snapshot_from_chat_generator, + _create_pipeline_snapshot_from_tool_invoker, +) +from haystack.core.pipeline.utils import _deepcopy_with_exceptions +from haystack.core.serialization import default_from_dict, import_class_by_name +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint +from haystack.dataclasses.streaming_chunk import StreamingCallbackT +from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace +from haystack.utils.callable_serialization import deserialize_callable +from haystack.utils.deserialization import deserialize_chatgenerator_inplace + +from haystack_experimental.components.agents.human_in_the_loop import ( + ConfirmationStrategy, + ToolExecutionDecision, + HITLBreakpointException, +) +from haystack_experimental.components.agents.human_in_the_loop.strategies import _process_confirmation_strategies + +logger = logging.getLogger(__name__) + + +@dataclass +class _ExecutionContext(Haystack_ExecutionContext): + """ + Execution context for the Agent component + + Extends Haystack's _ExecutionContext to include tool execution decisions for human-in-the-loop strategies. + + :param tool_execution_decisions: Optional list of ToolExecutionDecision objects to use instead of prompting + the user. This is useful when restarting from a snapshot where tool execution decisions were already made. + """ + + tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None + + +class Agent(HaystackAgent): + """ + A Haystack component that implements a tool-using agent with provider-agnostic chat model support. + + NOTE: This class extends Haystack's Agent component to add support for human-in-the-loop confirmation strategies. + + The component processes messages and executes tools until an exit condition is met. + The exit condition can be triggered either by a direct text response or by invoking a specific designated tool. + Multiple exit conditions can be specified. + + When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits. + + ### Usage example + ```python + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.tools.tool import Tool + + from haystack_experimental.components.agents import Agent + from haystack_experimental.components.agents.human_in_the_loop import ( + HumanInTheLoopStrategy, + AlwaysAskPolicy, + NeverAskPolicy, + SimpleConsoleUI, + ) + + calculator_tool = Tool(name="calculator", description="A tool for performing mathematical calculations.", ...) + search_tool = Tool(name="search", description="A tool for searching the web.", ...) + + agent = Agent( + chat_generator=OpenAIChatGenerator(), + tools=[calculator_tool, search_tool], + confirmation_strategies={ + calculator_tool.name: HumanInTheLoopStrategy( + confirmation_policy=NeverAskPolicy(), confirmation_ui=SimpleConsoleUI() + ), + search_tool.name: HumanInTheLoopStrategy( + confirmation_policy=AlwaysAskPolicy(), confirmation_ui=SimpleConsoleUI() + ), + }, + ) + + # Run the agent + result = agent.run( + messages=[ChatMessage.from_user("Find information about Haystack")] + ) + + assert "messages" in result # Contains conversation history + ``` + """ + + def __init__( + self, + *, + chat_generator: ChatGenerator, + tools: Optional[Union[list[Tool], Toolset]] = None, + system_prompt: Optional[str] = None, + exit_conditions: Optional[list[str]] = None, + state_schema: Optional[dict[str, Any]] = None, + max_agent_steps: int = 100, + streaming_callback: Optional[StreamingCallbackT] = None, + raise_on_tool_invocation_failure: bool = False, + confirmation_strategies: Optional[dict[str, ConfirmationStrategy]] = None, + tool_invoker_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + """ + Initialize the agent component. + + :param chat_generator: An instance of the chat generator that your agent should use. It must support tools. + :param tools: List of Tool objects or a Toolset that the agent can use. + :param system_prompt: System prompt for the agent. + :param exit_conditions: List of conditions that will cause the agent to return. + Can include "text" if the agent should return when it generates a message without tool calls, + or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"]. + :param state_schema: The schema for the runtime state used by the tools. + :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. + If the agent exceeds this number of steps, it will stop and return the current state. + :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. + :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? + If set to False, the exception will be turned into a chat message and passed to the LLM. + :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. + :raises TypeError: If the chat_generator does not support tools parameter in its run method. + :raises ValueError: If the exit_conditions are not valid. + """ + super(Agent, self).__init__( + chat_generator=chat_generator, + tools=tools, + system_prompt=system_prompt, + exit_conditions=exit_conditions, + state_schema=state_schema, + max_agent_steps=max_agent_steps, + streaming_callback=streaming_callback, + raise_on_tool_invocation_failure=raise_on_tool_invocation_failure, + tool_invoker_kwargs=tool_invoker_kwargs, + ) + self._confirmation_strategies = confirmation_strategies or {} + + def _initialize_fresh_execution( + self, + messages: list[ChatMessage], + streaming_callback: Optional[StreamingCallbackT], + requires_async: bool, + *, + system_prompt: Optional[str] = None, + tools: Optional[Union[list[Tool], Toolset, list[str]]] = None, + **kwargs: dict[str, Any], + ) -> _ExecutionContext: + """ + Initialize execution context for a fresh run of the agent. + + :param messages: List of ChatMessage objects to start the agent with. + :param streaming_callback: Optional callback for streaming responses. + :param requires_async: Whether the agent run requires asynchronous execution. + :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt. + :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. + When passing tool names, tools are selected from the Agent's originally configured tools. + :param kwargs: Additional data to pass to the State used by the Agent. + """ + exe_context = super(Agent, self)._initialize_fresh_execution( + messages=messages, + streaming_callback=streaming_callback, + requires_async=requires_async, + system_prompt=system_prompt, + tools=tools, + **kwargs, + ) + # NOTE: 1st difference with parent method to add this to tool_invoker_inputs + if self._tool_invoker: + exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = ( + self._tool_invoker.enable_streaming_callback_passthrough + ) + # NOTE: 2nd difference is to use the extended _ExecutionContext + return _ExecutionContext( + state=exe_context.state, + component_visits=exe_context.component_visits, + chat_generator_inputs=exe_context.chat_generator_inputs, + tool_invoker_inputs=exe_context.tool_invoker_inputs, + ) + + def _initialize_from_snapshot( # type: ignore[override] + self, + snapshot: AgentSnapshot, + streaming_callback: Optional[StreamingCallbackT], + requires_async: bool, + *, + tools: Optional[Union[list[Tool], Toolset, list[str]]] = None, + ) -> _ExecutionContext: + """ + Initialize execution context from an AgentSnapshot. + + :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution. + :param streaming_callback: Optional callback for streaming responses. + :param requires_async: Whether the agent run requires asynchronous execution. + :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. + When passing tool names, tools are selected from the Agent's originally configured tools. + """ + exe_context = super(Agent, self)._initialize_from_snapshot( + snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools + ) + # NOTE: 1st difference with parent method to add this to tool_invoker_inputs + if self._tool_invoker: + exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = ( + self._tool_invoker.enable_streaming_callback_passthrough + ) + # NOTE: 2nd difference is to use the extended _ExecutionContext and add tool_execution_decisions + return _ExecutionContext( + state=exe_context.state, + component_visits=exe_context.component_visits, + chat_generator_inputs=exe_context.chat_generator_inputs, + tool_invoker_inputs=exe_context.tool_invoker_inputs, + counter=exe_context.counter, + skip_chat_generator=exe_context.skip_chat_generator, + tool_execution_decisions=snapshot.tool_execution_decisions, + ) + + def run( # noqa: PLR0915 + self, + messages: list[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + *, + break_point: Optional[AgentBreakpoint] = None, + snapshot: Optional[AgentSnapshot] = None, # type: ignore[override] + system_prompt: Optional[str] = None, + tools: Optional[Union[list[Tool], Toolset, list[str]]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Process messages and execute tools until an exit condition is met. + + :param messages: List of Haystack ChatMessage objects to process. + :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. + :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint + for "tool_invoker". + :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains + the relevant information to restart the Agent execution from where it left off. + :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt. + :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. + When passing tool names, tools are selected from the Agent's originally configured tools. + :param kwargs: Additional data to pass to the State schema used by the Agent. + The keys must match the schema defined in the Agent's `state_schema`. + :returns: + A dictionary with the following keys: + - "messages": List of all messages exchanged during the agent's run. + - "last_message": The last message exchanged during the agent's run. + - Any additional keys defined in the `state_schema`. + :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. + :raises BreakpointException: If an agent breakpoint is triggered. + """ + # We pop parent_snapshot from kwargs to avoid passing it into State. + parent_snapshot = kwargs.pop("parent_snapshot", None) + agent_inputs = { + "messages": messages, + "streaming_callback": streaming_callback, + "break_point": break_point, + "snapshot": snapshot, + **kwargs, + } + self._runtime_checks(break_point=break_point, snapshot=snapshot) + + if snapshot: + exe_context = self._initialize_from_snapshot( + snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools + ) + else: + exe_context = self._initialize_fresh_execution( + messages=messages, + streaming_callback=streaming_callback, + requires_async=False, + system_prompt=system_prompt, + tools=tools, + **kwargs, + ) + + with self._create_agent_span() as span: + span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs)) + + while exe_context.counter < self.max_agent_steps: + # Handle breakpoint and ChatGenerator call + Agent._check_chat_generator_breakpoint( + execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot + ) + # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint + if exe_context.skip_chat_generator: + llm_messages = exe_context.state.get("messages", [])[-1:] + # Set to False so the next iteration will call the chat generator + exe_context.skip_chat_generator = False + else: + try: + result = Pipeline._run_component( + component_name="chat_generator", + component={"instance": self.chat_generator}, + inputs={ + "messages": exe_context.state.data["messages"], + **exe_context.chat_generator_inputs, + }, + component_visits=exe_context.component_visits, + parent_span=span, + ) + except PipelineRuntimeError as e: + pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator( + agent_name=getattr(self, "__component_name__", None), + execution_context=exe_context, + parent_snapshot=parent_snapshot, + ) + e.pipeline_snapshot = pipeline_snapshot + raise e + + llm_messages = result["replies"] + exe_context.state.set("messages", llm_messages) + + # Check if any of the LLM responses contain a tool call or if the LLM is not using tools + if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: + exe_context.counter += 1 + break + + # Apply confirmation strategies and update State and messages sent to ToolInvoker + try: + # Run confirmation strategies to get updated tool call messages and modified chat history + modified_tool_call_messages, new_chat_history = _process_confirmation_strategies( + confirmation_strategies=self._confirmation_strategies, + messages_with_tool_calls=llm_messages, + execution_context=exe_context, + ) + # Replace the chat history in state with the modified one + exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values) + except HITLBreakpointException as tbp_error: + # We create a break_point to pass into _check_tool_invoker_breakpoint + break_point = AgentBreakpoint( + agent_name=getattr(self, "__component_name__", ""), + break_point=ToolBreakpoint( + component_name="tool_invoker", + tool_name=tbp_error.tool_name, + visit_count=exe_context.component_visits["tool_invoker"], + snapshot_file_path=tbp_error.snapshot_file_path, + ), + ) + + # Handle breakpoint + Agent._check_tool_invoker_breakpoint( + execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot + ) + + # Run ToolInvoker + try: + # We only send the messages from the LLM to the tool invoker + tool_invoker_result = Pipeline._run_component( + component_name="tool_invoker", + component={"instance": self._tool_invoker}, + inputs={ + "messages": modified_tool_call_messages, + "state": exe_context.state, + **exe_context.tool_invoker_inputs, + }, + component_visits=exe_context.component_visits, + parent_span=span, + ) + except PipelineRuntimeError as e: + # Access the original Tool Invoker exception + original_error = e.__cause__ + tool_name = getattr(original_error, "tool_name", None) + + pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker( + tool_name=tool_name, + agent_name=getattr(self, "__component_name__", None), + execution_context=exe_context, + parent_snapshot=parent_snapshot, + ) + e.pipeline_snapshot = pipeline_snapshot + raise e + + # Set execution context tool execution decisions to empty after applying them b/c they should only + # be used once for the current tool calls + exe_context.tool_execution_decisions = None + tool_messages = tool_invoker_result["tool_messages"] + exe_context.state = tool_invoker_result["state"] + exe_context.state.set("messages", tool_messages) + + # Check if any LLM message's tool call name matches an exit condition + if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): + exe_context.counter += 1 + break + + # Increment the step counter + exe_context.counter += 1 + + if exe_context.counter >= self.max_agent_steps: + logger.warning( + "Agent reached maximum agent steps of {max_agent_steps}, stopping.", + max_agent_steps=self.max_agent_steps, + ) + span.set_content_tag("haystack.agent.output", exe_context.state.data) + span.set_tag("haystack.agent.steps_taken", exe_context.counter) + + result = {**exe_context.state.data} + if msgs := result.get("messages"): + result["last_message"] = msgs[-1] + return result + + async def run_async( + self, + messages: list[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + *, + break_point: Optional[AgentBreakpoint] = None, + snapshot: Optional[AgentSnapshot] = None, # type: ignore[override] + system_prompt: Optional[str] = None, + tools: Optional[Union[list[Tool], Toolset, list[str]]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Asynchronously process messages and execute tools until the exit condition is met. + + This is the asynchronous version of the `run` method. It follows the same logic but uses + asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator + if available. + + :param messages: List of Haystack ChatMessage objects to process. + :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the + LLM. The same callback can be configured to emit tool results when a tool is called. + :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint + for "tool_invoker". + :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains + the relevant information to restart the Agent execution from where it left off. + :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt. + :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. + :param kwargs: Additional data to pass to the State schema used by the Agent. + The keys must match the schema defined in the Agent's `state_schema`. + :returns: + A dictionary with the following keys: + - "messages": List of all messages exchanged during the agent's run. + - "last_message": The last message exchanged during the agent's run. + - Any additional keys defined in the `state_schema`. + :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. + :raises BreakpointException: If an agent breakpoint is triggered. + """ + # We pop parent_snapshot from kwargs to avoid passing it into State. + parent_snapshot = kwargs.pop("parent_snapshot", None) + agent_inputs = { + "messages": messages, + "streaming_callback": streaming_callback, + "break_point": break_point, + "snapshot": snapshot, + **kwargs, + } + self._runtime_checks(break_point=break_point, snapshot=snapshot) + + if snapshot: + exe_context = self._initialize_from_snapshot( + snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools + ) + else: + exe_context = self._initialize_fresh_execution( + messages=messages, + streaming_callback=streaming_callback, + requires_async=True, + system_prompt=system_prompt, + tools=tools, + **kwargs, + ) + + with self._create_agent_span() as span: + span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs)) + + while exe_context.counter < self.max_agent_steps: + # Handle breakpoint and ChatGenerator call + self._check_chat_generator_breakpoint( + execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot + ) + # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint + if exe_context.skip_chat_generator: + llm_messages = exe_context.state.get("messages", [])[-1:] + # Set to False so the next iteration will call the chat generator + exe_context.skip_chat_generator = False + else: + result = await AsyncPipeline._run_component_async( + component_name="chat_generator", + component={"instance": self.chat_generator}, + component_inputs={ + "messages": exe_context.state.data["messages"], + **exe_context.chat_generator_inputs, + }, + component_visits=exe_context.component_visits, + parent_span=span, + ) + llm_messages = result["replies"] + exe_context.state.set("messages", llm_messages) + + # Check if any of the LLM responses contain a tool call or if the LLM is not using tools + if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: + exe_context.counter += 1 + break + + # Apply confirmation strategies and update State and messages sent to ToolInvoker + try: + # Run confirmation strategies to get updated tool call messages and modified chat history + modified_tool_call_messages, new_chat_history = _process_confirmation_strategies( + confirmation_strategies=self._confirmation_strategies, + messages_with_tool_calls=llm_messages, + execution_context=exe_context, + ) + # Replace the chat history in state with the modified one + exe_context.state.set(key="messages", value=new_chat_history, handler_override=replace_values) + except HITLBreakpointException as tbp_error: + # We create a break_point to pass into _check_tool_invoker_breakpoint + break_point = AgentBreakpoint( + agent_name=getattr(self, "__component_name__", ""), + break_point=ToolBreakpoint( + component_name="tool_invoker", + tool_name=tbp_error.tool_name, + visit_count=exe_context.component_visits["tool_invoker"], + snapshot_file_path=tbp_error.snapshot_file_path, + ), + ) + + # Handle breakpoint + Agent._check_tool_invoker_breakpoint( + execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot + ) + + # Run ToolInvoker + # We only send the messages from the LLM to the tool invoker + tool_invoker_result = await AsyncPipeline._run_component_async( + component_name="tool_invoker", + component={"instance": self._tool_invoker}, + component_inputs={ + "messages": modified_tool_call_messages, + "state": exe_context.state, + **exe_context.tool_invoker_inputs, + }, + component_visits=exe_context.component_visits, + parent_span=span, + ) + + # Set execution context tool execution decisions to empty after applying them b/c they should only + # be used once for the current tool calls + exe_context.tool_execution_decisions = None + tool_messages = tool_invoker_result["tool_messages"] + exe_context.state = tool_invoker_result["state"] + exe_context.state.set("messages", tool_messages) + + # Check if any LLM message's tool call name matches an exit condition + if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): + exe_context.counter += 1 + break + + # Increment the step counter + exe_context.counter += 1 + + if exe_context.counter >= self.max_agent_steps: + logger.warning( + "Agent reached maximum agent steps of {max_agent_steps}, stopping.", + max_agent_steps=self.max_agent_steps, + ) + span.set_content_tag("haystack.agent.output", exe_context.state.data) + span.set_tag("haystack.agent.steps_taken", exe_context.counter) + + result = {**exe_context.state.data} + if msgs := result.get("messages"): + result["last_message"] = msgs[-1] + return result + + def to_dict(self) -> dict[str, Any]: + """ + Serialize the component to a dictionary. + + :return: Dictionary with serialized data + """ + data = super(Agent, self).to_dict() + data["init_parameters"]["confirmation_strategies"] = ( + {name: strategy.to_dict() for name, strategy in self._confirmation_strategies.items()} + if self._confirmation_strategies + else None + ) + return data + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Agent": + """ + Deserialize the agent from a dictionary. + + :param data: Dictionary to deserialize from + :return: Deserialized agent + """ + init_params = data.get("init_parameters", {}) + + deserialize_chatgenerator_inplace(init_params, key="chat_generator") + + if "state_schema" in init_params: + init_params["state_schema"] = _schema_from_dict(init_params["state_schema"]) + + if init_params.get("streaming_callback") is not None: + init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) + + deserialize_tools_or_toolset_inplace(init_params, key="tools") + + if "confirmation_strategies" in init_params and init_params["confirmation_strategies"] is not None: + for name, strategy_dict in init_params["confirmation_strategies"].items(): + strategy_class = import_class_by_name(strategy_dict["type"]) + if not hasattr(strategy_class, "from_dict"): + raise TypeError(f"{strategy_class} does not have from_dict method implemented.") + init_params["confirmation_strategies"][name] = strategy_class.from_dict(strategy_dict) + + return default_from_dict(cls, data) diff --git a/haystack_experimental/components/agents/human_in_the_loop/__init__.py b/haystack_experimental/components/agents/human_in_the_loop/__init__.py new file mode 100644 index 00000000..ab26024a --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import sys +from typing import TYPE_CHECKING + +from lazy_imports import LazyImporter + +_import_structure = { + "dataclasses": ["ConfirmationUIResult", "ToolExecutionDecision"], + "errors": ["HITLBreakpointException"], + "policies": ["AlwaysAskPolicy", "NeverAskPolicy", "AskOncePolicy"], + "strategies": ["BlockingConfirmationStrategy", "BreakpointConfirmationStrategy"], + "types": ["ConfirmationPolicy", "ConfirmationUI", "ConfirmationStrategy"], + "user_interfaces": ["RichConsoleUI", "SimpleConsoleUI"], +} + +if TYPE_CHECKING: + from .dataclasses import ConfirmationUIResult as ConfirmationUIResult + from .dataclasses import ToolExecutionDecision as ToolExecutionDecision + from .errors import HITLBreakpointException as HITLBreakpointException + from .policies import AlwaysAskPolicy as AlwaysAskPolicy + from .policies import AskOncePolicy as AskOncePolicy + from .policies import NeverAskPolicy as NeverAskPolicy + from .strategies import BlockingConfirmationStrategy as BlockingConfirmationStrategy + from .strategies import BreakpointConfirmationStrategy as BreakpointConfirmationStrategy + from .types import ConfirmationPolicy as ConfirmationPolicy + from .types import ConfirmationStrategy as ConfirmationStrategy + from .types import ConfirmationUI as ConfirmationUI + from .user_interfaces import RichConsoleUI as RichConsoleUI + from .user_interfaces import SimpleConsoleUI as SimpleConsoleUI + +else: + sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure) diff --git a/haystack_experimental/components/agents/human_in_the_loop/breakpoint.py b/haystack_experimental/components/agents/human_in_the_loop/breakpoint.py new file mode 100644 index 00000000..562435b4 --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/breakpoint.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint +from haystack.utils import _deserialize_value_with_schema + +from haystack_experimental.components.agents.human_in_the_loop.strategies import _prepare_tool_args + + +def get_tool_calls_and_descriptions_from_snapshot( + agent_snapshot: AgentSnapshot, breakpoint_tool_only: bool = True +) -> tuple[list[dict], dict[str, str]]: + """ + Extract tool calls and tool descriptions from an AgentSnapshot. + + By default, only the tool call that caused the breakpoint is processed and its arguments are reconstructed. + This is useful for scenarios where you want to present the relevant tool call and its description + to a human for confirmation before execution. + + :param agent_snapshot: The AgentSnapshot from which to extract tool calls and descriptions. + :param breakpoint_tool_only: If True, only the tool call that caused the breakpoint is returned. If False, all tool + calls are returned. + :returns: + A tuple containing a list of tool call dictionaries and a dictionary of tool descriptions + """ + break_point = agent_snapshot.break_point.break_point + if not isinstance(break_point, ToolBreakpoint): + raise ValueError("The provided AgentSnapshot does not contain a ToolBreakpoint.") + + tool_caused_break_point = break_point.tool_name + + # Deserialize the tool invoker inputs from the snapshot + tool_invoker_inputs = _deserialize_value_with_schema(agent_snapshot.component_inputs["tool_invoker"]) + tool_call_messages = tool_invoker_inputs["messages"] + state = tool_invoker_inputs["state"] + tool_name_to_tool = {t.name: t for t in tool_invoker_inputs["tools"]} + + tool_calls = [] + for msg in tool_call_messages: + if msg.tool_calls: + tool_calls.extend(msg.tool_calls) + serialized_tcs = [tc.to_dict() for tc in tool_calls] + + # Reconstruct the final arguments for each tool call + tool_descriptions = {} + updated_tool_calls = [] + for tc in serialized_tcs: + # Only process the tool that caused the breakpoint if breakpoint_tool_only is True + if breakpoint_tool_only and tc["tool_name"] != tool_caused_break_point: + continue + + final_args = _prepare_tool_args( + tool=tool_name_to_tool[tc["tool_name"]], + tool_call_arguments=tc["arguments"], + state=state, + streaming_callback=tool_invoker_inputs.get("streaming_callback", None), + enable_streaming_passthrough=tool_invoker_inputs.get("enable_streaming_passthrough", False), + ) + updated_tool_calls.append({**tc, "arguments": final_args}) + tool_descriptions[tc["tool_name"]] = tool_name_to_tool[tc["tool_name"]].description + + return updated_tool_calls, tool_descriptions diff --git a/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py b/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py new file mode 100644 index 00000000..74309819 --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/dataclasses.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import asdict, dataclass +from typing import Any, Optional + + +@dataclass +class ConfirmationUIResult: + """ + Result of the confirmation UI interaction. + + :param action: + The action taken by the user such as "confirm", "reject", or "modify". + This action type is not enforced to allow for custom actions to be implemented. + :param feedback: + Optional feedback message from the user. For example, if the user rejects the tool execution, + they might provide a reason for the rejection. + :param new_tool_params: + Optional set of new parameters for the tool. For example, if the user chooses to modify the tool parameters, + they can provide a new set of parameters here. + """ + + action: str # "confirm", "reject", "modify" + feedback: Optional[str] = None + new_tool_params: Optional[dict[str, Any]] = None + + +@dataclass +class ToolExecutionDecision: + """ + Decision made regarding tool execution. + + :param tool_name: + The name of the tool to be executed. + :param execute: + A boolean indicating whether to execute the tool with the provided parameters. + :param tool_call_id: + Optional unique identifier for the tool call. This can be used to track and correlate the decision with a + specific tool invocation. + :param feedback: + Optional feedback message. + For example, if the tool execution is rejected, this can contain the reason. Or if the tool parameters were + modified, this can contain the modification details. + :param final_tool_params: + Optional final parameters for the tool if execution is confirmed or modified. + """ + + tool_name: str + execute: bool + tool_call_id: Optional[str] = None + feedback: Optional[str] = None + final_tool_params: Optional[dict[str, Any]] = None + + def to_dict(self) -> dict[str, Any]: + """ + Convert the ToolExecutionDecision to a dictionary representation. + + :return: A dictionary containing the tool execution decision details. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ToolExecutionDecision": + """ + Populate the ToolExecutionDecision from a dictionary representation. + + :param data: A dictionary containing the tool execution decision details. + :return: An instance of ToolExecutionDecision. + """ + return cls(**data) diff --git a/haystack_experimental/components/agents/human_in_the_loop/errors.py b/haystack_experimental/components/agents/human_in_the_loop/errors.py new file mode 100644 index 00000000..1ec33161 --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/errors.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + + +class HITLBreakpointException(Exception): + """ + Exception raised when a tool execution is paused by a ConfirmationStrategy (e.g. BreakpointConfirmationStrategy). + """ + + def __init__( + self, message: str, tool_name: str, snapshot_file_path: str, tool_call_id: Optional[str] = None + ) -> None: + """ + Initialize the HITLBreakpointException. + + :param message: The exception message. + :param tool_name: The name of the tool whose execution is paused. + :param snapshot_file_path: The file path to the saved pipeline snapshot. + :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate + the decision with a specific tool invocation. + """ + super().__init__(message) + self.tool_name = tool_name + self.snapshot_file_path = snapshot_file_path + self.tool_call_id = tool_call_id diff --git a/haystack_experimental/components/agents/human_in_the_loop/policies.py b/haystack_experimental/components/agents/human_in_the_loop/policies.py new file mode 100644 index 00000000..f80d4a54 --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/policies.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult +from haystack_experimental.components.agents.human_in_the_loop.types import ConfirmationPolicy + + +class AlwaysAskPolicy(ConfirmationPolicy): + """Always ask for confirmation.""" + + def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool: + """ + Always ask for confirmation before executing the tool. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + :returns: Always returns True, indicating confirmation is needed. + """ + return True + + +class NeverAskPolicy(ConfirmationPolicy): + """Never ask for confirmation.""" + + def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool: + """ + Never ask for confirmation, always proceed with tool execution. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + :returns: Always returns False, indicating no confirmation is needed. + """ + return False + + +class AskOncePolicy(ConfirmationPolicy): + """Ask only once per tool with specific parameters.""" + + def __init__(self): + self._asked_tools = {} + + def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool: + """ + Ask for confirmation only once per tool with specific parameters. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + :returns: True if confirmation is needed, False if already asked with the same parameters. + """ + # Don't ask again if we've already asked for this tool with the same parameters + return not (tool_name in self._asked_tools and self._asked_tools[tool_name] == tool_params) + + def update_after_confirmation( + self, + tool_name: str, + tool_description: str, + tool_params: dict[str, Any], + confirmation_result: ConfirmationUIResult, + ) -> None: + """ + Store the tool and parameters if the action was "confirm" to avoid asking again. + + This method updates the internal state to remember that the user has already confirmed the execution of the + tool with the given parameters. + + :param tool_name: The name of the tool that was executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters that were passed to the tool. + :param confirmation_result: The result from the confirmation UI. + """ + if confirmation_result.action == "confirm": + self._asked_tools[tool_name] = tool_params diff --git a/haystack_experimental/components/agents/human_in_the_loop/strategies.py b/haystack_experimental/components/agents/human_in_the_loop/strategies.py new file mode 100644 index 00000000..98679dfb --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/strategies.py @@ -0,0 +1,455 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import replace +from typing import TYPE_CHECKING, Any, Optional + +from haystack.components.agents.state import State +from haystack.components.tools.tool_invoker import ToolInvoker +from haystack.core.serialization import default_from_dict, default_to_dict, import_class_by_name +from haystack.dataclasses import ChatMessage, StreamingCallbackT +from haystack.tools import Tool + +from haystack_experimental.components.agents.human_in_the_loop import ( + ConfirmationPolicy, + ConfirmationStrategy, + ConfirmationUI, + HITLBreakpointException, + ToolExecutionDecision, +) + +if TYPE_CHECKING: + from haystack_experimental.components.agents.agent import _ExecutionContext + + +_REJECTION_FEEDBACK_TEMPLATE = "Tool execution for '{tool_name}' was rejected by the user." +_MODIFICATION_FEEDBACK_TEMPLATE = ( + "The parameters for tool '{tool_name}' were updated by the user to:\n{final_tool_params}" +) + + +class BlockingConfirmationStrategy: + """ + Confirmation strategy that blocks execution to gather user feedback. + """ + + def __init__(self, confirmation_policy: ConfirmationPolicy, confirmation_ui: ConfirmationUI) -> None: + """ + Initialize the BlockingConfirmationStrategy with a confirmation policy and UI. + + :param confirmation_policy: + The confirmation policy to determine when to ask for user confirmation. + :param confirmation_ui: + The user interface to interact with the user for confirmation. + """ + self.confirmation_policy = confirmation_policy + self.confirmation_ui = confirmation_ui + + def run( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None + ) -> ToolExecutionDecision: + """ + Run the human-in-the-loop strategy for a given tool and its parameters. + + :param tool_name: + The name of the tool to be executed. + :param tool_description: + The description of the tool. + :param tool_params: + The parameters to be passed to the tool. + :param tool_call_id: + Optional unique identifier for the tool call. This can be used to track and correlate the decision with a + specific tool invocation. + + :returns: + A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a + feedback message if rejected. + """ + # Check if we should ask based on policy + if not self.confirmation_policy.should_ask( + tool_name=tool_name, tool_description=tool_description, tool_params=tool_params + ): + return ToolExecutionDecision( + tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params + ) + + # Get user confirmation through UI + confirmation_ui_result = self.confirmation_ui.get_user_confirmation(tool_name, tool_description, tool_params) + + # Pass back the result to the policy for any learning/updating + self.confirmation_policy.update_after_confirmation( + tool_name, tool_description, tool_params, confirmation_ui_result + ) + + # Process the confirmation result + final_args = {} + if confirmation_ui_result.action == "reject": + explanation_text = _REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tool_name) + if confirmation_ui_result.feedback: + explanation_text += f" With feedback: {confirmation_ui_result.feedback}" + return ToolExecutionDecision( + tool_name=tool_name, execute=False, tool_call_id=tool_call_id, feedback=explanation_text + ) + elif confirmation_ui_result.action == "modify" and confirmation_ui_result.new_tool_params: + # Update the tool call params with the new params + final_args.update(confirmation_ui_result.new_tool_params) + explanation_text = _MODIFICATION_FEEDBACK_TEMPLATE.format(tool_name=tool_name, final_tool_params=final_args) + if confirmation_ui_result.feedback: + explanation_text += f" With feedback: {confirmation_ui_result.feedback}" + return ToolExecutionDecision( + tool_name=tool_name, + tool_call_id=tool_call_id, + execute=True, + feedback=explanation_text, + final_tool_params=final_args, + ) + else: # action == "confirm" + return ToolExecutionDecision( + tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params + ) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the BlockingConfirmationStrategy to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, confirmation_policy=self.confirmation_policy.to_dict(), confirmation_ui=self.confirmation_ui.to_dict() + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "BlockingConfirmationStrategy": + """ + Deserializes the BlockingConfirmationStrategy from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized BlockingConfirmationStrategy. + """ + policy_data = data["init_parameters"]["confirmation_policy"] + policy_class = import_class_by_name(policy_data["type"]) + if not hasattr(policy_class, "from_dict"): + raise ValueError(f"Class {policy_class} does not implement from_dict method.") + ui_data = data["init_parameters"]["confirmation_ui"] + ui_class = import_class_by_name(ui_data["type"]) + if not hasattr(ui_class, "from_dict"): + raise ValueError(f"Class {ui_class} does not implement from_dict method.") + return cls(confirmation_policy=policy_class.from_dict(policy_data), confirmation_ui=ui_class.from_dict(ui_data)) + + +class BreakpointConfirmationStrategy: + """ + Confirmation strategy that raises a tool breakpoint exception to pause execution and gather user feedback. + + This strategy is designed for scenarios where immediate user interaction is not possible. + When a tool execution requires confirmation, it raises an `HITLBreakpointException`, which is caught by the Agent. + The Agent then serialize its current state, including the tool call details. This information can then be used to + notify a user to review and confirm the tool execution. + """ + + def __init__(self, snapshot_file_path: str) -> None: + """ + Initialize the BreakpointConfirmationStrategy. + + :param snapshot_file_path: The path to the directory that the snapshot should be saved. + """ + self.snapshot_file_path = snapshot_file_path + + def run( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None + ) -> ToolExecutionDecision: + """ + Run the breakpoint confirmation strategy for a given tool and its parameters. + + :param tool_name: + The name of the tool to be executed. + :param tool_description: + The description of the tool. + :param tool_params: + The parameters to be passed to the tool. + :param tool_call_id: + Optional unique identifier for the tool call. This can be used to track and correlate the decision with a + specific tool invocation. + + :raises HITLBreakpointException: + Always raises an `HITLBreakpointException` exception to signal that user confirmation is required. + + :returns: + This method does not return; it always raises an exception. + """ + raise HITLBreakpointException( + message=f"Tool execution for '{tool_name}' requires user confirmation.", + tool_name=tool_name, + tool_call_id=tool_call_id, + snapshot_file_path=self.snapshot_file_path, + ) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the BreakpointConfirmationStrategy to a dictionary. + """ + return default_to_dict(self, snapshot_file_path=self.snapshot_file_path) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "BreakpointConfirmationStrategy": + """ + Deserializes the BreakpointConfirmationStrategy from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized BreakpointConfirmationStrategy. + """ + return default_from_dict(cls, data) + + +def _prepare_tool_args( + *, + tool: Tool, + tool_call_arguments: dict[str, Any], + state: State, + streaming_callback: Optional[StreamingCallbackT] = None, + enable_streaming_passthrough: bool = False, +) -> dict[str, Any]: + """ + Prepare the final arguments for a tool by injecting state inputs and optionally a streaming callback. + + :param tool: + The tool instance to prepare arguments for. + :param tool_call_arguments: + The initial arguments provided for the tool call. + :param state: + The current state containing inputs to be injected into the tool arguments. + :param streaming_callback: + Optional streaming callback to be injected if enabled and applicable. + :param enable_streaming_passthrough: + Flag indicating whether to inject the streaming callback into the tool arguments. + + :returns: + A dictionary of final arguments ready for tool invocation. + """ + # Combine user + state inputs + final_args = ToolInvoker._inject_state_args(tool, tool_call_arguments.copy(), state) + # Check whether to inject streaming_callback + if ( + enable_streaming_passthrough + and streaming_callback is not None + and "streaming_callback" not in final_args + and "streaming_callback" in ToolInvoker._get_func_params(tool) + ): + final_args["streaming_callback"] = streaming_callback + return final_args + + +def _process_confirmation_strategies( + *, + confirmation_strategies: dict[str, ConfirmationStrategy], + messages_with_tool_calls: list[ChatMessage], + execution_context: "_ExecutionContext", +) -> tuple[list[ChatMessage], list[ChatMessage]]: + """ + Run the confirmation strategies and return modified tool call messages and updated chat history. + + :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies + :param messages_with_tool_calls: Chat messages containing tool calls + :param execution_context: The current execution context of the agent + :returns: + Tuple of modified messages with confirmed tool calls and updated chat history + """ + # Run confirmation strategies and get tool execution decisions + teds = _run_confirmation_strategies( + confirmation_strategies=confirmation_strategies, + messages_with_tool_calls=messages_with_tool_calls, + execution_context=execution_context, + ) + + # Apply tool execution decisions to messages_with_tool_calls + rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions( + tool_call_messages=messages_with_tool_calls, + tool_execution_decisions=teds, + ) + + # Update the chat history with rejection messages and new tool call messages + new_chat_history = _update_chat_history( + chat_history=execution_context.state.get("messages"), + rejection_messages=rejection_messages, + tool_call_and_explanation_messages=modified_tool_call_messages, + ) + + return modified_tool_call_messages, new_chat_history + + +def _run_confirmation_strategies( + confirmation_strategies: dict[str, ConfirmationStrategy], + messages_with_tool_calls: list[ChatMessage], + execution_context: "_ExecutionContext", +) -> list[ToolExecutionDecision]: + """ + Run confirmation strategies for tool calls in the provided chat messages. + + :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies + :param messages_with_tool_calls: Messages containing tool calls to process + :param execution_context: The current execution context containing state and inputs + :returns: + A list of ToolExecutionDecision objects representing the decisions made for each tool call. + """ + state = execution_context.state + tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]} + existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else [] + existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name} + existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id} + + teds = [] + for message in messages_with_tool_calls: + if not message.tool_calls: + continue + + for tool_call in message.tool_calls: + tool_name = tool_call.tool_name + tool_to_invoke = tools_with_names[tool_name] + + # Prepare final tool args + final_args = _prepare_tool_args( + tool=tool_to_invoke, + tool_call_arguments=tool_call.arguments, + state=state, + streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"), + enable_streaming_passthrough=execution_context.tool_invoker_inputs.get( + "enable_streaming_passthrough", False + ), + ) + + # Get tool execution decisions from confirmation strategies + # If no confirmation strategy is defined for this tool, proceed with execution + if tool_name not in confirmation_strategies: + teds.append( + ToolExecutionDecision( + tool_call_id=tool_call.id, + tool_name=tool_name, + execute=True, + final_tool_params=final_args, + ) + ) + continue + + # Check if there's already a decision for this tool call in the execution context + ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name) + + # If not, run the confirmation strategy + if not ted: + ted = confirmation_strategies[tool_name].run( + tool_name=tool_name, tool_description=tool_to_invoke.description, tool_params=final_args + ) + teds.append(ted) + + return teds + + +def _apply_tool_execution_decisions( + tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision] +) -> tuple[list[ChatMessage], list[ChatMessage]]: + """ + Apply the tool execution decisions to the tool call messages. + + :param tool_call_messages: The tool call messages to apply the decisions to. + :param tool_execution_decisions: The tool execution decisions to apply. + :returns: + A tuple containing: + - A list of rejection messages for rejected tool calls. These are pairs of tool call and tool call result + messages. + - A list of tool call messages for confirmed or modified tool calls. If tool parameters were modified, + a user message explaining the modification is included before the tool call message. + """ + decision_by_id = {d.tool_call_id: d for d in tool_execution_decisions if d.tool_call_id} + decision_by_name = {d.tool_name: d for d in tool_execution_decisions if d.tool_name} + + def make_assistant_message(chat_message, tool_calls): + return ChatMessage.from_assistant( + text=chat_message.text, + meta=chat_message.meta, + name=chat_message.name, + tool_calls=tool_calls, + reasoning=chat_message.reasoning, + ) + + new_tool_call_messages = [] + rejection_messages = [] + + for chat_msg in tool_call_messages: + new_tool_calls = [] + for tc in chat_msg.tool_calls or []: + ted = decision_by_id.get(tc.id or "") or decision_by_name.get(tc.tool_name) + if not ted: + # This shouldn't happen, if so something went wrong in _run_confirmation_strategies + continue + + if not ted.execute: + # rejected tool call + tool_result_text = ted.feedback or _REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tc.tool_name) + rejection_messages.extend( + [ + make_assistant_message(chat_msg, [tc]), + ChatMessage.from_tool(tool_result=tool_result_text, origin=tc, error=True), + ] + ) + continue + + # Covers confirm and modify cases + final_args = ted.final_tool_params or {} + if tc.arguments != final_args: + # In the modify case we add a user message explaining the modification otherwise the LLM won't know + # why the tool parameters changed and will likely just try and call the tool again with the + # original parameters. + user_text = ted.feedback or _MODIFICATION_FEEDBACK_TEMPLATE.format( + tool_name=tc.tool_name, final_tool_params=final_args + ) + new_tool_call_messages.append(ChatMessage.from_user(text=user_text)) + new_tool_calls.append(replace(tc, arguments=final_args)) + + # Only add the tool call message if there are any tool calls left (i.e. not all were rejected) + if new_tool_calls: + new_tool_call_messages.append(make_assistant_message(chat_msg, new_tool_calls)) + + return rejection_messages, new_tool_call_messages + + +def _update_chat_history( + chat_history: list[ChatMessage], + rejection_messages: list[ChatMessage], + tool_call_and_explanation_messages: list[ChatMessage], +) -> list[ChatMessage]: + """ + Update the chat history to include rejection messages and tool call messages at the appropriate positions. + + Steps: + 1. Identify the last user message and the last tool message in the current chat history. + 2. Determine the insertion point as the maximum index of these two messages. + 3. Create a new chat history that includes: + - All messages up to the insertion point. + - Any rejection messages (pairs of tool call and tool call result messages). + - Any tool call messages for confirmed or modified tool calls, including user messages explaining modifications. + + :param chat_history: The current chat history. + :param rejection_messages: Chat messages to add for rejected tool calls (pairs of tool call and tool call result + messages). + :param tool_call_and_explanation_messages: Tool call messages for confirmed or modified tool calls, which may + include user messages explaining modifications. + :returns: + The updated chat history. + """ + user_indices = [i for i, message in enumerate(chat_history) if message.is_from("user")] + tool_indices = [i for i, message in enumerate(chat_history) if message.is_from("tool")] + + last_user_idx = max(user_indices) if user_indices else -1 + last_tool_idx = max(tool_indices) if tool_indices else -1 + + insertion_point = max(last_user_idx, last_tool_idx) + + new_chat_history = chat_history[: insertion_point + 1] + rejection_messages + tool_call_and_explanation_messages + return new_chat_history diff --git a/haystack_experimental/components/agents/human_in_the_loop/types.py b/haystack_experimental/components/agents/human_in_the_loop/types.py new file mode 100644 index 00000000..8a0626d3 --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/types.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Protocol + +from haystack.core.serialization import default_from_dict, default_to_dict + +from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ( + ConfirmationUIResult, + ToolExecutionDecision, +) + +# Ellipsis are needed to define the Protocol but pylint complains. See https://github.com/pylint-dev/pylint/issues/9319. +# pylint: disable=unnecessary-ellipsis + + +class ConfirmationUI(Protocol): + """Base class for confirmation UIs.""" + + def get_user_confirmation( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any] + ) -> ConfirmationUIResult: + """Get user confirmation for tool execution.""" + ... + + def to_dict(self) -> dict[str, Any]: + """Serialize the UI to a dictionary.""" + return default_to_dict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ConfirmationUI": + """Deserialize the ConfirmationUI from a dictionary.""" + return default_from_dict(cls, data) + + +class ConfirmationPolicy(Protocol): + """Base class for confirmation policies.""" + + def should_ask(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> bool: + """Determine whether to ask for confirmation.""" + ... + + def update_after_confirmation( + self, + tool_name: str, + tool_description: str, + tool_params: dict[str, Any], + confirmation_result: ConfirmationUIResult, + ) -> None: + """Update the policy based on the confirmation UI result.""" + pass + + def to_dict(self) -> dict[str, Any]: + """Serialize the policy to a dictionary.""" + return default_to_dict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ConfirmationPolicy": + """Deserialize the policy from a dictionary.""" + return default_from_dict(cls, data) + + +class ConfirmationStrategy(Protocol): + def run( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any], tool_call_id: Optional[str] = None + ) -> ToolExecutionDecision: + """ + Run the confirmation strategy for a given tool and its parameters. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + :param tool_call_id: Optional unique identifier for the tool call. This can be used to track and correlate + the decision with a specific tool invocation. + + :returns: + The result of the confirmation strategy (e.g., tool output, rejection message, etc.). + """ + ... + + def to_dict(self) -> dict[str, Any]: + """Serialize the strategy to a dictionary.""" + ... + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ConfirmationStrategy": + """Deserialize the strategy from a dictionary.""" + ... diff --git a/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py b/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py new file mode 100644 index 00000000..1f6f2430 --- /dev/null +++ b/haystack_experimental/components/agents/human_in_the_loop/user_interfaces.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from threading import Lock +from typing import Any, Optional + +from haystack.core.serialization import default_to_dict +from rich.console import Console +from rich.panel import Panel +from rich.prompt import Prompt + +from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult +from haystack_experimental.components.agents.human_in_the_loop.types import ConfirmationUI + +_ui_lock = Lock() + + +class RichConsoleUI(ConfirmationUI): + """Rich console interface for user interaction.""" + + def __init__(self, console: Optional[Console] = None): + self.console = console or Console() + + def get_user_confirmation( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any] + ) -> ConfirmationUIResult: + """ + Get user confirmation for tool execution via rich console prompts. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + :returns: ConfirmationUIResult based on user input. + """ + with _ui_lock: + self._display_tool_info(tool_name, tool_description, tool_params) + # If wrong input is provided, Prompt.ask will re-prompt + choice = Prompt.ask("\nYour choice", choices=["y", "n", "m"], default="y", console=self.console) + return self._process_choice(choice, tool_params) + + def _display_tool_info(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> None: + """ + Display tool information and parameters in a rich panel. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + """ + lines = [ + f"[bold yellow]Tool:[/bold yellow] {tool_name}", + f"[bold yellow]Description:[/bold yellow] {tool_description}", + "\n[bold yellow]Arguments:[/bold yellow]", + ] + + if tool_params: + for k, v in tool_params.items(): + lines.append(f"[cyan]{k}:[/cyan] {v}") + else: + lines.append(" (No arguments)") + + self.console.print(Panel("\n".join(lines), title="🔧 Tool Execution Request", title_align="left")) + + def _process_choice(self, choice: str, tool_params: dict[str, Any]) -> ConfirmationUIResult: + """ + Process the user's choice and return the corresponding ConfirmationUIResult. + + :param choice: The user's choice ('y', 'n', or 'm'). + :param tool_params: The original tool parameters. + :returns: + ConfirmationUIResult based on user input. + """ + if choice == "y": + return ConfirmationUIResult(action="confirm") + elif choice == "m": + return self._modify_params(tool_params) + else: # reject + feedback = Prompt.ask("Feedback message (optional)", default="", console=self.console) + return ConfirmationUIResult(action="reject", feedback=feedback or None) + + def _modify_params(self, tool_params: dict[str, Any]) -> ConfirmationUIResult: + """ + Prompt the user to modify tool parameters. + + :param tool_params: The original tool parameters. + :returns: + ConfirmationUIResult with modified parameters. + """ + new_params: dict[str, Any] = {} + for k, v in tool_params.items(): + # We don't JSON dump strings to avoid users needing to input extra quotes + default_val = json.dumps(v) if not isinstance(v, str) else v + while True: + new_val = Prompt.ask(f"Modify '{k}'", default=default_val, console=self.console) + try: + if isinstance(v, str): + # Always treat input as string + new_params[k] = new_val + else: + # Parse JSON for all non-string types + new_params[k] = json.loads(new_val) + break + except json.JSONDecodeError: + self.console.print("[red]❌ Invalid JSON, please try again.[/red]") + + return ConfirmationUIResult(action="modify", new_tool_params=new_params) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the RichConsoleConfirmationUI to a dictionary. + + :returns: + Dictionary with serialized data. + """ + # Note: Console object is not serializable; we store None + return default_to_dict(self, console=None) + + +class SimpleConsoleUI(ConfirmationUI): + """Simple console interface using standard input/output.""" + + def get_user_confirmation( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any] + ) -> ConfirmationUIResult: + """ + Get user confirmation for tool execution via simple console prompts. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + """ + with _ui_lock: + self._display_tool_info(tool_name, tool_description, tool_params) + valid_choices = {"y", "yes", "n", "no", "m", "modify"} + while True: + choice = input("Confirm execution? (y=confirm / n=reject / m=modify): ").strip().lower() + if choice in valid_choices: + break + print("Invalid input. Please enter 'y', 'n', or 'm'.") + return self._process_choice(choice, tool_params) + + def _display_tool_info(self, tool_name: str, tool_description: str, tool_params: dict[str, Any]) -> None: + """ + Display tool information and parameters in the console. + + :param tool_name: The name of the tool to be executed. + :param tool_description: The description of the tool. + :param tool_params: The parameters to be passed to the tool. + """ + print("\n--- Tool Execution Request ---") + print(f"Tool: {tool_name}") + print(f"Description: {tool_description}") + print("Arguments:") + if tool_params: + for k, v in tool_params.items(): + print(f" {k}: {v}") + else: + print(" (No arguments)") + print("-" * 30) + + def _process_choice(self, choice: str, tool_params: dict[str, Any]) -> ConfirmationUIResult: + """ + Process the user's choice and return the corresponding ConfirmationUIResult. + + :param choice: The user's choice ('y', 'n', or 'm'). + :param tool_params: The original tool parameters. + :returns: + ConfirmationUIResult based on user input. + """ + if choice in ("y", "yes"): + return ConfirmationUIResult(action="confirm") + elif choice in ("m", "modify"): + return self._modify_params(tool_params) + else: # reject + feedback = input("Feedback message (optional): ").strip() + return ConfirmationUIResult(action="reject", feedback=feedback or None) + + def _modify_params(self, tool_params: dict[str, Any]) -> ConfirmationUIResult: + """ + Prompt the user to modify tool parameters. + + :param tool_params: The original tool parameters. + :returns: + ConfirmationUIResult with modified parameters. + """ + new_params: dict[str, Any] = {} + + if not tool_params: + print("No parameters to modify, skipping modification.") + return ConfirmationUIResult(action="modify", new_tool_params=new_params) + + for k, v in tool_params.items(): + # We don't JSON dump strings to avoid users needing to input extra quotes + default_val = json.dumps(v) if not isinstance(v, str) else v + while True: + new_val = input(f"Modify '{k}' (current: {default_val}): ").strip() or default_val + try: + if isinstance(v, str): + # Always treat input as string + new_params[k] = new_val + else: + # Parse JSON for all non-string types + new_params[k] = json.loads(new_val) + break + except json.JSONDecodeError: + print("❌ Invalid JSON, please try again.") + + return ConfirmationUIResult(action="modify", new_tool_params=new_params) diff --git a/haystack_experimental/core/__init__.py b/haystack_experimental/core/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/haystack_experimental/core/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack_experimental/core/pipeline/__init__.py b/haystack_experimental/core/pipeline/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/haystack_experimental/core/pipeline/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py new file mode 100644 index 00000000..da782ebb --- /dev/null +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy +from dataclasses import replace +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional + +from haystack import logging +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import _save_pipeline_snapshot +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import AgentBreakpoint, PipelineSnapshot, PipelineState, ToolBreakpoint +from haystack.utils.base_serialization import _serialize_value_with_schema +from haystack.utils.misc import _get_output_dir + +from haystack_experimental.dataclasses.breakpoints import AgentSnapshot + +if TYPE_CHECKING: + from haystack_experimental.components.agents.agent import _ExecutionContext + from haystack_experimental.components.agents.human_in_the_loop import ToolExecutionDecision + + +logger = logging.getLogger(__name__) + + +def _create_agent_snapshot( + *, + component_visits: dict[str, int], + agent_breakpoint: AgentBreakpoint, + component_inputs: dict[str, Any], + tool_execution_decisions: Optional[list["ToolExecutionDecision"]] = None, +) -> AgentSnapshot: + """ + Create a snapshot of the agent's state. + + NOTE: Only difference to Haystack's native implementation is the addition of tool_execution_decisions to the + AgentSnapshot. + + :param component_visits: The visit counts for the agent's components. + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :param component_inputs: The inputs to the agent's components. + :param tool_execution_decisions: Optional list of ToolExecutionDecision objects representing decisions made + regarding tool executions. + :return: An AgentSnapshot containing the agent's state and component visits. + """ + return AgentSnapshot( + component_inputs={ + "chat_generator": _serialize_value_with_schema(deepcopy(component_inputs["chat_generator"])), + "tool_invoker": _serialize_value_with_schema(deepcopy(component_inputs["tool_invoker"])), + }, + component_visits=component_visits, + break_point=agent_breakpoint, + timestamp=datetime.now(), + tool_execution_decisions=tool_execution_decisions, + ) + + +def _create_pipeline_snapshot_from_tool_invoker( + *, + execution_context: "_ExecutionContext", + tool_name: Optional[str] = None, + agent_name: Optional[str] = None, + break_point: Optional[AgentBreakpoint] = None, + parent_snapshot: Optional[PipelineSnapshot] = None, +) -> PipelineSnapshot: + """ + Create a pipeline snapshot when a tool invoker breakpoint is raised or an exception during execution occurs. + + :param execution_context: The current execution context of the agent. + :param tool_name: The name of the tool that triggered the breakpoint, if available. + :param agent_name: The name of the agent component if present in a pipeline. + :param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one. + A scenario where a new breakpoint is created is when an exception occurs during tool execution and we want to + capture the state at that point. + :param parent_snapshot: An optional parent PipelineSnapshot to build upon. + :returns: + A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception. + """ + if break_point is None: + agent_breakpoint = AgentBreakpoint( + agent_name=agent_name or "agent", + break_point=ToolBreakpoint( + component_name="tool_invoker", + visit_count=execution_context.component_visits["tool_invoker"], + tool_name=tool_name, + snapshot_file_path=_get_output_dir("pipeline_snapshot"), + ), + ) + else: + agent_breakpoint = break_point + + messages = execution_context.state.data["messages"] + agent_snapshot = _create_agent_snapshot( + component_visits=execution_context.component_visits, + agent_breakpoint=agent_breakpoint, + component_inputs={ + "chat_generator": {"messages": messages[:-1], **execution_context.chat_generator_inputs}, + "tool_invoker": { + "messages": messages[-1:], # tool invoker consumes last msg from the chat_generator, contains tool call + "state": execution_context.state, + **execution_context.tool_invoker_inputs, + }, + }, + tool_execution_decisions=execution_context.tool_execution_decisions, + ) + if parent_snapshot is None: + # Create an empty pipeline snapshot if no parent snapshot is provided + final_snapshot = PipelineSnapshot( + pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}), + timestamp=agent_snapshot.timestamp, + break_point=agent_snapshot.break_point, + agent_snapshot=agent_snapshot, + original_input_data={}, + ordered_component_names=[], + include_outputs_from=set(), + ) + else: + final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) + + return final_snapshot + + +def _trigger_tool_invoker_breakpoint(*, llm_messages: list[ChatMessage], pipeline_snapshot: PipelineSnapshot) -> None: + """ + Check if a tool call breakpoint should be triggered before executing the tool invoker. + + NOTE: Only difference to Haystack's native implementation is that it includes the fix from + PR https://github.com/deepset-ai/haystack/pull/9853 where we make sure to check all tool calls in a chat message + when checking if a BreakpointException should be made. + + :param llm_messages: List of ChatMessage objects containing potential tool calls. + :param pipeline_snapshot: PipelineSnapshot object containing the state of the pipeline and Agent snapshot. + :raises BreakpointException: If the breakpoint is triggered, indicating a breakpoint has been reached for a tool + call. + """ + if not pipeline_snapshot.agent_snapshot: + raise ValueError("PipelineSnapshot must contain an AgentSnapshot to trigger a tool call breakpoint.") + + if not isinstance(pipeline_snapshot.agent_snapshot.break_point.break_point, ToolBreakpoint): + return + + tool_breakpoint = pipeline_snapshot.agent_snapshot.break_point.break_point + + # Check if we should break for this specific tool or all tools + if tool_breakpoint.tool_name is None: + # Break for any tool call + should_break = any(msg.tool_call for msg in llm_messages) + else: + # Break only for the specific tool + should_break = any( + tc.tool_name == tool_breakpoint.tool_name for msg in llm_messages for tc in msg.tool_calls or [] + ) + + if not should_break: + return # No breakpoint triggered + + _save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot) + + msg = ( + f"Breaking at {tool_breakpoint.component_name} visit count " + f"{pipeline_snapshot.agent_snapshot.component_visits[tool_breakpoint.component_name]}" + ) + if tool_breakpoint.tool_name: + msg += f" for tool {tool_breakpoint.tool_name}" + logger.info(msg) + + raise BreakpointException( + message=msg, + component=tool_breakpoint.component_name, + inputs=pipeline_snapshot.agent_snapshot.component_inputs, + results=pipeline_snapshot.agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"], + ) diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/haystack_experimental/dataclasses/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack_experimental/dataclasses/breakpoints.py b/haystack_experimental/dataclasses/breakpoints.py new file mode 100644 index 00000000..880badc4 --- /dev/null +++ b/haystack_experimental/dataclasses/breakpoints.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +from haystack.dataclasses.breakpoints import AgentBreakpoint +from haystack.dataclasses.breakpoints import AgentSnapshot as HaystackAgentSnapshot + +from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ToolExecutionDecision + + +@dataclass +class AgentSnapshot(HaystackAgentSnapshot): + tool_execution_decisions: Optional[list[ToolExecutionDecision]] = None + + def to_dict(self) -> dict[str, Any]: + """ + Convert the AgentSnapshot to a dictionary representation. + + :return: A dictionary containing the agent state, timestamp, and breakpoint. + """ + return { + "component_inputs": self.component_inputs, + "component_visits": self.component_visits, + "break_point": self.break_point.to_dict(), + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "tool_execution_decisions": [ted.to_dict() for ted in self.tool_execution_decisions] + if self.tool_execution_decisions + else None, + } + + @classmethod + def from_dict(cls, data: dict) -> "AgentSnapshot": + """ + Populate the AgentSnapshot from a dictionary representation. + + :param data: A dictionary containing the agent state, timestamp, and breakpoint. + :return: An instance of AgentSnapshot. + """ + return cls( + component_inputs=data["component_inputs"], + component_visits=data["component_visits"], + break_point=AgentBreakpoint.from_dict(data["break_point"]), + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + tool_execution_decisions=[ + ToolExecutionDecision.from_dict(ted) for ted in data.get("tool_execution_decisions", []) + ] + if data.get("tool_execution_decisions") + else None, + ) diff --git a/hitl_breakpoint_example.py b/hitl_breakpoint_example.py new file mode 100644 index 00000000..0b7a5cbe --- /dev/null +++ b/hitl_breakpoint_example.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from typing import Any, Optional + +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import PipelineSnapshot +from haystack.tools import create_tool_from_function +from rich.console import Console + +from haystack_experimental.components.agents.agent import Agent +from haystack_experimental.components.agents.human_in_the_loop import ( + AlwaysAskPolicy, + BlockingConfirmationStrategy, + BreakpointConfirmationStrategy, + RichConsoleUI, + ToolExecutionDecision, +) +from haystack_experimental.components.agents.human_in_the_loop.breakpoint import ( + get_tool_calls_and_descriptions_from_snapshot, +) + + +def get_bank_balance(account_id: str) -> str: + """ + Simulate fetching a bank balance for a given account ID. + + :param account_id: The ID of the bank account. + :returns: + A string representing the bank balance. + """ + return f"Balance for account {account_id} is $1,234.56" + + +def addition(a: float, b: float) -> float: + """ + A simple addition function. + + :param a: First float. + :param b: Second float. + :returns: + Sum of a and b. + """ + return a + b + + +def get_latest_snapshot(snapshot_file_path: str) -> PipelineSnapshot: + """ + Load the latest pipeline snapshot from the 'pipeline_snapshots' directory. + """ + snapshot_dir = Path(snapshot_file_path) + possible_snapshots = [snapshot_dir / f for f in os.listdir(snapshot_dir)] + latest_snapshot_file = str(max(possible_snapshots, key=os.path.getctime)) + return load_pipeline_snapshot(latest_snapshot_file) + + +def frontend_simulate_tool_decision( + tool_calls: list[dict[str, Any]], tool_descriptions: dict[str, str], console: Console +) -> list[dict]: + """ + Simulate front-end receiving tool calls, prompting user, and sending back decisions. + + :param tool_calls: + A list of tool call dictionaries containing tool_name, id, and arguments. + :param tool_descriptions: + A dictionary mapping tool names to their descriptions. + :param console: + A Rich Console instance for displaying prompts and messages. + :returns: + A list of serialized ToolExecutionDecision dictionaries. + """ + + confirmation_strategy = BlockingConfirmationStrategy( + confirmation_policy=AlwaysAskPolicy(), + confirmation_ui=RichConsoleUI(console=console), + ) + + tool_execution_decisions = [] + for tc in tool_calls: + tool_execution_decisions.append( + confirmation_strategy.run( + tool_name=tc["tool_name"], + tool_description=tool_descriptions[tc["tool_name"]], + tool_call_id=tc["id"], + tool_params=tc["arguments"], + ) + ) + return [ted.to_dict() for ted in tool_execution_decisions] + + +def run_agent( + agent: Agent, + messages: list[ChatMessage], + console: Console, + snapshot_file_path: Optional[str] = None, + tool_execution_decisions: Optional[list[dict[str, Any]]] = None, +) -> Optional[dict[str, Any]]: + """ + Run the agent with the given messages and optional snapshot. + """ + # Load the latest snapshot if a path is provided + snapshot = None + if snapshot_file_path: + snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path) + + # Add any new tool execution decisions to the snapshot + if tool_execution_decisions: + teds = [ToolExecutionDecision.from_dict(ted) for ted in tool_execution_decisions] + existing_decisions = snapshot.agent_snapshot.tool_execution_decisions or [] + snapshot.agent_snapshot.tool_execution_decisions = existing_decisions + teds + + try: + return agent.run(messages=messages, snapshot=snapshot.agent_snapshot if snapshot else None) + except BreakpointException as e: + console.print("[bold red]Execution paused by Breakpoint Confirmation Strategy:[/bold red]", str(e)) + return None + + +def main(user_message: str): + """ + Main function to demonstrate the Breakpoint Confirmation Strategy with an agent. + """ + cons = Console() + cons.print("\n[bold blue]=== Breakpoint Confirmation Strategy Example ===[/bold blue]\n") + cons.print(f"[bold yellow]User Message:[/bold yellow] {user_message}\n") + + # Define agent with both tools and breakpoint confirmation strategies + addition_tool = create_tool_from_function( + function=addition, + name="addition", + description="Add two floats together.", + ) + balance_tool = create_tool_from_function( + function=get_bank_balance, + name="get_bank_balance", + description="Get the bank balance for a given account ID.", + ) + snapshot_fp = "pipeline_snapshots" + bank_agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=[balance_tool, addition_tool], + system_prompt="You are a helpful financial assistant. Use the provided tool to get bank balances when needed.", + confirmation_strategies={ + balance_tool.name: BreakpointConfirmationStrategy(snapshot_file_path=snapshot_fp), + addition_tool.name: BreakpointConfirmationStrategy(snapshot_file_path=snapshot_fp), + }, + ) + + # Step 1: Initial run + result = run_agent(bank_agent, [ChatMessage.from_user(user_message)], cons) + + # Step 2: Loop to handle break point confirmation strategy until agent completes + while result is None: + # Load the latest snapshot from disk and prep data for front-end + loaded_snapshot = get_latest_snapshot(snapshot_file_path=snapshot_fp) + serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot( + agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True + ) + + # Simulate front-end interaction + serialized_teds = frontend_simulate_tool_decision(serialized_tool_calls, tool_descripts, cons) + + # Re-run the agent with the new tool execution decisions + result = run_agent(bank_agent, [], cons, snapshot_fp, serialized_teds) + + # Step 3: Final result + last_message = result["last_message"] + cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}") + + +if __name__ == "__main__": + for usr_msg in [ + # Single tool call question --> Works + "What's the balance of account 56789?", + # Two tool call question --> Works + "What's the balance of account 56789 and what is 5.5 + 3.2?", + # Multiple sequential tool calls question --> Works + "What's the balance of account 56789? If it's lower than $2000, what's the balance of account 12345?", + ]: + main(usr_msg) diff --git a/hitl_intro_example.py b/hitl_intro_example.py new file mode 100644 index 00000000..11985036 --- /dev/null +++ b/hitl_intro_example.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack.tools import create_tool_from_function +from rich.console import Console + +from haystack_experimental.components.agents.agent import Agent +from haystack_experimental.components.agents.human_in_the_loop import ( + AlwaysAskPolicy, + AskOncePolicy, + BlockingConfirmationStrategy, + NeverAskPolicy, + RichConsoleUI, + SimpleConsoleUI, +) + + +def addition(a: float, b: float) -> float: + """ + A simple addition function. + + :param a: First float. + :param b: Second float. + :returns: + Sum of a and b. + """ + return a + b + + +addition_tool = create_tool_from_function( + function=addition, + name="addition", + description="Add two floats together.", +) + + +def get_bank_balance(account_id: str) -> str: + """ + Simulate fetching a bank balance for a given account ID. + + :param account_id: The ID of the bank account. + :returns: + A string representing the bank balance. + """ + return f"Balance for account {account_id} is $1,234.56" + + +balance_tool = create_tool_from_function( + function=get_bank_balance, + name="get_bank_balance", + description="Get the bank balance for a given account ID.", +) + + +def get_phone_number(name: str) -> str: + """ + Simulate fetching a phone number for a given name. + + :param name: The name of the person. + :returns: + A string representing the phone number. + """ + return f"The phone number for {name} is (123) 456-7890" + + +phone_tool = create_tool_from_function( + function=get_phone_number, + name="get_phone_number", + description="Get the phone number for a given name.", +) + +# Define shared console +cons = Console() + +# Define Main Agent with multiple tools and different confirmation strategies +agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=[balance_tool, addition_tool, phone_tool], + system_prompt="You are a helpful financial assistant. Use the provided tool to get bank balances when needed.", + confirmation_strategies={ + balance_tool.name: BlockingConfirmationStrategy( + confirmation_policy=AlwaysAskPolicy(), confirmation_ui=RichConsoleUI(console=cons) + ), + addition_tool.name: BlockingConfirmationStrategy( + confirmation_policy=NeverAskPolicy(), confirmation_ui=SimpleConsoleUI() + ), + phone_tool.name: BlockingConfirmationStrategy( + confirmation_policy=AskOncePolicy(), confirmation_ui=SimpleConsoleUI() + ), + }, +) + +# Call bank tool with confirmation (Always Ask) using RichConsoleUI +result = agent.run([ChatMessage.from_user("What's the balance of account 56789?")]) +last_message = result["last_message"] +cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}") + +# Call addition tool with confirmation (Never Ask) +result = agent.run([ChatMessage.from_user("What is 5.5 + 3.2?")]) +last_message = result["last_message"] +print(f"\nAgent Result: {last_message.text}") + +# Call phone tool with confirmation (Ask Once) using SimpleConsoleUI +result = agent.run([ChatMessage.from_user("What is the phone number of Alice?")]) +last_message = result["last_message"] +print(f"\nAgent Result: {last_message.text}") + +# Call phone tool again to see that it doesn't ask for confirmation the second time +result = agent.run([ChatMessage.from_user("What is the phone number of Alice?")]) +last_message = result["last_message"] +print(f"\nAgent Result: {last_message.text}") diff --git a/hitl_multi_agent_example.py b/hitl_multi_agent_example.py new file mode 100644 index 00000000..ce352d0d --- /dev/null +++ b/hitl_multi_agent_example.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack.tools import ComponentTool, create_tool_from_function +from rich.console import Console + +from haystack_experimental.components.agents.agent import Agent +from haystack_experimental.components.agents.human_in_the_loop import ( + AlwaysAskPolicy, + BlockingConfirmationStrategy, + RichConsoleUI, +) + + +def addition(a: float, b: float) -> float: + """ + A simple addition function. + + :param a: First float. + :param b: Second float. + :returns: + Sum of a and b. + """ + return a + b + + +addition_tool = create_tool_from_function( + function=addition, + name="addition", + description="Add two floats together.", +) + + +def get_bank_balance(account_id: str) -> str: + """ + Simulate fetching a bank balance for a given account ID. + + :param account_id: The ID of the bank account. + :returns: + A string representing the bank balance. + """ + return f"Balance for account {account_id} is $1,234.56" + + +balance_tool = create_tool_from_function( + function=get_bank_balance, + name="get_bank_balance", + description="Get the bank balance for a given account ID.", +) + +# Define shared console for all UIs +cons = Console() + +# Define Bank Sub-Agent +bank_agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=[balance_tool], + system_prompt="You are a helpful financial assistant. Use the provided tool to get bank balances when needed.", + confirmation_strategies={ + balance_tool.name: BlockingConfirmationStrategy( + confirmation_policy=AlwaysAskPolicy(), confirmation_ui=RichConsoleUI(console=cons) + ), + }, +) +bank_agent_tool = ComponentTool( + component=bank_agent, + name="bank_agent_tool", + description="A bank agent that can get bank balances.", +) + +# Define Math Sub-Agent +math_agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=[addition_tool], + system_prompt="You are a helpful math assistant. Use the provided tool to perform addition when needed.", + confirmation_strategies={ + addition_tool.name: BlockingConfirmationStrategy( + # We use AlwaysAskPolicy here for demonstration; in real scenarios, you might choose NeverAskPolicy + confirmation_policy=AlwaysAskPolicy(), + confirmation_ui=RichConsoleUI(console=cons), + ), + }, +) +math_agent_tool = ComponentTool( + component=math_agent, + name="math_agent_tool", + description="A math agent that can perform addition.", +) + +# Define Main Agent with Sub-Agents as tools +planner_agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=[bank_agent_tool, math_agent_tool], + system_prompt="""You are a master agent that can delegate tasks to sub-agents based on the user's request. +Available sub-agents: +- bank_agent_tool: A bank agent that can get bank balances. +- math_agent_tool: A math agent that can perform addition. +Use the appropriate sub-agent to handle the user's request. +""", +) + +# Make bank balance request to planner agent +result = planner_agent.run([ChatMessage.from_user("What's the balance of account 56789?")]) +last_message = result["last_message"] +cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}") + +# Make addition request to planner agent +result = planner_agent.run([ChatMessage.from_user("What is 5.5 + 3.2?")]) +last_message = result["last_message"] +print(f"\nAgent Result: {last_message.text}") + +# Make bank balance request and addition request to planner agent +# NOTE: This will try and invoke both sub-agents in parallel requiring a thread-safe UI +result = planner_agent.run([ChatMessage.from_user("What's the balance of account 56789 and what is 5.5 + 3.2?")]) +last_message = result["last_message"] +print(f"\nAgent Result: {last_message.text}") diff --git a/pyproject.toml b/pyproject.toml index d3ed63f3..61fdde40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ dependencies = [ "haystack-ai", + "rich", # For pretty printing in the console used by human-in-the-loop utilities ] [project.urls] diff --git a/test/components/agents/__init__.py b/test/components/agents/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/agents/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/agents/human_in_the_loop/__init__.py b/test/components/agents/human_in_the_loop/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/agents/human_in_the_loop/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/agents/human_in_the_loop/test_breakpoint.py b/test/components/agents/human_in_the_loop/test_breakpoint.py new file mode 100644 index 00000000..503d2a1c --- /dev/null +++ b/test/components/agents/human_in_the_loop/test_breakpoint.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint + +from haystack_experimental.dataclasses.breakpoints import AgentSnapshot +from haystack_experimental.components.agents.human_in_the_loop.breakpoint import ( + get_tool_calls_and_descriptions_from_snapshot, +) + + +def get_bank_balance(account_id: str) -> str: + return f"The balance for account {account_id} is $1,234.56." + + +def addition(a: float, b: float) -> float: + return a + b + + +def test_get_tool_calls_and_descriptions_from_snapshot(): + agent_snapshot = AgentSnapshot( + component_inputs={ + "chat_generator": {}, + "tool_invoker": { + "serialization_schema": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}, + }, + "state": {"type": "haystack.components.agents.state.state.State"}, + "tools": {"type": "array", "items": {"type": "haystack.tools.tool.Tool"}}, + "enable_streaming_callback_passthrough": {"type": "boolean"}, + }, + }, + "serialized_data": { + "messages": [ + { + "role": "assistant", + "content": [ + { + "tool_call": { + "tool_name": "get_bank_balance", + "arguments": {"account_id": "56789"}, + "id": None, + } + } + ], + } + ], + "state": { + "schema": { + "messages": { + "type": "list[haystack.dataclasses.chat_message.ChatMessage]", + "handler": "haystack.components.agents.state.state_utils.merge_lists", + } + }, + "data": { + "serialization_schema": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}, + } + }, + }, + "serialized_data": { + "messages": [ + { + "role": "system", + "content": [ + { + "text": "You are a helpful financial assistant. Use the provided tool to get bank balances when needed." + } + ], + }, + { + "role": "user", + "content": [{"text": "What's the balance of account 56789?"}], + }, + { + "role": "assistant", + "content": [ + { + "tool_call": { + "tool_name": "get_bank_balance", + "arguments": {"account_id": "56789"}, + "id": None, + } + } + ], + }, + ] + }, + }, + }, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "get_bank_balance", + "description": "Get the bank balance for a given account ID.", + "parameters": { + "properties": {"account_id": {"type": "string"}}, + "required": ["account_id"], + "type": "object", + }, + "function": "test.components.agents.human_in_the_loop.test_breakpoint.get_bank_balance", + }, + }, + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "addition", + "description": "Add two floats together.", + "parameters": { + "properties": {"a": {"type": "number"}, "b": {"type": "number"}}, + "required": ["a", "b"], + "type": "object", + }, + "function": "test.components.agents.human_in_the_loop.test_breakpoint.addition", + }, + }, + ], + "enable_streaming_callback_passthrough": False, + }, + }, + }, + component_visits={"chat_generator": 1, "tool_invoker": 0}, + break_point=AgentBreakpoint( + agent_name="agent", + break_point=ToolBreakpoint( + tool_name="get_bank_balance", component_name="tool_invoker", visit_count=0, snapshot_file_path=None + ), + ), + ) + + tool_calls, tool_descriptions = get_tool_calls_and_descriptions_from_snapshot( + agent_snapshot=agent_snapshot, breakpoint_tool_only=True + ) + + assert len(tool_calls) == 1 + assert tool_calls[0]["tool_name"] == "get_bank_balance" + assert tool_calls[0]["arguments"] == {"account_id": "56789"} + assert tool_descriptions == {"get_bank_balance": "Get the bank balance for a given account ID."} diff --git a/test/components/agents/human_in_the_loop/test_dataclasses.py b/test/components/agents/human_in_the_loop/test_dataclasses.py new file mode 100644 index 00000000..ff0d0cd2 --- /dev/null +++ b/test/components/agents/human_in_the_loop/test_dataclasses.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.agents.human_in_the_loop import ( + ConfirmationUIResult, + ToolExecutionDecision, +) + + +class TestConfirmationUIResult: + def test_init(self): + original = ConfirmationUIResult( + action="reject", + feedback="Changed my mind", + ) + assert original.action == "reject" + assert original.feedback == "Changed my mind" + assert original.new_tool_params is None + + +class TestToolExecutionDecision: + def test_init(self): + decision = ToolExecutionDecision( + execute=True, + tool_name="test_tool", + tool_call_id="test_tool_call_id", + final_tool_params={"param1": "new_value"}, + ) + assert decision.execute is True + assert decision.final_tool_params == {"param1": "new_value"} + assert decision.tool_call_id == "test_tool_call_id" + assert decision.tool_name == "test_tool" + + def test_to_dict(self): + original = ToolExecutionDecision( + execute=True, + tool_name="test_tool", + tool_call_id="test_tool_call_id", + final_tool_params={"param1": "new_value"}, + ) + as_dict = original.to_dict() + assert as_dict == { + "execute": True, + "tool_name": "test_tool", + "tool_call_id": "test_tool_call_id", + "feedback": None, + "final_tool_params": {"param1": "new_value"}, + } + + def test_from_dict(self): + data = { + "execute": False, + "tool_name": "another_tool", + "tool_call_id": "another_tool_call_id", + "feedback": "Not needed", + "final_tool_params": {"paramA": 123}, + } + decision = ToolExecutionDecision.from_dict(data) + assert decision.execute is False + assert decision.tool_name == "another_tool" + assert decision.tool_call_id == "another_tool_call_id" + assert decision.feedback == "Not needed" + assert decision.final_tool_params == {"paramA": 123} diff --git a/test/components/agents/human_in_the_loop/test_policies.py b/test/components/agents/human_in_the_loop/test_policies.py new file mode 100644 index 00000000..3f63fb41 --- /dev/null +++ b/test/components/agents/human_in_the_loop/test_policies.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from haystack.tools import Tool, create_tool_from_function + +from haystack_experimental.components.agents.human_in_the_loop import ( + AlwaysAskPolicy, + AskOncePolicy, + ConfirmationUIResult, + NeverAskPolicy, +) + + +def addition(x: int, y: int) -> int: + return x + y + + +@pytest.fixture +def addition_tool() -> Tool: + return create_tool_from_function( + function=addition, + name="Addition tool", + description="Adds two integers together.", + ) + + +class TestAlwaysAskPolicy: + def test_should_ask_always_true(self, addition_tool): + policy = AlwaysAskPolicy() + assert policy.should_ask(addition_tool.name, addition_tool.description, {"x": 1, "y": 2}) is True + + def test_to_dict(self): + policy = AlwaysAskPolicy() + policy_dict = policy.to_dict() + assert ( + policy_dict["type"] == "haystack_experimental.components.agents.human_in_the_loop.policies.AlwaysAskPolicy" + ) + assert policy_dict["init_parameters"] == {} + + def test_from_dict(self): + policy_dict = { + "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AlwaysAskPolicy", + "init_parameters": {}, + } + policy = AlwaysAskPolicy.from_dict(policy_dict) + assert isinstance(policy, AlwaysAskPolicy) + + +class TestAskOncePolicy: + def test_should_ask_first_time_true(self, addition_tool): + policy = AskOncePolicy() + assert policy.should_ask(addition_tool.name, addition_tool.description, {"x": 1, "y": 2}) is True + + def test_should_ask_second_time_false(self, addition_tool): + policy = AskOncePolicy() + params = {"x": 1, "y": 2} + assert policy.should_ask(addition_tool.name, addition_tool.description, params) is True + # Simulate the update after confirmation that occurs in HumanInTheLoopStrategy + policy.update_after_confirmation( + addition_tool.name, addition_tool.description, params, ConfirmationUIResult(action="confirm", feedback=None) + ) + assert policy.should_ask(addition_tool.name, addition_tool.description, params) is False + + def test_should_ask_different_params_true(self, addition_tool): + policy = AskOncePolicy() + params1 = {"x": 1, "y": 2} + params2 = {"x": 3, "y": 4} + assert policy.should_ask(addition_tool.name, addition_tool.description, params1) is True + # Simulate the update after confirmation that occurs in HumanInTheLoopStrategy + policy.update_after_confirmation( + addition_tool.name, + addition_tool.description, + params1, + ConfirmationUIResult(action="confirm", feedback=None), + ) + assert policy.should_ask(addition_tool.name, addition_tool.description, params2) is True + + def test_to_dict(self): + policy = AskOncePolicy() + policy_dict = policy.to_dict() + assert policy_dict["type"] == "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy" + assert policy_dict["init_parameters"] == {} + + def test_from_dict(self): + policy_dict = { + "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy", + "init_parameters": {}, + } + policy = AskOncePolicy.from_dict(policy_dict) + assert isinstance(policy, AskOncePolicy) + + +class TestNeverAskPolicy: + def test_should_ask_always_false(self, addition_tool): + policy = NeverAskPolicy() + assert policy.should_ask(addition_tool.name, addition_tool.description, {"x": 1, "y": 2}) is False + + def test_to_dict(self): + policy = NeverAskPolicy() + policy_dict = policy.to_dict() + assert ( + policy_dict["type"] == "haystack_experimental.components.agents.human_in_the_loop.policies.NeverAskPolicy" + ) + assert policy_dict["init_parameters"] == {} + + def test_from_dict(self): + policy_dict = { + "type": "haystack_experimental.components.agents.human_in_the_loop.policies.NeverAskPolicy", + "init_parameters": {}, + } + policy = NeverAskPolicy.from_dict(policy_dict) + assert isinstance(policy, NeverAskPolicy) diff --git a/test/components/agents/human_in_the_loop/test_strategies.py b/test/components/agents/human_in_the_loop/test_strategies.py new file mode 100644 index 00000000..ea91c884 --- /dev/null +++ b/test/components/agents/human_in_the_loop/test_strategies.py @@ -0,0 +1,340 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import replace +import pytest +from haystack.components.agents.state.state import State +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.tools import Tool, create_tool_from_function + +from haystack_experimental.components.agents.agent import _ExecutionContext +from haystack_experimental.components.agents.human_in_the_loop import ( + AlwaysAskPolicy, + AskOncePolicy, + BlockingConfirmationStrategy, + BreakpointConfirmationStrategy, + HITLBreakpointException, + SimpleConsoleUI, + ToolExecutionDecision, + NeverAskPolicy, + ConfirmationUIResult, +) +from haystack_experimental.components.agents.human_in_the_loop.strategies import ( + _apply_tool_execution_decisions, + _run_confirmation_strategies, + _update_chat_history, +) + + +def addition_tool(a: int, b: int) -> int: + return a + b + + +@pytest.fixture +def tools() -> list[Tool]: + tool = create_tool_from_function( + function=addition_tool, name="addition_tool", description="A tool that adds two integers together." + ) + return [tool] + + +@pytest.fixture +def execution_context(tools) -> _ExecutionContext: + return _ExecutionContext( + state=State(schema={"messages": {"type": list[ChatMessage]}}), + component_visits={"chat_generator": 0, "tool_invoker": 0}, + chat_generator_inputs={}, + tool_invoker_inputs={"tools": tools}, + counter=0, + skip_chat_generator=False, + tool_execution_decisions=None, + ) + + +class TestBlockingConfirmationStrategy: + def test_initialization(self): + strategy = BlockingConfirmationStrategy(confirmation_policy=AskOncePolicy(), confirmation_ui=SimpleConsoleUI()) + assert isinstance(strategy.confirmation_policy, AskOncePolicy) + assert isinstance(strategy.confirmation_ui, SimpleConsoleUI) + + def test_to_dict(self): + strategy = BlockingConfirmationStrategy(confirmation_policy=AskOncePolicy(), confirmation_ui=SimpleConsoleUI()) + strategy_dict = strategy.to_dict() + assert strategy_dict == { + "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BlockingConfirmationStrategy", + "init_parameters": { + "confirmation_policy": { + "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy", + "init_parameters": {}, + }, + "confirmation_ui": { + "type": "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI", + "init_parameters": {}, + }, + }, + } + + def test_from_dict(self): + strategy_dict = { + "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.HumanInTheLoopStrategy", + "init_parameters": { + "confirmation_policy": { + "type": "haystack_experimental.components.agents.human_in_the_loop.policies.AskOncePolicy", + "init_parameters": {}, + }, + "confirmation_ui": { + "type": "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI", + "init_parameters": {}, + }, + }, + } + strategy = BlockingConfirmationStrategy.from_dict(strategy_dict) + assert isinstance(strategy, BlockingConfirmationStrategy) + assert isinstance(strategy.confirmation_policy, AskOncePolicy) + assert isinstance(strategy.confirmation_ui, SimpleConsoleUI) + + def test_run_confirm(self, monkeypatch): + strategy = BlockingConfirmationStrategy(AlwaysAskPolicy(), SimpleConsoleUI()) + + # Mock the UI to always confirm + def mock_get_user_confirmation(tool_name, tool_description, tool_params): + return ConfirmationUIResult(action="confirm") + + monkeypatch.setattr(strategy.confirmation_ui, "get_user_confirmation", mock_get_user_confirmation) + + decision = strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"}) + assert decision.tool_name == "test_tool" + assert decision.execute is True + assert decision.final_tool_params == {"param1": "value1"} + + def test_run_modify(self, monkeypatch): + strategy = BlockingConfirmationStrategy(AlwaysAskPolicy(), SimpleConsoleUI()) + + # Mock the UI to always modify + def mock_get_user_confirmation(tool_name, tool_description, tool_params): + return ConfirmationUIResult(action="modify", new_tool_params={"param1": "new_value"}) + + monkeypatch.setattr(strategy.confirmation_ui, "get_user_confirmation", mock_get_user_confirmation) + + decision = strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"}) + assert decision.tool_name == "test_tool" + assert decision.execute is True + assert decision.final_tool_params == {"param1": "new_value"} + assert decision.feedback == ( + "The parameters for tool 'test_tool' were updated by the user to:\n{'param1': 'new_value'}" + ) + + def test_run_reject(self, monkeypatch): + strategy = BlockingConfirmationStrategy(AlwaysAskPolicy(), SimpleConsoleUI()) + + # Mock the UI to always reject + def mock_get_user_confirmation(tool_name, tool_description, tool_params): + return ConfirmationUIResult(action="reject", feedback="Not needed") + + monkeypatch.setattr(strategy.confirmation_ui, "get_user_confirmation", mock_get_user_confirmation) + + decision = strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"}) + assert decision.tool_name == "test_tool" + assert decision.execute is False + assert decision.final_tool_params is None + assert decision.feedback == "Tool execution for 'test_tool' was rejected by the user. With feedback: Not needed" + + +class TestBreakpointConfirmationStrategy: + def test_initialization(self): + strategy = BreakpointConfirmationStrategy(snapshot_file_path="test") + assert strategy.snapshot_file_path == "test" + + def test_to_dict(self): + strategy = BreakpointConfirmationStrategy(snapshot_file_path="test") + strategy_dict = strategy.to_dict() + assert strategy_dict == { + "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BreakpointConfirmationStrategy", + "init_parameters": {"snapshot_file_path": "test"}, + } + + def test_from_dict(self): + strategy_dict = { + "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BreakpointConfirmationStrategy", + "init_parameters": {"snapshot_file_path": "test"}, + } + strategy = BreakpointConfirmationStrategy.from_dict(strategy_dict) + assert isinstance(strategy, BreakpointConfirmationStrategy) + assert strategy.snapshot_file_path == "test" + + def test_run(self): + strategy = BreakpointConfirmationStrategy(snapshot_file_path="test") + with pytest.raises(HITLBreakpointException): + strategy.run(tool_name="test_tool", tool_description="A test tool", tool_params={"param1": "value1"}) + + +class TestRunConfirmationStrategies: + def test_run_confirmation_strategies_hitl_breakpoint(self, tmp_path, tools, execution_context): + with pytest.raises(HITLBreakpointException): + _run_confirmation_strategies( + confirmation_strategies={tools[0].name: BreakpointConfirmationStrategy(str(tmp_path))}, + messages_with_tool_calls=[ + ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"})]), + ], + execution_context=execution_context, + ) + + def test_run_confirmation_strategies_no_strategy(self, tools, execution_context): + teds = _run_confirmation_strategies( + confirmation_strategies={}, + messages_with_tool_calls=[ + ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"})]), + ], + execution_context=execution_context, + ) + assert teds == [ + ToolExecutionDecision(tool_name=tools[0].name, execute=True, final_tool_params={"param1": "value1"}) + ] + + def test_run_confirmation_strategies_with_strategy(self, tools, execution_context): + teds = _run_confirmation_strategies( + confirmation_strategies={tools[0].name: BlockingConfirmationStrategy(NeverAskPolicy(), SimpleConsoleUI())}, + messages_with_tool_calls=[ + ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"})]), + ], + execution_context=execution_context, + ) + assert teds == [ + ToolExecutionDecision(tool_name=tools[0].name, execute=True, final_tool_params={"param1": "value1"}) + ] + + def test_run_confirmation_strategies_with_existing_teds(self, tools, execution_context): + exe_context_with_teds = replace( + execution_context, + tool_execution_decisions=[ + ToolExecutionDecision( + tool_name=tools[0].name, execute=True, tool_call_id="123", final_tool_params={"param1": "new_value"} + ) + ], + ) + teds = _run_confirmation_strategies( + confirmation_strategies={tools[0].name: BlockingConfirmationStrategy(NeverAskPolicy(), SimpleConsoleUI())}, + messages_with_tool_calls=[ + ChatMessage.from_assistant(tool_calls=[ToolCall(tools[0].name, {"param1": "value1"}, id="123")]), + ], + execution_context=exe_context_with_teds, + ) + assert teds == [ + ToolExecutionDecision( + tool_name=tools[0].name, execute=True, tool_call_id="123", final_tool_params={"param1": "new_value"} + ) + ] + + +class TestApplyToolExecutionDecisions: + @pytest.fixture + def assistant_message(self, tools): + tool_call = ToolCall(tool_name=tools[0].name, arguments={"a": 1, "b": 2}, id="1") + return ChatMessage.from_assistant(tool_calls=[tool_call]) + + def test_apply_tool_execution_decisions_reject(self, tools, assistant_message): + rejection_messages, new_tool_call_messages = _apply_tool_execution_decisions( + tool_call_messages=[assistant_message], + tool_execution_decisions=[ + ToolExecutionDecision( + tool_name=tools[0].name, + execute=False, + tool_call_id="1", + feedback=( + "The tool execution for 'addition_tool' was rejected by the user. With feedback: Not needed" + ), + ) + ], + ) + assert rejection_messages == [ + assistant_message, + ChatMessage.from_tool( + tool_result=( + "The tool execution for 'addition_tool' was rejected by the user. With feedback: Not needed" + ), + origin=assistant_message.tool_call, + error=True, + ), + ] + assert new_tool_call_messages == [] + + def test_apply_tool_execution_decisions_confirm(self, tools, assistant_message): + rejection_messages, new_tool_call_messages = _apply_tool_execution_decisions( + tool_call_messages=[assistant_message], + tool_execution_decisions=[ + ToolExecutionDecision( + tool_name=tools[0].name, execute=True, tool_call_id="1", final_tool_params={"a": 1, "b": 2} + ) + ], + ) + assert rejection_messages == [] + assert new_tool_call_messages == [assistant_message] + + def test_apply_tool_execution_decisions_modify(self, tools, assistant_message): + rejection_messages, new_tool_call_messages = _apply_tool_execution_decisions( + tool_call_messages=[assistant_message], + tool_execution_decisions=[ + ToolExecutionDecision( + tool_name=tools[0].name, + execute=True, + tool_call_id="1", + final_tool_params={"a": 5, "b": 6}, + feedback="The parameters for tool 'addition_tool' were updated by the user to:\n{'a': 5, 'b': 6}", + ) + ], + ) + assert rejection_messages == [] + assert new_tool_call_messages == [ + ChatMessage.from_user( + "The parameters for tool 'addition_tool' were updated by the user to:\n{'a': 5, 'b': 6}" + ), + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name=tools[0].name, arguments={"a": 5, "b": 6}, id="1")] + ), + ] + + +class TestUpdateChatHistory: + @pytest.fixture + def chat_history_one_tool_call(self): + return [ + ChatMessage.from_user("Hello"), + ChatMessage.from_assistant(tool_calls=[ToolCall("tool1", {"a": 1, "b": 2}, id="1")]), + ] + + def test_update_chat_history_rejection(self, chat_history_one_tool_call): + """Test that the new history includes a tool call result message after a rejection.""" + rejection_messages = [ + ChatMessage.from_assistant(tool_calls=[chat_history_one_tool_call[1].tool_call]), + ChatMessage.from_tool( + tool_result="The tool execution for 'tool1' was rejected by the user. With feedback: Not needed", + origin=chat_history_one_tool_call[1].tool_call, + error=True, + ), + ] + updated_messages = _update_chat_history( + chat_history_one_tool_call, rejection_messages=rejection_messages, tool_call_and_explanation_messages=[] + ) + assert updated_messages == [chat_history_one_tool_call[0], *rejection_messages] + + def test_update_chat_history_confirm(self, chat_history_one_tool_call): + """No changes should be made if the tool call was confirmed.""" + tool_call_messages = [ChatMessage.from_assistant(tool_calls=[chat_history_one_tool_call[1].tool_call])] + updated_messages = _update_chat_history( + chat_history_one_tool_call, rejection_messages=[], tool_call_and_explanation_messages=tool_call_messages + ) + assert updated_messages == chat_history_one_tool_call + + def test_update_chat_history_modify(self, chat_history_one_tool_call): + """Test that the new history includes a user message and updated tool call after a modification.""" + tool_call_messages = [ + ChatMessage.from_user( + "The parameters for tool 'tool1' were updated by the user to:\n{'param': 'new_value'}" + ), + ChatMessage.from_assistant(tool_calls=[ToolCall("tool1", {"param": "new_value"}, id="1")]), + ] + updated_messages = _update_chat_history( + chat_history_one_tool_call, rejection_messages=[], tool_call_and_explanation_messages=tool_call_messages + ) + assert updated_messages == [chat_history_one_tool_call[0], *tool_call_messages] diff --git a/test/components/agents/human_in_the_loop/test_user_interfaces.py b/test/components/agents/human_in_the_loop/test_user_interfaces.py new file mode 100644 index 00000000..de3474dc --- /dev/null +++ b/test/components/agents/human_in_the_loop/test_user_interfaces.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock, patch + +import pytest +from haystack.tools import create_tool_from_function + +from haystack_experimental.components.agents.human_in_the_loop.dataclasses import ConfirmationUIResult +from haystack_experimental.components.agents.human_in_the_loop.user_interfaces import RichConsoleUI, SimpleConsoleUI + + +def multiply_tool(x: int) -> int: + return x * 2 + + +@pytest.fixture +def tool(): + return create_tool_from_function( + function=multiply_tool, + name="test_tool", + description="A test tool that multiplies input by 2.", + ) + + +class TestRichConsoleUI: + @pytest.mark.parametrize("choice", ["y"]) + def test_process_choice_confirm(self, tool, choice): + ui = RichConsoleUI(console=MagicMock()) + + with patch( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask", + side_effect=[choice, "feedback"], + ): + result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "confirm" + assert result.new_tool_params is None + assert result.feedback is None + + @pytest.mark.parametrize("choice", ["m"]) + def test_process_choice_modify(self, tool, choice): + ui = RichConsoleUI(console=MagicMock()) + + with patch( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask", + side_effect=["m", "2"], + ): + result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {"x": 2} + + def test_process_choice_modify_dict_param(self, tool): + ui = RichConsoleUI(console=MagicMock()) + + with patch( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask", + side_effect=["m", '{"key": "value"}'], + ): + result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {"param1": {"key": "value"}} + + def test_process_choice_modify_dict_param_invalid_json(self, tool): + ui = RichConsoleUI(console=MagicMock()) + + with patch( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask", + side_effect=["m", "invalid_json", '{"key": "value"}'], + ): + result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {"param1": {"key": "value"}} + + @pytest.mark.parametrize("choice", ["n"]) + def test_process_choice_reject(self, tool, choice): + ui = RichConsoleUI(console=MagicMock()) + + with patch( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.Prompt.ask", + side_effect=["n", "Changed my mind"], + ): + result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "reject" + assert result.feedback == "Changed my mind" + + def test_to_dict(self): + ui = RichConsoleUI() + data = ui.to_dict() + assert data["type"] == ( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.RichConsoleUI" + ) + assert data["init_parameters"]["console"] is None + + def test_from_dict(self): + ui = RichConsoleUI() + data = ui.to_dict() + new_ui = RichConsoleUI.from_dict(data) + assert isinstance(new_ui, RichConsoleUI) + + +class TestSimpleConsoleUI: + @pytest.mark.parametrize("choice", ["y", "yes", "Y", "YES"]) + def test_process_choice_confirm(self, tool, choice): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=[choice]): + result = ui.get_user_confirmation(tool.name, tool.description, {"y": "abc"}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "confirm" + + @pytest.mark.parametrize("choice", ["m", "modify", "M", "MODIFY"]) + def test_process_choice_modify(self, tool, choice): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=[choice, "new_value"]): + result = ui.get_user_confirmation(tool.name, tool.description, {"y": "abc"}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {"y": "new_value"} + + def test_process_choice_modify_dict_param(self, tool): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=["m", '{"key": "value"}']): + result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {"param1": {"key": "value"}} + + def test_process_choice_modify_dict_param_invalid_json(self, tool): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=["m", "invalid_json", '{"key": "value"}']): + result = ui.get_user_confirmation(tool.name, tool.description, {"param1": {"old_key": "old_value"}}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {"param1": {"key": "value"}} + + @pytest.mark.parametrize("choice", ["n", "no", "N", "NO"]) + def test_process_choice_reject(self, tool, choice): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=[choice, "Changed my mind"]): + result = ui.get_user_confirmation(tool.name, tool.description, {"param1": "value1"}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "reject" + assert result.feedback == "Changed my mind" + + def test_process_choice_no_tool_params_confirm(self, tool): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=["y"]): + result = ui.get_user_confirmation(tool.name, tool.description, {}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "confirm" + assert result.new_tool_params is None + assert result.feedback is None + + def test_process_choice_no_tool_params_modify(self, tool): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=["m"]): + result = ui.get_user_confirmation(tool.name, tool.description, {}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "modify" + assert result.new_tool_params == {} + assert result.feedback is None + + def test_process_choice_no_tool_params_reject(self, tool): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=["n", "Changed my mind"]): + result = ui.get_user_confirmation(tool.name, tool.description, {}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "reject" + assert result.new_tool_params is None + assert result.feedback == "Changed my mind" + + def test_to_dict(self): + ui = SimpleConsoleUI() + data = ui.to_dict() + assert data["type"] == ( + "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI" + ) + assert data["init_parameters"] == {} + + def test_from_dict(self): + ui = SimpleConsoleUI() + data = ui.to_dict() + new_ui = SimpleConsoleUI.from_dict(data) + assert isinstance(new_ui, SimpleConsoleUI) + + def test_get_user_confirmation_invalid_input_then_valid(self, tool): + ui = SimpleConsoleUI() + + with patch("builtins.input", side_effect=["invalid", "y"]): + result = ui.get_user_confirmation(tool.name, tool.description, {"x": 1}) + + assert isinstance(result, ConfirmationUIResult) + assert result.action == "confirm" + assert result.new_tool_params is None + assert result.feedback is None diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py new file mode 100644 index 00000000..849faf8e --- /dev/null +++ b/test/components/agents/test_agent.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from typing import Any, Optional + +import pytest +from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import PipelineSnapshot +from haystack.tools import Tool, create_tool_from_function + +from haystack_experimental.components.agents.agent import Agent +from haystack_experimental.components.agents.human_in_the_loop import ( + AlwaysAskPolicy, + BlockingConfirmationStrategy, + BreakpointConfirmationStrategy, + ConfirmationStrategy, + ConfirmationUI, + ConfirmationUIResult, + NeverAskPolicy, + SimpleConsoleUI, + ToolExecutionDecision, +) +from haystack_experimental.components.agents.human_in_the_loop.breakpoint import ( + get_tool_calls_and_descriptions_from_snapshot, +) + + +class TestUserInterface(ConfirmationUI): + def __init__(self, ui_result: ConfirmationUIResult) -> None: + self.ui_result = ui_result + + def get_user_confirmation( + self, tool_name: str, tool_description: str, tool_params: dict[str, Any] + ) -> ConfirmationUIResult: + return self.ui_result + + +def frontend_simulate_tool_decision( + tool_calls: list[dict[str, Any]], + tool_descriptions: dict[str, str], + confirmation_ui_result: ConfirmationUIResult, +) -> list[dict]: + confirmation_strategy = BlockingConfirmationStrategy( + confirmation_policy=AlwaysAskPolicy(), + confirmation_ui=TestUserInterface(ui_result=confirmation_ui_result), + ) + + tool_execution_decisions = [] + for tc in tool_calls: + tool_execution_decisions.append( + confirmation_strategy.run( + tool_name=tc["tool_name"], + tool_description=tool_descriptions[tc["tool_name"]], + tool_call_id=tc["id"], + tool_params=tc["arguments"], + ) + ) + return [ted.to_dict() for ted in tool_execution_decisions] + + +def get_latest_snapshot(snapshot_file_path: str) -> PipelineSnapshot: + snapshot_dir = Path(snapshot_file_path) + possible_snapshots = [snapshot_dir / f for f in os.listdir(snapshot_dir)] + latest_snapshot_file = str(max(possible_snapshots, key=os.path.getctime)) + return load_pipeline_snapshot(latest_snapshot_file) + + +def run_agent( + agent: Agent, + messages: list[ChatMessage], + snapshot_file_path: Optional[str] = None, + tool_execution_decisions: Optional[list[dict[str, Any]]] = None, +) -> Optional[dict[str, Any]]: + # Load the latest snapshot if a path is provided + snapshot = None + if snapshot_file_path: + snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path) + + # Add any new tool execution decisions to the snapshot + if tool_execution_decisions: + teds = [ToolExecutionDecision.from_dict(ted) for ted in tool_execution_decisions] + existing_decisions = snapshot.agent_snapshot.tool_execution_decisions or [] + snapshot.agent_snapshot.tool_execution_decisions = existing_decisions + teds + + try: + return agent.run(messages=messages, snapshot=snapshot.agent_snapshot if snapshot else None) + except BreakpointException: + return None + + +async def run_agent_async( + agent: Agent, + messages: list[ChatMessage], + snapshot_file_path: Optional[str] = None, + tool_execution_decisions: Optional[list[dict[str, Any]]] = None, +) -> Optional[dict[str, Any]]: + # Load the latest snapshot if a path is provided + snapshot = None + if snapshot_file_path: + snapshot = get_latest_snapshot(snapshot_file_path=snapshot_file_path) + + # Add any new tool execution decisions to the snapshot + if tool_execution_decisions: + teds = [ToolExecutionDecision.from_dict(ted) for ted in tool_execution_decisions] + existing_decisions = snapshot.agent_snapshot.tool_execution_decisions or [] + snapshot.agent_snapshot.tool_execution_decisions = existing_decisions + teds + + try: + return await agent.run_async(messages=messages, snapshot=snapshot.agent_snapshot if snapshot else None) + except BreakpointException: + return None + + +def addition_tool(a: int, b: int) -> int: + return a + b + + +@pytest.fixture +def tools() -> list[Tool]: + tool = create_tool_from_function( + function=addition_tool, name="addition_tool", description="A tool that adds two integers together." + ) + return [tool] + + +@pytest.fixture +def confirmation_strategies() -> dict[str, ConfirmationStrategy]: + return {"addition_tool": BlockingConfirmationStrategy(NeverAskPolicy(), SimpleConsoleUI())} + + +class TestAgent: + def test_to_dict(self, tools, confirmation_strategies, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test") + agent = Agent( + chat_generator=OpenAIChatGenerator(), tools=tools, confirmation_strategies=confirmation_strategies + ) + agent_dict = agent.to_dict() + assert agent_dict == { + "type": "haystack_experimental.components.agents.agent.Agent", + "init_parameters": { + "chat_generator": { + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "model": "gpt-4o-mini", + "streaming_callback": None, + "api_base_url": None, + "organization": None, + "generation_kwargs": {}, + "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, + "timeout": None, + "max_retries": None, + "tools": None, + "tools_strict": False, + "http_client_kwargs": None, + }, + }, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "addition_tool", + "description": "A tool that adds two integers together.", + "parameters": { + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + "type": "object", + }, + "function": "test.components.agents.test_agent.addition_tool", + "outputs_to_string": None, + "inputs_from_state": None, + "outputs_to_state": None, + }, + } + ], + "system_prompt": None, + "exit_conditions": ["text"], + "state_schema": {}, + "max_agent_steps": 100, + "streaming_callback": None, + "raise_on_tool_invocation_failure": False, + "tool_invoker_kwargs": None, + "confirmation_strategies": { + "addition_tool": { + "type": "haystack_experimental.components.agents.human_in_the_loop.strategies.BlockingConfirmationStrategy", + "init_parameters": { + "confirmation_policy": { + "type": "haystack_experimental.components.agents.human_in_the_loop.policies.NeverAskPolicy", + "init_parameters": {}, + }, + "confirmation_ui": { + "type": "haystack_experimental.components.agents.human_in_the_loop.user_interfaces.SimpleConsoleUI", + "init_parameters": {}, + }, + }, + } + }, + }, + } + + def test_from_dict(self, tools, confirmation_strategies, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test") + agent = Agent( + chat_generator=OpenAIChatGenerator(), tools=tools, confirmation_strategies=confirmation_strategies + ) + deserialized_agent = Agent.from_dict(agent.to_dict()) + assert deserialized_agent.to_dict() == agent.to_dict() + assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) + assert len(deserialized_agent.tools) == 1 + assert deserialized_agent.tools[0].name == "addition_tool" + assert isinstance(deserialized_agent._tool_invoker, type(agent._tool_invoker)) + assert isinstance(deserialized_agent._confirmation_strategies["addition_tool"], BlockingConfirmationStrategy) + assert isinstance( + deserialized_agent._confirmation_strategies["addition_tool"].confirmation_policy, NeverAskPolicy + ) + assert isinstance(deserialized_agent._confirmation_strategies["addition_tool"].confirmation_ui, SimpleConsoleUI) + + +class TestAgentConfirmationStrategy: + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_run_blocking_confirmation_strategy_modify(self, tools): + agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"), + tools=tools, + confirmation_strategies={ + "addition_tool": BlockingConfirmationStrategy( + AlwaysAskPolicy(), + TestUserInterface(ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3})), + ) + }, + ) + agent.warm_up() + + result = agent.run([ChatMessage.from_user("What is 2+2?")]) + + assert isinstance(result["last_message"], ChatMessage) + assert "5" in result["last_message"].text + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_run_breakpoint_confirmation_strategy_modify(self, tools, tmp_path): + agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"), + tools=tools, + confirmation_strategies={ + "addition_tool": BreakpointConfirmationStrategy(snapshot_file_path=str(tmp_path)), + }, + ) + agent.warm_up() + + # Step 1: Initial run + result = run_agent(agent, [ChatMessage.from_user("What is 2+2?")]) + + # Step 2: Loop to handle break point confirmation strategy until agent completes + while result is None: + # Load the latest snapshot from disk and prep data for front-end + loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path)) + serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot( + agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True + ) + + # Simulate front-end interaction + serialized_teds = frontend_simulate_tool_decision( + serialized_tool_calls, + tool_descripts, + ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3}), + ) + + # Re-run the agent with the new tool execution decisions + result = run_agent(agent, [], str(tmp_path), serialized_teds) + + # Step 3: Final result + last_message = result["last_message"] + assert isinstance(last_message, ChatMessage) + assert "5" in last_message.text + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_async_blocking_confirmation_strategy_modify(self, tools): + agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"), + tools=tools, + confirmation_strategies={ + "addition_tool": BlockingConfirmationStrategy( + AlwaysAskPolicy(), + TestUserInterface(ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3})), + ) + }, + ) + agent.warm_up() + + result = await agent.run_async([ChatMessage.from_user("What is 2+2?")]) + + assert isinstance(result["last_message"], ChatMessage) + assert "5" in result["last_message"].text + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_async_breakpoint_confirmation_strategy_modify(self, tools, tmp_path): + agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4o-mini"), + tools=tools, + confirmation_strategies={ + "addition_tool": BreakpointConfirmationStrategy(snapshot_file_path=str(tmp_path)), + }, + ) + agent.warm_up() + + # Step 1: Initial run + result = await run_agent_async(agent, [ChatMessage.from_user("What is 2+2?")]) + + # Step 2: Loop to handle break point confirmation strategy until agent completes + while result is None: + # Load the latest snapshot from disk and prep data for front-end + loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path)) + serialized_tool_calls, tool_descripts = get_tool_calls_and_descriptions_from_snapshot( + agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True + ) + + # Simulate front-end interaction + serialized_teds = frontend_simulate_tool_decision( + serialized_tool_calls, + tool_descripts, + ConfirmationUIResult(action="modify", new_tool_params={"a": 2, "b": 3}), + ) + + # Re-run the agent with the new tool execution decisions + result = await run_agent_async(agent, [], str(tmp_path), serialized_teds) + + # Step 3: Final result + last_message = result["last_message"] + assert isinstance(last_message, ChatMessage) + assert "5" in last_message.text