diff --git a/docs/mcp/client.md b/docs/mcp/client.md index ce2f30d332..87de09751d 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -280,7 +280,7 @@ async def main(): ``` 1. When you supply `http_client`, Pydantic AI re-uses this client for every - request. Anything supported by **httpx** (`verify`, `cert`, custom + request. Anything supported by **httpx** (`verify`, `cert`, custom proxies, timeouts, etc.) therefore applies to all MCP traffic. ## MCP Sampling @@ -391,3 +391,143 @@ server = MCPServerStdio( allow_sampling=False, ) ``` + +## Elicitation + +In MCP, [elicitation](https://modelcontextprotocol.io/docs/concepts/elicitation) allows a server to request for [structured input](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) from the client for missing or additional context during a session. + +Elicitation let models essentially say "Hold on - I need to know X before i can continue" rather than requiring everything upfront or taking a shot in the dark. + +### How Elicitation works + +Elicitation introduces a new protocol message type called [`ElicitRequest`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitrequest), which is sent from the server to the client when it needs additional information. The client can then respond with an [`ElicitResult`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitresult) or an `ErrorData` message. + +Here's a typical interaction: + +- User makes a request to the MCP server (e.g. "Book a table at that Italian place") +- The server identifies that it needs more information (e.g. "Which Italian place?", "What date and time?") +- The server sends an `ElicitRequest` to the client asking for the missing information. +- The client receives the request, presents it to the user (e.g. via a terminal prompt, GUI dialog, or web interface). +- User provides the requested information, `decline` or `cancel` the request. +- The client sends an `ElicitResult` back to the server with the user's response. +- With the structured data, the server can continue processing the original request. + +This allows for a more interactive and user-friendly experience, especially for multi-staged workflows. Instead of requiring all information upfront, the server can ask for it as needed, making the interaction feel more natural. + +### Setting up Elicitation + +To enable elicitation, provide an [`elicitation_callback`][pydantic_ai.mcp.MCPServer.elicitation_callback] function when creating your MCP server instance: + +```python {title="restaurant_server.py" py="3.10"} +from mcp.server.fastmcp import Context, FastMCP +from pydantic import BaseModel, Field + +mcp = FastMCP(name='Restaurant Booking') + + +class BookingDetails(BaseModel): + """Schema for restaurant booking information.""" + + restaurant: str = Field(description='Choose a restaurant') + party_size: int = Field(description='Number of people', ge=1, le=8) + date: str = Field(description='Reservation date (DD-MM-YYYY)') + + +@mcp.tool() +async def book_table(ctx: Context) -> str: + """Book a restaurant table with user input.""" + # Ask user for booking details using Pydantic schema + result = await ctx.elicit(message='Please provide your booking details:', schema=BookingDetails) + + if result.action == 'accept' and result.data: + booking = result.data + return f'✅ Booked table for {booking.party_size} at {booking.restaurant} on {booking.date}' + elif result.action == 'decline': + return 'No problem! Maybe another time.' + else: # cancel + return 'Booking cancelled.' + + +if __name__ == '__main__': + mcp.run(transport='stdio') +``` + +This server demonstrates elicitation by requesting structured booking details from the client when the `book_table` tool is called. Here's how to create a client that handles these elicitation requests: + +```python {title="client_example.py" py="3.10" requires="restaurant_server.py" test="skip"} +import asyncio +from typing import Any + +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.types import ElicitRequestParams, ElicitResult + +from pydantic_ai import Agent +from pydantic_ai.mcp import MCPServerStdio + + +async def handle_elicitation( + context: RequestContext[ClientSession, Any, Any], + params: ElicitRequestParams, +) -> ElicitResult: + """Handle elicitation requests from MCP server.""" + print(f'\n{params.message}') + + if not params.requestedSchema: + response = input('Response: ') + return ElicitResult(action='accept', content={'response': response}) + + # Collect data for each field + properties = params.requestedSchema['properties'] + data = {} + + for field, info in properties.items(): + description = info.get('description', field) + + value = input(f'{description}: ') + + # Convert to proper type based on JSON schema + if info.get('type') == 'integer': + data[field] = int(value) + else: + data[field] = value + + # Confirm + confirm = input('\nConfirm booking? (y/n/c): ').lower() + + if confirm == 'y': + print('Booking details:', data) + return ElicitResult(action='accept', content=data) + elif confirm == 'n': + return ElicitResult(action='decline') + else: + return ElicitResult(action='cancel') + + +# Set up MCP server connection +restaurant_server = MCPServerStdio( + command='python', args=['restaurant_server.py'], elicitation_callback=handle_elicitation +) + +# Create agent +agent = Agent('openai:gpt-4o', toolsets=[restaurant_server]) + + +async def main(): + """Run the agent to book a restaurant table.""" + async with agent: + result = await agent.run('Book me a table') + print(f'\nResult: {result.output}') + + +if __name__ == '__main__': + asyncio.run(main()) +``` + +### Supported Schema Types + +MCP elicitation supports string, number, boolean, and enum types with flat object structures only. These limitations ensure reliable cross-client compatibility. See [supported schema types](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) for details. + +### Security + +MCP Elicitation requires careful handling - servers must not request sensitive information, and clients must implement user approval controls with clear explanations. See [security considerations](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#security-considerations) for details. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 8dd155992d..35f0297f52 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -18,14 +18,13 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated -from pydantic_ai._run_context import RunContext -from pydantic_ai.tools import ToolDefinition +from pydantic_ai.tools import RunContext, ToolDefinition from .toolsets.abstract import AbstractToolset, ToolsetTool try: from mcp import types as mcp_types - from mcp.client.session import ClientSession, LoggingFnT + from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client @@ -65,6 +64,7 @@ class MCPServer(AbstractToolset[Any], ABC): allow_sampling: bool sampling_model: models.Model | None max_retries: int + elicitation_callback: ElicitationFnT | None = None _id: str | None @@ -87,6 +87,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + elicitation_callback: ElicitationFnT | None = None, *, id: str | None = None, ): @@ -99,6 +100,7 @@ def __init__( self.allow_sampling = allow_sampling self.sampling_model = sampling_model self.max_retries = max_retries + self.elicitation_callback = elicitation_callback self._id = id or tool_prefix @@ -247,6 +249,7 @@ async def __aenter__(self) -> Self: read_stream=self._read_stream, write_stream=self._write_stream, sampling_callback=self._sampling_callback if self.allow_sampling else None, + elicitation_callback=self.elicitation_callback, logging_callback=self.log_handler, read_timeout_seconds=timedelta(seconds=self.read_timeout), ) @@ -445,6 +448,9 @@ async def main(): max_retries: int """The maximum number of times to retry a tool call.""" + elicitation_callback: ElicitationFnT | None = None + """Callback function to handle elicitation requests from the server.""" + def __init__( self, command: str, @@ -460,6 +466,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + elicitation_callback: ElicitationFnT | None = None, *, id: str | None = None, ): @@ -479,6 +486,7 @@ def __init__( allow_sampling: Whether to allow MCP sampling through this client. sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. + elicitation_callback: Callback function to handle elicitation requests from the server. id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command @@ -496,6 +504,7 @@ def __init__( allow_sampling, sampling_model, max_retries, + elicitation_callback, id=id, ) @@ -605,6 +614,9 @@ class _MCPServerHTTP(MCPServer): max_retries: int """The maximum number of times to retry a tool call.""" + elicitation_callback: ElicitationFnT | None = None + """Callback function to handle elicitation requests from the server.""" + def __init__( self, *, @@ -621,6 +633,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + elicitation_callback: ElicitationFnT | None = None, **_deprecated_kwargs: Any, ): """Build a new MCP server. @@ -639,6 +652,7 @@ def __init__( allow_sampling: Whether to allow MCP sampling through this client. sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. + elicitation_callback: Callback function to handle elicitation requests from the server. """ if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: @@ -668,6 +682,7 @@ def __init__( allow_sampling, sampling_model, max_retries, + elicitation_callback, id=id, ) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 65ef56d789..1885dec40d 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"] # CLI cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"] # MCP -mcp = ["mcp>=1.10.0; python_version >= '3.10'"] +mcp = ["mcp>=1.12.3; python_version >= '3.10'"] # Evals evals = ["pydantic-evals=={{ version }}"] # A2A diff --git a/tests/mcp_server.py b/tests/mcp_server.py index f7fa75e9b4..07a866a50b 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -13,7 +13,7 @@ TextContent, TextResourceContents, ) -from pydantic import AnyUrl +from pydantic import AnyUrl, BaseModel mcp = FastMCP('Pydantic AI MCP Server') log_level = 'unset' @@ -186,7 +186,7 @@ async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> @mcp.tool() -async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore +async def use_sampling(ctx: Context[ServerSessionT, LifespanContextT, RequestT], foo: str) -> str: """Use sampling callback.""" result = await ctx.session.create_message( @@ -202,6 +202,22 @@ async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore return result.model_dump_json(indent=2) +class UserResponse(BaseModel): + response: str + + +@mcp.tool() +async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str: + """Use elicitation callback to ask the user a question.""" + + result = await ctx.elicit(message=question, schema=UserResponse) + + if result.action == 'accept' and result.data: + return f'User responded: {result.data.response}' + else: + return f'User {result.action}ed the elicitation' + + @mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] async def set_logging_level(level: str) -> None: global log_level diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 178d0b3627..c73184adea 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -34,7 +34,9 @@ with try_import() as imports_successful: from mcp import ErrorData, McpError, SamplingMessage - from mcp.types import CreateMessageRequestParams, ImageContent, TextContent + from mcp.client.session import ClientSession + from mcp.shared.context import RequestContext + from mcp.types import CreateMessageRequestParams, ElicitRequestParams, ElicitResult, ImageContent, TextContent from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult @@ -74,7 +76,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(16) + assert len(tools) == snapshot(17) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -122,7 +124,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(16) + assert len(tools) == snapshot(17) async def test_process_tool_call(run_context: RunContext[int]) -> int: @@ -297,7 +299,7 @@ async def test_log_level_unset(run_context: RunContext[int]): assert server.log_level is None async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(16) + assert len(tools) == snapshot(17) assert tools[13].name == 'get_log_level' result = await server.direct_call_tool('get_log_level', {}) @@ -1322,3 +1324,40 @@ def test_map_from_mcp_params_model_response(): def test_map_from_model_response(): with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'): map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')])) + + +async def test_elicitation_callback_functionality(run_context: RunContext[int]): + """Test that elicitation callback is actually called and works.""" + # Track callback execution + callback_called = False + callback_message = None + callback_response = 'Yes, proceed with the action' + + async def mock_elicitation_callback( + context: RequestContext[ClientSession, Any, Any], params: ElicitRequestParams + ) -> ElicitResult: + nonlocal callback_called, callback_message + callback_called = True + callback_message = params.message + return ElicitResult(action='accept', content={'response': callback_response}) + + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], elicitation_callback=mock_elicitation_callback) + + async with server: + # Call the tool that uses elicitation + result = await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'}) + + # Verify the callback was called + assert callback_called, 'Elicitation callback should have been called' + assert callback_message == 'Should I continue?', 'Callback should receive the question' + assert result == f'User responded: {callback_response}', 'Tool should return the callback response' + + +async def test_elicitation_callback_not_set(run_context: RunContext[int]): + """Test that elicitation fails when no callback is set.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + async with server: + # Should raise an error when elicitation is attempted without callback + with pytest.raises(ModelRetry, match='Elicitation not supported'): + await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'}) diff --git a/uv.lock b/uv.lock index e8c1988de8..a1a505433b 100644 --- a/uv.lock +++ b/uv.lock @@ -1979,7 +1979,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.12.1" +version = "1.12.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1994,9 +1994,9 @@ dependencies = [ { name = "starlette", marker = "python_full_version >= '3.10'" }, { name = "uvicorn", marker = "python_full_version >= '3.10' and sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/16cef13b2e60d5f865fbc96372efb23dc8b0591f102dd55003b4ae62f9b1/mcp-1.12.1.tar.gz", hash = "sha256:d1d0bdeb09e4b17c1a72b356248bf3baf75ab10db7008ef865c4afbeb0eb810e", size = 425768, upload-time = "2025-07-22T16:51:41.66Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/19/9955e2df5384ff5dd25d38f8e88aaf89d2d3d9d39f27e7383eaf0b293836/mcp-1.12.3.tar.gz", hash = "sha256:ab2e05f5e5c13e1dc90a4a9ef23ac500a6121362a564447855ef0ab643a99fed", size = 427203, upload-time = "2025-07-31T18:36:36.795Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/04/9a967a575518fc958bda1e34a52eae0c7f6accf3534811914fdaf57b0689/mcp-1.12.1-py3-none-any.whl", hash = "sha256:34147f62891417f8b000c39718add844182ba424c8eb2cea250b4267bda4b08b", size = 158463, upload-time = "2025-07-22T16:51:40.086Z" }, + { url = "https://files.pythonhosted.org/packages/8f/8b/0be74e3308a486f1d127f3f6767de5f9f76454c9b4183210c61cc50999b6/mcp-1.12.3-py3-none-any.whl", hash = "sha256:5483345bf39033b858920a5b6348a303acacf45b23936972160ff152107b850e", size = 158810, upload-time = "2025-07-31T18:36:34.915Z" }, ] [package.optional-dependencies] @@ -3554,8 +3554,8 @@ requires-dist = [ { name = "groq", marker = "extra == 'groq'", specifier = ">=0.25.0" }, { name = "httpx", specifier = ">=0.27" }, { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.33.5" }, + { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.12.3" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.14.1" }, - { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.10.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.9.2" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.99.9" }, { name = "opentelemetry-api", specifier = ">=1.28.0" },