diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 51116b71ae..1bd04d7b8b 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -756,10 +756,20 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Content] | Any # that should not be forwarded to external MCP servers. # conversation_id is an internal tracking ID used by services like Azure AI. # options contains metadata/store used by AG-UI for Azure AI client requirements. + # response_format is a type used for structured outputs and cannot be serialized. filtered_kwargs = { k: v for k, v in kwargs.items() - if k not in {"chat_options", "tools", "tool_choice", "thread", "conversation_id", "options"} + if k + not in { + "chat_options", + "tools", + "tool_choice", + "thread", + "conversation_id", + "options", + "response_format", + } } # Try the operation, reconnecting once if the connection is closed @@ -819,10 +829,28 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] "Prompts are not loaded for this server, please set load_prompts=True in the constructor." ) + # Filter out framework kwargs that cannot be serialized by the MCP SDK. + # These are internal objects passed through the function invocation pipeline + # that should not be forwarded to external MCP servers. + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k + not in { + "chat_options", + "tools", + "tool_choice", + "thread", + "conversation_id", + "options", + "response_format", + } + } + # Try the operation, reconnecting once if the connection is closed for attempt in range(2): try: - prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) # type: ignore + prompt_result = await self.session.get_prompt(prompt_name, arguments=filtered_kwargs) # type: ignore if self.parse_prompt_results is None: return prompt_result if self.parse_prompt_results is True: diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 67bf94acaf..143e421344 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1014,6 +1014,78 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: await func.invoke(param="test_value") +async def test_mcp_tool_filters_framework_kwargs(): + """Test that call_tool method filters out framework-specific kwargs.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": { + "param": {"type": "string"}, + "code": {"type": "string"}, + }, + "required": ["param"], + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + func = server.functions[0] + + # Create a mock response_format type + class MockResponseFormat(BaseModel): + result: str + + # Call with framework kwargs that should be filtered out + result = await func.invoke( + param="test_value", + code="print('hello')", + chat_options={"key": "value"}, + tools=["tool1", "tool2"], + tool_choice="auto", + thread="thread_id", + conversation_id="conv_123", + options={"store": "data"}, + response_format=MockResponseFormat, + ) + + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "Tool executed" + + # Verify the session.call_tool was called with only the actual tool parameters + server.session.call_tool.assert_called_once() + call_args = server.session.call_tool.call_args + # Should only include param and code, not the framework kwargs + assert call_args.kwargs["arguments"] == {"param": "test_value", "code": "print('hello')"} + # Ensure none of the framework kwargs were passed + assert "chat_options" not in call_args.kwargs["arguments"] + assert "tools" not in call_args.kwargs["arguments"] + assert "tool_choice" not in call_args.kwargs["arguments"] + assert "thread" not in call_args.kwargs["arguments"] + assert "conversation_id" not in call_args.kwargs["arguments"] + assert "options" not in call_args.kwargs["arguments"] + assert "response_format" not in call_args.kwargs["arguments"] + + async def test_local_mcp_server_prompt_execution(): """Test prompt execution through MCP server.""" @@ -1059,6 +1131,81 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: assert result[0].contents[0].text == "Test message" +async def test_mcp_prompt_filters_framework_kwargs(): + """Test that get_prompt method filters out framework-specific kwargs.""" + + class TestMCPTool(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_prompts = AsyncMock( + return_value=types.ListPromptsResult( + prompts=[ + types.Prompt( + name="test_prompt", + description="Test prompt", + arguments=[ + types.PromptArgument(name="arg", description="Test arg", required=True), + types.PromptArgument(name="code", description="Code arg", required=False), + ], + ) + ] + ) + ) + self.session.get_prompt = AsyncMock( + return_value=types.GetPromptResult( + description="Generated prompt", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Test message"), + ) + ], + ) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestMCPTool(name="test_server") + async with server: + await server.load_prompts() + prompt = server.functions[0] + + # Create a mock response_format type + class MockResponseFormat(BaseModel): + result: str + + # Call with framework kwargs that should be filtered out + result = await prompt.invoke( + arg="test_value", + code="print('hello')", + chat_options={"key": "value"}, + tools=["tool1", "tool2"], + tool_choice="auto", + thread="thread_id", + conversation_id="conv_123", + options={"store": "data"}, + response_format=MockResponseFormat, + ) + + assert len(result) == 1 + assert isinstance(result[0], ChatMessage) + + # Verify the session.get_prompt was called with only the actual prompt parameters + server.session.get_prompt.assert_called_once() + call_args = server.session.get_prompt.call_args + # Should only include arg and code, not the framework kwargs + assert call_args.kwargs["arguments"] == {"arg": "test_value", "code": "print('hello')"} + # Ensure none of the framework kwargs were passed + assert "chat_options" not in call_args.kwargs["arguments"] + assert "tools" not in call_args.kwargs["arguments"] + assert "tool_choice" not in call_args.kwargs["arguments"] + assert "thread" not in call_args.kwargs["arguments"] + assert "conversation_id" not in call_args.kwargs["arguments"] + assert "options" not in call_args.kwargs["arguments"] + assert "response_format" not in call_args.kwargs["arguments"] + + @pytest.mark.parametrize( "approval_mode,expected_approvals", [