diff --git a/src/agents/agent.py b/src/agents/agent.py index b67a12c0d..28320bb9b 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -66,6 +66,7 @@ class MCPConfig(TypedDict): """ +@dataclass @dataclass class AgentBase(Generic[TContext]): """Base class for `Agent` and `RealtimeAgent`.""" @@ -94,6 +95,65 @@ class AgentBase(Generic[TContext]): mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) """Configuration for MCP servers.""" + def __post_init__(self): + if not isinstance(self.name, str): + raise TypeError("Agent name must be a string.") + if not self.name.strip(): + raise ValueError("Agent name cannot be empty or whitespace.") + + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) + + async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + mcp_tools = await self.get_mcp_tools(run_context) + + async def _check_tool_enabled(tool: Tool) -> bool: + if not isinstance(tool, FunctionTool): + return True + + attr = tool.is_enabled + if isinstance(attr, bool): + return attr + res = attr(run_context, self) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) + enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] + return [*mcp_tools, *enabled] + + """Base class for `Agent` and `RealtimeAgent`.""" + + name: str + """The name of the agent.""" + + handoff_description: str | None = None + """A description of the agent. This is used when the agent is used as a handoff, so that an + LLM knows what it does and when to invoke it. + """ + + tools: list[Tool] = field(default_factory=list) + """A list of tools that the agent can use.""" + + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + + mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) + """Configuration for MCP servers.""" + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)