diff --git a/.changeset/seven-rice-shop.md b/.changeset/seven-rice-shop.md new file mode 100644 index 00000000..2e8dbe6f --- /dev/null +++ b/.changeset/seven-rice-shop.md @@ -0,0 +1,5 @@ +--- +'@openai/agents-core': patch +--- + +refactor: restructure mcp tools fetching with options object pattern diff --git a/examples/mcp/README.md b/examples/mcp/README.md index 858ed2a3..14c14bbe 100644 --- a/examples/mcp/README.md +++ b/examples/mcp/README.md @@ -18,3 +18,9 @@ pnpm -F mcp start:stdio ```bash pnpm -F mcp start:tool-filter ``` + +`get-all-mcp-tools-example.ts` demonstrates how to use the `getAllMcpTools` function to fetch tools from multiple MCP servers: + +```bash +pnpm -F mcp start:get-all-tools +``` diff --git a/examples/mcp/get-all-mcp-tools-example.ts b/examples/mcp/get-all-mcp-tools-example.ts new file mode 100644 index 00000000..419c718e --- /dev/null +++ b/examples/mcp/get-all-mcp-tools-example.ts @@ -0,0 +1,110 @@ +import { + Agent, + run, + MCPServerStdio, + getAllMcpTools, + withTrace, +} from '@openai/agents'; +import * as path from 'node:path'; + +async function main() { + const samplesDir = path.join(__dirname, 'sample_files'); + + // Create multiple MCP servers to demonstrate getAllMcpTools + const filesystemServer = new MCPServerStdio({ + name: 'Filesystem Server', + fullCommand: `npx -y @modelcontextprotocol/server-filesystem ${samplesDir}`, + }); + + // Note: This example shows how to use multiple servers + // In practice, you would have different servers with different tools + const servers = [filesystemServer]; + + // Connect all servers + for (const server of servers) { + await server.connect(); + } + + try { + await withTrace('getAllMcpTools Example', async () => { + console.log('=== Using getAllMcpTools to fetch all tools ===\n'); + + // Method 1: Simple array of servers + const allTools = await getAllMcpTools(servers); + console.log( + `Found ${allTools.length} tools from ${servers.length} server(s):`, + ); + allTools.forEach((tool) => { + const description = + tool.type === 'function' ? tool.description : 'No description'; + console.log(`- ${tool.name}: ${description}`); + }); + + console.log('\n=== Using getAllMcpTools with options object ===\n'); + + // Method 2: Using options object (recommended for more control) + const allToolsWithOptions = await getAllMcpTools({ + mcpServers: servers, + convertSchemasToStrict: true, // Convert schemas to strict mode + }); + + console.log( + `Found ${allToolsWithOptions.length} tools with strict schemas:`, + ); + allToolsWithOptions.forEach((tool) => { + const description = + tool.type === 'function' ? tool.description : 'No description'; + console.log(`- ${tool.name}: ${description}`); + }); + + console.log('\n=== Creating agent with pre-fetched tools ===\n'); + + // Create agent using the pre-fetched tools + const agent = new Agent({ + name: 'MCP Assistant with Pre-fetched Tools', + instructions: + 'Use the available tools to help the user with file operations.', + tools: allTools, // Use pre-fetched tools instead of mcpServers + }); + + // Test the agent + const message = 'List the available files and read one of them.'; + console.log(`Running: ${message}\n`); + const result = await run(agent, message); + console.log(result.finalOutput); + + console.log( + '\n=== Demonstrating tool filtering with getAllMcpTools ===\n', + ); + + // Add tool filter to one of the servers + filesystemServer.toolFilter = { + allowedToolNames: ['read_file'], // Only allow read_file tool + }; + + // Note: For callable filters to work, you need to pass runContext and agent + // This is typically done internally when the agent runs + const filteredTools = await getAllMcpTools({ + mcpServers: servers, + convertSchemasToStrict: false, + // runContext and agent would normally be provided by the agent runtime + // For demo purposes, we're showing the structure + }); + + console.log(`After filtering, found ${filteredTools.length} tools:`); + filteredTools.forEach((tool) => { + console.log(`- ${tool.name}`); + }); + }); + } finally { + // Clean up - close all servers + for (const server of servers) { + await server.close(); + } + } +} + +main().catch((err) => { + console.error('Error:', err); + process.exit(1); +}); diff --git a/examples/mcp/package.json b/examples/mcp/package.json index 3e0c255c..a65ed18e 100644 --- a/examples/mcp/package.json +++ b/examples/mcp/package.json @@ -14,6 +14,7 @@ "start:hosted-mcp-human-in-the-loop": "tsx hosted-mcp-human-in-the-loop.ts", "start:hosted-mcp-simple": "tsx hosted-mcp-simple.ts", "start:tool-filter": "tsx tool-filter-example.ts", - "start:sse": "tsx sse-example.ts" + "start:sse": "tsx sse-example.ts", + "start:get-all-tools": "tsx get-all-mcp-tools-example.ts" } } diff --git a/packages/agents-core/src/agent.ts b/packages/agents-core/src/agent.ts index 114ec884..d40dd610 100644 --- a/packages/agents-core/src/agent.ts +++ b/packages/agents-core/src/agent.ts @@ -518,7 +518,12 @@ export class Agent< runContext: RunContext, ): Promise[]> { if (this.mcpServers.length > 0) { - return getAllMcpTools(this.mcpServers, runContext, this, false); + return getAllMcpTools({ + mcpServers: this.mcpServers, + runContext, + agent: this, + convertSchemasToStrict: false, + }); } return []; diff --git a/packages/agents-core/src/index.ts b/packages/agents-core/src/index.ts index a5b9f9d3..6f3fd125 100644 --- a/packages/agents-core/src/index.ts +++ b/packages/agents-core/src/index.ts @@ -70,10 +70,12 @@ export { getLogger } from './logger'; export { getAllMcpTools, invalidateServerToolsCache, + mcpToFunctionTool, MCPServer, MCPServerStdio, MCPServerStreamableHttp, MCPServerSSE, + GetAllMcpToolsOptions, } from './mcp'; export { MCPToolFilterCallable, diff --git a/packages/agents-core/src/mcp.ts b/packages/agents-core/src/mcp.ts index 93eb2439..f56e7197 100644 --- a/packages/agents-core/src/mcp.ts +++ b/packages/agents-core/src/mcp.ts @@ -285,35 +285,6 @@ export class MCPServerSSE extends BaseMCPServerSSE { * Fetches and flattens all tools from multiple MCP servers. * Logs and skips any servers that fail to respond. */ -export async function getAllMcpFunctionTools( - mcpServers: MCPServer[], - runContext: RunContext, - agent: Agent, - convertSchemasToStrict = false, -): Promise[]> { - const allTools: Tool[] = []; - const toolNames = new Set(); - for (const server of mcpServers) { - const serverTools = await getFunctionToolsFromServer( - server, - runContext, - agent, - convertSchemasToStrict, - ); - const serverToolNames = new Set(serverTools.map((t) => t.name)); - const intersection = [...serverToolNames].filter((n) => toolNames.has(n)); - if (intersection.length > 0) { - throw new UserError( - `Duplicate tool names found across MCP servers: ${intersection.join(', ')}`, - ); - } - for (const t of serverTools) { - toolNames.add(t.name); - allTools.push(t); - } - } - return allTools; -} const _cachedTools: Record = {}; /** @@ -327,12 +298,17 @@ export async function invalidateServerToolsCache(serverName: string) { /** * Fetches all function tools from a single MCP server. */ -async function getFunctionToolsFromServer( - server: MCPServer, - runContext: RunContext, - agent: Agent, - convertSchemasToStrict: boolean, -): Promise[]> { +async function getFunctionToolsFromServer({ + server, + convertSchemasToStrict, + runContext, + agent, +}: { + server: MCPServer; + convertSchemasToStrict: boolean; + runContext?: RunContext; + agent?: Agent; +}): Promise[]> { if (server.cacheToolsList && _cachedTools[server.name]) { return _cachedTools[server.name].map((t) => mcpToFunctionTool(t, server, convertSchemasToStrict), @@ -341,52 +317,54 @@ async function getFunctionToolsFromServer( return withMCPListToolsSpan( async (span) => { const fetchedMcpTools = await server.listTools(); - const mcpTools: MCPTool[] = []; - const context = { - runContext, - agent, - serverName: server.name, - }; - for (const tool of fetchedMcpTools) { - const filter = server.toolFilter; - if (filter) { - if (filter && typeof filter === 'function') { - const filtered = await filter(context, tool); - if (!filtered) { - globalLogger.debug( - `MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`, - ); - continue; // skip this tool - } - } else { - const allowedToolNames = filter.allowedToolNames ?? []; - const blockedToolNames = filter.blockedToolNames ?? []; - if (allowedToolNames.length > 0 || blockedToolNames.length > 0) { - const allowed = - allowedToolNames.length > 0 - ? allowedToolNames.includes(tool.name) - : true; - const blocked = - blockedToolNames.length > 0 - ? blockedToolNames.includes(tool.name) - : false; - if (!allowed || blocked) { - if (blocked) { - globalLogger.debug( - `MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`, - ); - } else if (!allowed) { - globalLogger.debug( - `MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`, - ); + let mcpTools: MCPTool[] = fetchedMcpTools; + + if (runContext && agent) { + const context = { runContext, agent, serverName: server.name }; + const filteredTools: MCPTool[] = []; + for (const tool of fetchedMcpTools) { + const filter = server.toolFilter; + if (filter) { + if (typeof filter === 'function') { + const filtered = await filter(context, tool); + if (!filtered) { + globalLogger.debug( + `MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`, + ); + continue; + } + } else { + const allowedToolNames = filter.allowedToolNames ?? []; + const blockedToolNames = filter.blockedToolNames ?? []; + if (allowedToolNames.length > 0 || blockedToolNames.length > 0) { + const allowed = + allowedToolNames.length > 0 + ? allowedToolNames.includes(tool.name) + : true; + const blocked = + blockedToolNames.length > 0 + ? blockedToolNames.includes(tool.name) + : false; + if (!allowed || blocked) { + if (blocked) { + globalLogger.debug( + `MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`, + ); + } else if (!allowed) { + globalLogger.debug( + `MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`, + ); + } + continue; } - continue; // skip this tool } } } + filteredTools.push(tool); } - mcpTools.push(tool); + mcpTools = filteredTools; } + span.spanData.result = mcpTools.map((t) => t.name); const tools: FunctionTool[] = mcpTools.map((t) => mcpToFunctionTool(t, server, convertSchemasToStrict), @@ -400,21 +378,70 @@ async function getFunctionToolsFromServer( ); } +/** + * Options for fetching MCP tools. + */ +export type GetAllMcpToolsOptions = { + mcpServers: MCPServer[]; + convertSchemasToStrict?: boolean; + runContext?: RunContext; + agent?: Agent; +}; + /** * Returns all MCP tools from the provided servers, using the function tool conversion. + * If runContext and agent are provided, callable tool filters will be applied. */ export async function getAllMcpTools( mcpServers: MCPServer[], - runContext: RunContext, - agent: Agent, +): Promise[]>; +export async function getAllMcpTools( + opts: GetAllMcpToolsOptions, +): Promise[]>; +export async function getAllMcpTools( + mcpServersOrOpts: MCPServer[] | GetAllMcpToolsOptions, + runContext?: RunContext, + agent?: Agent, convertSchemasToStrict = false, ): Promise[]> { - return getAllMcpFunctionTools( + const opts = Array.isArray(mcpServersOrOpts) + ? { + mcpServers: mcpServersOrOpts, + runContext, + agent, + convertSchemasToStrict, + } + : mcpServersOrOpts; + + const { mcpServers, - runContext, - agent, - convertSchemasToStrict, - ); + convertSchemasToStrict: convertSchemasToStrictFromOpts = false, + runContext: runContextFromOpts, + agent: agentFromOpts, + } = opts; + const allTools: Tool[] = []; + const toolNames = new Set(); + + for (const server of mcpServers) { + const serverTools = await getFunctionToolsFromServer({ + server, + convertSchemasToStrict: convertSchemasToStrictFromOpts, + runContext: runContextFromOpts, + agent: agentFromOpts, + }); + const serverToolNames = new Set(serverTools.map((t) => t.name)); + const intersection = [...serverToolNames].filter((n) => toolNames.has(n)); + if (intersection.length > 0) { + throw new UserError( + `Duplicate tool names found across MCP servers: ${intersection.join(', ')}`, + ); + } + for (const t of serverTools) { + toolNames.add(t.name); + allTools.push(t); + } + } + return allTools; } /** diff --git a/packages/agents-core/test/mcpCache.test.ts b/packages/agents-core/test/mcpCache.test.ts index 537b9983..42c83eb5 100644 --- a/packages/agents-core/test/mcpCache.test.ts +++ b/packages/agents-core/test/mcpCache.test.ts @@ -51,27 +51,27 @@ describe('MCP tools cache invalidation', () => { ]; const server = new StubServer('server', toolsA); - let tools = await getAllMcpTools( - [server], - new RunContext({}), - new Agent({ name: 'test' }), - ); + let tools = await getAllMcpTools({ + mcpServers: [server], + runContext: new RunContext({}), + agent: new Agent({ name: 'test' }), + }); expect(tools.map((t) => t.name)).toEqual(['a']); server.toolList = toolsB; - tools = await getAllMcpTools( - [server], - new RunContext({}), - new Agent({ name: 'test' }), - ); + tools = await getAllMcpTools({ + mcpServers: [server], + runContext: new RunContext({}), + agent: new Agent({ name: 'test' }), + }); expect(tools.map((t) => t.name)).toEqual(['a']); await server.invalidateToolsCache(); - tools = await getAllMcpTools( - [server], - new RunContext({}), - new Agent({ name: 'test' }), - ); + tools = await getAllMcpTools({ + mcpServers: [server], + runContext: new RunContext({}), + agent: new Agent({ name: 'test' }), + }); expect(tools.map((t) => t.name)).toEqual(['b']); }); }); @@ -87,11 +87,11 @@ describe('MCP tools cache invalidation', () => { ]; const serverA = new StubServer('server', tools); - await getAllMcpTools( - [serverA], - new RunContext({}), - new Agent({ name: 'test' }), - ); + await getAllMcpTools({ + mcpServers: [serverA], + runContext: new RunContext({}), + agent: new Agent({ name: 'test' }), + }); const serverB = new StubServer('server', tools); let called = false; @@ -100,11 +100,11 @@ describe('MCP tools cache invalidation', () => { return []; }; - const cachedTools = (await getAllMcpTools( - [serverB], - new RunContext({}), - new Agent({ name: 'test' }), - )) as FunctionTool[]; + const cachedTools = (await getAllMcpTools({ + mcpServers: [serverB], + runContext: new RunContext({}), + agent: new Agent({ name: 'test' }), + })) as FunctionTool[]; await cachedTools[0].invoke({} as any, '{}'); expect(called).toBe(true);