diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 1296b72be..ca3c9f94a 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -5,7 +5,7 @@ from openai import AsyncOpenAI from . import _config -from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult +from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( @@ -160,6 +160,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", + "AgentBase", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", diff --git a/src/agents/agent.py b/src/agents/agent.py index 6c87297f1..b855e03b4 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -67,7 +67,36 @@ class MCPConfig(TypedDict): @dataclass -class Agent(Generic[TContext]): +class AgentBase: + """Base class for `Agent` and `RealtimeAgent`.""" + + name: str + """The name of the agent.""" + + handoff_description: str | None = None + """A description of the agent. This is used when the agent is used as a handoff, so that an + LLM knows what it does and when to invoke it. + """ + + tools: list[Tool] = field(default_factory=list) + """A list of tools that the agent can use.""" + + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + + mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) + """Configuration for MCP servers.""" + + +@dataclass +class Agent(AgentBase, Generic[TContext]): """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more. We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In @@ -76,10 +105,9 @@ class Agent(Generic[TContext]): Agents are generic on the context type. The context is a (mutable) object you create. It is passed to tool functions, handoffs, guardrails, etc. - """ - name: str - """The name of the agent.""" + See `AgentBase` for base parameters that are shared with `RealtimeAgent`s. + """ instructions: ( str @@ -103,11 +131,6 @@ class Agent(Generic[TContext]): usable with OpenAI models, using the Responses API. """ - handoff_description: str | None = None - """A description of the agent. This is used when the agent is used as a handoff, so that an - LLM knows what it does and when to invoke it. - """ - handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list) """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs, and the agent can choose to delegate to them if relevant. Allows for separation of concerns and @@ -125,22 +148,6 @@ class Agent(Generic[TContext]): """Configures model-specific tuning parameters (e.g. temperature, top_p). """ - tools: list[Tool] = field(default_factory=list) - """A list of tools that the agent can use.""" - - mcp_servers: list[MCPServer] = field(default_factory=list) - """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that - the agent can use. Every time the agent runs, it will include tools from these servers in the - list of available tools. - - NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call - `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no - longer needed. - """ - - mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) - """Configuration for MCP servers.""" - input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -256,9 +263,7 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools( - self, run_context: RunContextWrapper[TContext] - ) -> list[Tool]: + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) return await MCPUtil.get_all_function_tools( diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 8643248b1..2cce496c8 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,25 +1,27 @@ from typing import Any, Generic -from .agent import Agent +from typing_extensions import TypeVar + +from .agent import Agent, AgentBase from .run_context import RunContextWrapper, TContext from .tool import Tool +TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase) + -class RunHooks(Generic[TContext]): +class RunHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events in an agent run. Subclass and override the methods you need. """ - async def on_agent_start( - self, context: RunContextWrapper[TContext], agent: Agent[TContext] - ) -> None: + async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: """Called before the agent is invoked. Called each time the current agent changes.""" pass async def on_agent_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, output: Any, ) -> None: """Called when the agent produces a final output.""" @@ -28,8 +30,8 @@ async def on_agent_end( async def on_handoff( self, context: RunContextWrapper[TContext], - from_agent: Agent[TContext], - to_agent: Agent[TContext], + from_agent: TAgent, + to_agent: TAgent, ) -> None: """Called when a handoff occurs.""" pass @@ -37,7 +39,7 @@ async def on_handoff( async def on_tool_start( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, ) -> None: """Called before a tool is invoked.""" @@ -46,7 +48,7 @@ async def on_tool_start( async def on_tool_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, result: str, ) -> None: @@ -54,14 +56,14 @@ async def on_tool_end( pass -class AgentHooks(Generic[TContext]): +class AgentHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events for a specific agent. You can set this on `agent.hooks` to receive events for that specific agent. Subclass and override the methods you need. """ - async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: """Called before the agent is invoked. Called each time the running agent is changed to this agent.""" pass @@ -69,7 +71,7 @@ async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TCon async def on_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, output: Any, ) -> None: """Called when the agent produces a final output.""" @@ -78,8 +80,8 @@ async def on_end( async def on_handoff( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], - source: Agent[TContext], + agent: TAgent, + source: TAgent, ) -> None: """Called when the agent is being handed off to. The `source` is the agent that is handing off to this agent.""" @@ -88,7 +90,7 @@ async def on_handoff( async def on_tool_start( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, ) -> None: """Called before a tool is invoked.""" @@ -97,9 +99,16 @@ async def on_tool_start( async def on_tool_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, result: str, ) -> None: """Called after a tool is invoked.""" pass + + +RunHooks = RunHooksBase[TContext, Agent] +"""Run hooks when using `Agent`.""" + +AgentHooks = AgentHooksBase[TContext, Agent] +"""Agent hooks for `Agent`s.""" diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 4fd606e34..91a9274fc 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -22,7 +22,7 @@ from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic if TYPE_CHECKING: - from ..agent import Agent + from ..agent import AgentBase class MCPServer(abc.ABC): @@ -53,7 +53,7 @@ async def cleanup(self): async def list_tools( self, run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + agent: AgentBase | None = None, ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -117,7 +117,7 @@ async def _apply_tool_filter( self, tools: list[MCPTool], run_context: RunContextWrapper[Any], - agent: Agent[Any], + agent: AgentBase, ) -> list[MCPTool]: """Apply the tool filter to the list of tools.""" if self.tool_filter is None: @@ -153,7 +153,7 @@ async def _apply_dynamic_tool_filter( self, tools: list[MCPTool], run_context: RunContextWrapper[Any], - agent: Agent[Any], + agent: AgentBase, ) -> list[MCPTool]: """Apply dynamic tool filtering using a callable filter function.""" @@ -244,7 +244,7 @@ async def connect(self): async def list_tools( self, run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + agent: AgentBase | None = None, ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 48da9f841..18cf4440a 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -5,12 +5,11 @@ from typing_extensions import NotRequired, TypedDict -from agents.strict_schema import ensure_strict_json_schema - from .. import _debug from ..exceptions import AgentsException, ModelBehaviorError, UserError from ..logger import logger from ..run_context import RunContextWrapper +from ..strict_schema import ensure_strict_json_schema from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span from ..util._types import MaybeAwaitable @@ -18,7 +17,7 @@ if TYPE_CHECKING: from mcp.types import Tool as MCPTool - from ..agent import Agent + from ..agent import AgentBase from .server import MCPServer @@ -29,7 +28,7 @@ class ToolFilterContext: run_context: RunContextWrapper[Any] """The current run context.""" - agent: "Agent[Any]" + agent: "AgentBase" """The agent that is requesting the tool list.""" server_name: str @@ -100,7 +99,7 @@ async def get_all_function_tools( servers: list["MCPServer"], convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], - agent: "Agent[Any]", + agent: "AgentBase", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] @@ -126,7 +125,7 @@ async def get_function_tools( server: "MCPServer", convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], - agent: "Agent[Any]", + agent: "AgentBase", ) -> list[Tool]: """Get all function tools from a single MCP server.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index 3aab47752..0faf09846 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -30,8 +30,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - - from .agent import Agent + from .agent import Agent, AgentBase ToolParams = ParamSpec("ToolParams") @@ -88,7 +87,7 @@ class FunctionTool: """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True """Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" @@ -301,7 +300,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -316,7 +315,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -331,7 +330,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index b232bf75e..9b02a2c42 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -5,7 +5,14 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool +from agents import ( + Agent, + AgentBase, + FunctionTool, + ModelBehaviorError, + RunContextWrapper, + function_tool, +) from agents.tool import default_tool_error_function from agents.tool_context import ToolContext @@ -268,7 +275,7 @@ async def test_is_enabled_bool_and_callable(): def disabled_tool(): return "nope" - async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool: + async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: AgentBase) -> bool: return ctx.context.enable_tools @function_tool(is_enabled=cond_enabled)