Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,20 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Content] | Any
# that should not be forwarded to external MCP servers.
# conversation_id is an internal tracking ID used by services like Azure AI.
# options contains metadata/store used by AG-UI for Azure AI client requirements.
# response_format is a type used for structured outputs and cannot be serialized.
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if k not in {"chat_options", "tools", "tool_choice", "thread", "conversation_id", "options"}
if k
not in {
"chat_options",
"tools",
"tool_choice",
"thread",
"conversation_id",
"options",
"response_format",
}
}

# Try the operation, reconnecting once if the connection is closed
Expand Down Expand Up @@ -819,10 +829,28 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]
"Prompts are not loaded for this server, please set load_prompts=True in the constructor."
)

# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in {
"chat_options",
"tools",
"tool_choice",
"thread",
"conversation_id",
"options",
"response_format",
}
}

# Try the operation, reconnecting once if the connection is closed
for attempt in range(2):
try:
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) # type: ignore
prompt_result = await self.session.get_prompt(prompt_name, arguments=filtered_kwargs) # type: ignore
if self.parse_prompt_results is None:
return prompt_result
if self.parse_prompt_results is True:
Expand Down
147 changes: 147 additions & 0 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,78 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
await func.invoke(param="test_value")


async def test_mcp_tool_filters_framework_kwargs():
"""Test that call_tool method filters out framework-specific kwargs."""

class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {
"param": {"type": "string"},
"code": {"type": "string"},
},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed")])
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]

# Create a mock response_format type
class MockResponseFormat(BaseModel):
result: str

# Call with framework kwargs that should be filtered out
result = await func.invoke(
param="test_value",
code="print('hello')",
chat_options={"key": "value"},
tools=["tool1", "tool2"],
tool_choice="auto",
thread="thread_id",
conversation_id="conv_123",
options={"store": "data"},
response_format=MockResponseFormat,
)

assert len(result) == 1
assert result[0].type == "text"
assert result[0].text == "Tool executed"

# Verify the session.call_tool was called with only the actual tool parameters
server.session.call_tool.assert_called_once()
call_args = server.session.call_tool.call_args
# Should only include param and code, not the framework kwargs
assert call_args.kwargs["arguments"] == {"param": "test_value", "code": "print('hello')"}
# Ensure none of the framework kwargs were passed
assert "chat_options" not in call_args.kwargs["arguments"]
assert "tools" not in call_args.kwargs["arguments"]
assert "tool_choice" not in call_args.kwargs["arguments"]
assert "thread" not in call_args.kwargs["arguments"]
assert "conversation_id" not in call_args.kwargs["arguments"]
assert "options" not in call_args.kwargs["arguments"]
assert "response_format" not in call_args.kwargs["arguments"]


async def test_local_mcp_server_prompt_execution():
"""Test prompt execution through MCP server."""

Expand Down Expand Up @@ -1059,6 +1131,81 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
assert result[0].contents[0].text == "Test message"


async def test_mcp_prompt_filters_framework_kwargs():
"""Test that get_prompt method filters out framework-specific kwargs."""

class TestMCPTool(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_prompts = AsyncMock(
return_value=types.ListPromptsResult(
prompts=[
types.Prompt(
name="test_prompt",
description="Test prompt",
arguments=[
types.PromptArgument(name="arg", description="Test arg", required=True),
types.PromptArgument(name="code", description="Code arg", required=False),
],
)
]
)
)
self.session.get_prompt = AsyncMock(
return_value=types.GetPromptResult(
description="Generated prompt",
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text="Test message"),
)
],
)
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

server = TestMCPTool(name="test_server")
async with server:
await server.load_prompts()
prompt = server.functions[0]

# Create a mock response_format type
class MockResponseFormat(BaseModel):
result: str

# Call with framework kwargs that should be filtered out
result = await prompt.invoke(
arg="test_value",
code="print('hello')",
chat_options={"key": "value"},
tools=["tool1", "tool2"],
tool_choice="auto",
thread="thread_id",
conversation_id="conv_123",
options={"store": "data"},
response_format=MockResponseFormat,
)

assert len(result) == 1
assert isinstance(result[0], ChatMessage)

# Verify the session.get_prompt was called with only the actual prompt parameters
server.session.get_prompt.assert_called_once()
call_args = server.session.get_prompt.call_args
# Should only include arg and code, not the framework kwargs
assert call_args.kwargs["arguments"] == {"arg": "test_value", "code": "print('hello')"}
# Ensure none of the framework kwargs were passed
assert "chat_options" not in call_args.kwargs["arguments"]
assert "tools" not in call_args.kwargs["arguments"]
assert "tool_choice" not in call_args.kwargs["arguments"]
assert "thread" not in call_args.kwargs["arguments"]
assert "conversation_id" not in call_args.kwargs["arguments"]
assert "options" not in call_args.kwargs["arguments"]
assert "response_format" not in call_args.kwargs["arguments"]


@pytest.mark.parametrize(
"approval_mode,expected_approvals",
[
Expand Down