-
Notifications
You must be signed in to change notification settings - Fork 21
feat: Add utilities for converting tools into human-in-the-loop tools #368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from dataclasses import dataclass, replace | ||
from typing import Any, Optional | ||
|
||
from haystack.tools import Tool | ||
from rich.console import Console | ||
from rich.panel import Panel | ||
from rich.prompt import Prompt | ||
|
||
from haystack_experimental.tools.types.protocol import ConfirmationPrompt, ExecutionPolicy | ||
|
||
|
||
@dataclass | ||
class ConfirmationResult: | ||
""" | ||
Result of the confirmation prompt to capture a user's decision. | ||
:param action: The action chosen by the user (e.g. "confirm", "reject", or "modify"). | ||
:param feedback: Optional feedback message if the action is "reject". | ||
:param new_params: Optional new parameters if the action is "modify". | ||
""" | ||
|
||
action: str # This is left as a string to allow users to define their own actions if needed. | ||
feedback: Optional[str] = None | ||
new_params: Optional[dict[str, Any]] = None | ||
|
||
|
||
class RichConsolePrompt: | ||
""" | ||
Confirmation prompt using Rich library for enhanced console interaction. | ||
""" | ||
|
||
def __init__(self, console: Optional[Console] = None) -> None: | ||
""" | ||
:param console: Optional Rich Console instance. If None, a new Console will be created. | ||
""" | ||
self.console = console or Console() | ||
|
||
def confirm(self, tool_name: str, params: dict[str, Any]) -> ConfirmationResult: | ||
""" | ||
Ask for user confirmation before executing a tool. | ||
:param tool_name: Name of the tool to be executed. | ||
:param params: Parameters to be passed to the tool. | ||
:returns: | ||
ConfirmationResult with action (e.g. "confirm" or "reject"), optional feedback message and new parameters | ||
if modified. | ||
""" | ||
# Display info | ||
lines = [f"[bold yellow]Tool:[/bold yellow] {tool_name}"] | ||
if params: | ||
lines.append("\n[bold yellow]Arguments:[/bold yellow]") | ||
for k, v in params.items(): | ||
lines.append(f"\n[cyan]{k}:[/cyan]\n {v}") | ||
self.console.print(Panel("\n".join(lines), title="🔧 Tool Execution Request")) | ||
|
||
# Ask action | ||
choice = Prompt.ask( | ||
"\nYour choice", | ||
choices=["y", "n", "m"], # confirm, reject, modify | ||
default="y", | ||
) | ||
if choice == "y": | ||
return ConfirmationResult(action="confirm") | ||
elif choice == "m": | ||
new_params = {} | ||
for k, v in params.items(): | ||
new_val = Prompt.ask(f"Modify '{k}'", default=str(v)) | ||
new_params[k] = new_val | ||
return ConfirmationResult(action="modify", new_params=new_params) | ||
else: # reject | ||
feedback = Prompt.ask("Feedback message (optional)", default="") | ||
return ConfirmationResult(action="reject", feedback=feedback or None) | ||
|
||
|
||
class SimpleInputPrompt: | ||
""" | ||
Simple confirmation prompt using standard input/output. | ||
""" | ||
|
||
def confirm(self, tool_name: str, params: dict[str, Any]) -> ConfirmationResult: | ||
""" | ||
Ask for user confirmation before executing a tool. | ||
:param tool_name: Name of the tool to be executed. | ||
:param params: Parameters to be passed to the tool. | ||
:returns: | ||
ConfirmationResult with action (e.g. "confirm" or "reject"), optional feedback message and new parameters | ||
if modified. | ||
""" | ||
print(f"Tool: {tool_name}") | ||
if params: | ||
print("Arguments:") | ||
for k, v in params.items(): | ||
print(f" {k}: {v}") | ||
|
||
choice = input("Confirm execution? (y=confirm / n=reject / m=modify): ").strip().lower() | ||
if choice == "y": | ||
return ConfirmationResult(action="confirm") | ||
elif choice == "m": | ||
new_params = {} | ||
for k, v in params.items(): | ||
new_val = input(f"Modify '{k}' [{v}]: ").strip() or v | ||
new_params[k] = new_val | ||
return ConfirmationResult(action="modify", new_params=new_params) | ||
else: # modify | ||
feedback = input("Feedback message (optional): ").strip() | ||
return ConfirmationResult(action="reject", feedback=feedback or None) | ||
|
||
|
||
class DefaultPolicy: | ||
""" | ||
Default execution policy: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I read policy, I first thought this is a choice between AlwaysAsk, AskOnce, NeverAsk. Here, the user is always asked. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gotcha, would a rename to AlwaysAskPolicy help make it more clear? |
||
- If confirmed, run the tool with original params. | ||
- If rejected, return a rejection message. | ||
- If modified, run the tool immediately with new params. | ||
""" | ||
|
||
def handle(self, result: ConfirmationResult, tool: Tool, kwargs: dict[str, Any]) -> Any: | ||
""" | ||
Handle the confirmation result and execute the tool accordingly. | ||
:param result: The result from the confirmation prompt. | ||
:param tool: The tool to potentially execute. | ||
:param kwargs: The original parameters for the tool. | ||
:returns: | ||
The result of the tool execution or a rejection message. | ||
""" | ||
if result.action == "reject": | ||
return { | ||
"status": "rejected", | ||
"tool": tool.name, | ||
"feedback": result.feedback or "Tool execution rejected by user", | ||
} | ||
elif result.action == "modify" and result.new_params: | ||
# Run immediately with new params | ||
return tool.function(**result.new_params) | ||
anakin87 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return tool.function(**kwargs) | ||
|
||
|
||
class AutoConfirmPolicy: | ||
""" | ||
Always confirm and run the tool, ignoring user input. | ||
""" | ||
|
||
def handle(self, result: ConfirmationResult, tool: Tool, kwargs: dict[str, Any]) -> Any: | ||
""" | ||
Always execute the tool, ignoring any rejection from the user. | ||
:param result: The result from the confirmation prompt (ignored). | ||
:param tool: The tool to execute. | ||
:param kwargs: The original parameters for the tool. | ||
:returns: The result of the tool execution. | ||
""" | ||
# Always run, ignore user rejection | ||
return tool.function(**kwargs) | ||
|
||
|
||
def confirmation_wrapper( | ||
tool: Tool, | ||
strategy: ConfirmationPrompt, | ||
policy: ExecutionPolicy = DefaultPolicy(), | ||
sjrl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> Tool: | ||
""" | ||
Wrap a tool with a human-in-the-loop confirmation step. | ||
:param tool: The tool to wrap. | ||
:param strategy: The confirmation prompt strategy to use. | ||
:param policy: The execution policy to apply based on user input. | ||
:return: A new Tool instance with confirmation logic. | ||
""" | ||
|
||
def wrapped_function(**kwargs: Any) -> Any: | ||
result = strategy.confirm(tool.name, kwargs) | ||
return policy.handle(result, tool, kwargs) | ||
|
||
return replace(tool, function=wrapped_function) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .protocol import ConfirmationPrompt, ExecutionPolicy | ||
|
||
__all__ = ["ConfirmationPrompt", "ExecutionPolicy"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import TYPE_CHECKING, Any, Protocol | ||
|
||
if TYPE_CHECKING: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is done to avoid circular imports or for other reasons? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah circular imports. I'll double check it's actually needed and leave a comment if it is |
||
from haystack.tools import Tool | ||
|
||
from haystack_experimental.tools.human_in_the_loop import ConfirmationResult | ||
|
||
|
||
class ConfirmationPrompt(Protocol): | ||
def confirm(self, tool_name: str, params: dict[str, Any]) -> "ConfirmationResult": | ||
""" | ||
Ask for user confirmation before executing a tool. | ||
:param tool_name: Name of the tool to be executed. | ||
:param params: Parameters to be passed to the tool. | ||
:returns: | ||
ConfirmationResult with action (e.g. "confirm" or "reject") and optional feedback message. | ||
""" | ||
|
||
|
||
class ExecutionPolicy(Protocol): | ||
def handle(self, result: "ConfirmationResult", tool: "Tool", kwargs: dict[str, Any]) -> Any: | ||
""" | ||
Handle the execution policy based on the user's confirmation result. | ||
:param result: The result from the confirmation prompt. | ||
:param tool: The tool to be executed. | ||
:param kwargs: The parameters to be passed to the tool. | ||
sjrl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
:returns: | ||
The result of the execution policy (e.g., tool output, rejection message, etc.). | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from haystack.components.agents import Agent | ||
from haystack.components.generators.chat import OpenAIChatGenerator | ||
from haystack.dataclasses import ChatMessage | ||
from haystack.tools import create_tool_from_function | ||
from rich.console import Console | ||
|
||
from haystack_experimental.tools.human_in_the_loop import ( | ||
RichConsolePrompt, | ||
SimpleInputPrompt, | ||
confirmation_wrapper, | ||
) | ||
|
||
|
||
def get_bank_balance(account_id: str) -> str: | ||
""" | ||
Simulate fetching a bank balance for a given account ID. | ||
|
||
:param account_id: The ID of the bank account. | ||
:returns: | ||
A string representing the bank balance. | ||
""" | ||
return f"Balance for account {account_id} is $1,234.56" | ||
|
||
|
||
balance_tool = create_tool_from_function( | ||
function=get_bank_balance, | ||
name="get_bank_balance", | ||
description="Get the bank balance for a given account ID.", | ||
) | ||
|
||
# | ||
# Example: Run Tool individually with different Prompts | ||
# | ||
|
||
# Use the console version | ||
cons = Console() | ||
console_tool = confirmation_wrapper(balance_tool, RichConsolePrompt(cons)) | ||
cons.print("\n[bold]Using console confirmation tool:[/bold]") | ||
res = console_tool.invoke(account_id="123456") | ||
cons.print(f"\n[bold green]Result:[/bold green] {res}") | ||
|
||
# Use the simple input version | ||
simple_tool = confirmation_wrapper(balance_tool, SimpleInputPrompt()) | ||
print("\nUsing simple input confirmation tool:") | ||
res = simple_tool.invoke(account_id="123456") | ||
print(f"\nResult: {res}") | ||
|
||
|
||
# | ||
# Example: Running with an Agent | ||
# | ||
|
||
agent = Agent( | ||
chat_generator=OpenAIChatGenerator(model="gpt-4.1"), | ||
tools=[console_tool], # or simple_tool | ||
system_prompt=""" | ||
You are a helpful financial assistant. Use the provided tool to get bank balances when needed. | ||
""", | ||
) | ||
|
||
result = agent.run([ChatMessage.from_user("What's the balance of account 56789?")]) | ||
last_message = result["last_message"] | ||
cons.print(f"\n[bold green]Agent Result:[/bold green] {last_message.text}") |
Uh oh!
There was an error while loading. Please reload this page.