Skip to content

feat: add elicitation callback support to MCP servers #2373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion docs/mcp/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ async def main():
```

1. When you supply `http_client`, Pydantic AI re-uses this client for every
request. Anything supported by **httpx** (`verify`, `cert`, custom
request. Anything supported by **httpx** (`verify`, `cert`, custom
proxies, timeouts, etc.) therefore applies to all MCP traffic.

## MCP Sampling
Expand Down Expand Up @@ -391,3 +391,143 @@ server = MCPServerStdio(
allow_sampling=False,
)
```

## Elicitation

In MCP, [elicitation](https://modelcontextprotocol.io/docs/concepts/elicitation) allows a server to request for [structured input](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) from the client for missing or additional context during a session.

Elicitation let models essentially say "Hold on - I need to know X before i can continue" rather than requiring everything upfront or taking a shot in the dark.

### How Elicitation works

Elicitation introduces a new protocol message type called [`ElicitRequest`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitrequest), which is sent from the server to the client when it needs additional information. The client can then respond with an [`ElicitResult`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitresult) or an `ErrorData` message.

Here's a typical interaction:

- User makes a request to the MCP server (e.g. "Book a table at that Italian place")
- The server identifies that it needs more information (e.g. "Which Italian place?", "What date and time?")
- The server sends an `ElicitRequest` to the client asking for the missing information.
- The client receives the request, presents it to the user (e.g. via a terminal prompt, GUI dialog, or web interface).
- User provides the requested information, `decline` or `cancel` the request.
- The client sends an `ElicitResult` back to the server with the user's response.
- With the structured data, the server can continue processing the original request.

This allows for a more interactive and user-friendly experience, especially for multi-staged workflows. Instead of requiring all information upfront, the server can ask for it as needed, making the interaction feel more natural.

### Setting up Elicitation

To enable elicitation, provide an [`elicitation_callback`][pydantic_ai.mcp.MCPServer.elicitation_callback] function when creating your MCP server instance:

```python {title="restaurant_server.py" py="3.10"}
from mcp.server.fastmcp import Context, FastMCP
from pydantic import BaseModel, Field

mcp = FastMCP(name='Restaurant Booking')


class BookingDetails(BaseModel):
"""Schema for restaurant booking information."""

restaurant: str = Field(description='Choose a restaurant')
party_size: int = Field(description='Number of people', ge=1, le=8)
date: str = Field(description='Reservation date (DD-MM-YYYY)')


@mcp.tool()
async def book_table(ctx: Context) -> str:
"""Book a restaurant table with user input."""
# Ask user for booking details using Pydantic schema
result = await ctx.elicit(message='Please provide your booking details:', schema=BookingDetails)

if result.action == 'accept' and result.data:
booking = result.data
return f'✅ Booked table for {booking.party_size} at {booking.restaurant} on {booking.date}'
elif result.action == 'decline':
return 'No problem! Maybe another time.'
else: # cancel
return 'Booking cancelled.'


if __name__ == '__main__':
mcp.run(transport='stdio')
```

This server demonstrates elicitation by requesting structured booking details from the client when the `book_table` tool is called. Here's how to create a client that handles these elicitation requests:

```python {title="client_example.py" py="3.10" requires="restaurant_server.py" test="skip"}
import asyncio
from typing import Any

from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.types import ElicitRequestParams, ElicitResult

from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerStdio


async def handle_elicitation(
context: RequestContext[ClientSession, Any, Any],
params: ElicitRequestParams,
) -> ElicitResult:
"""Handle elicitation requests from MCP server."""
print(f'\n{params.message}')

if not params.requestedSchema:
response = input('Response: ')
return ElicitResult(action='accept', content={'response': response})

# Collect data for each field
properties = params.requestedSchema['properties']
data = {}

for field, info in properties.items():
description = info.get('description', field)

value = input(f'{description}: ')

# Convert to proper type based on JSON schema
if info.get('type') == 'integer':
data[field] = int(value)
else:
data[field] = value

# Confirm
confirm = input('\nConfirm booking? (y/n/c): ').lower()

if confirm == 'y':
print('Booking details:', data)
return ElicitResult(action='accept', content=data)
elif confirm == 'n':
return ElicitResult(action='decline')
else:
return ElicitResult(action='cancel')


# Set up MCP server connection
restaurant_server = MCPServerStdio(
command='python', args=['restaurant_server.py'], elicitation_callback=handle_elicitation
)

# Create agent
agent = Agent('openai:gpt-4o', toolsets=[restaurant_server])


async def main():
"""Run the agent to book a restaurant table."""
async with agent:
result = await agent.run('Book me a table')
print(f'\nResult: {result.output}')


if __name__ == '__main__':
asyncio.run(main())
```

### Supported Schema Types

MCP elicitation supports string, number, boolean, and enum types with flat object structures only. These limitations ensure reliable cross-client compatibility. See [supported schema types](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) for details.
Comment on lines +527 to +529
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this section is valuable since the schema type limitations are key constraint that everybody should know. should we move it earlier or completely remove it?


### Security

MCP Elicitation requires careful handling - servers must not request sensitive information, and clients must implement user approval controls with clear explanations. See [security considerations](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#security-considerations) for details.
21 changes: 18 additions & 3 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from typing_extensions import Self, assert_never, deprecated

from pydantic_ai._run_context import RunContext
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.tools import RunContext, ToolDefinition

from .toolsets.abstract import AbstractToolset, ToolsetTool

try:
from mcp import types as mcp_types
from mcp.client.session import ClientSession, LoggingFnT
from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
Expand Down Expand Up @@ -65,6 +64,7 @@ class MCPServer(AbstractToolset[Any], ABC):
allow_sampling: bool
sampling_model: models.Model | None
max_retries: int
elicitation_callback: ElicitationFnT | None = None

_id: str | None

Expand All @@ -87,6 +87,7 @@ def __init__(
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
*,
id: str | None = None,
):
Expand All @@ -99,6 +100,7 @@ def __init__(
self.allow_sampling = allow_sampling
self.sampling_model = sampling_model
self.max_retries = max_retries
self.elicitation_callback = elicitation_callback

self._id = id or tool_prefix

Expand Down Expand Up @@ -247,6 +249,7 @@ async def __aenter__(self) -> Self:
read_stream=self._read_stream,
write_stream=self._write_stream,
sampling_callback=self._sampling_callback if self.allow_sampling else None,
elicitation_callback=self.elicitation_callback,
logging_callback=self.log_handler,
read_timeout_seconds=timedelta(seconds=self.read_timeout),
)
Expand Down Expand Up @@ -445,6 +448,9 @@ async def main():
max_retries: int
"""The maximum number of times to retry a tool call."""

elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

def __init__(
self,
command: str,
Expand All @@ -460,6 +466,7 @@ def __init__(
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
*,
id: str | None = None,
):
Expand All @@ -479,6 +486,7 @@ def __init__(
allow_sampling: Whether to allow MCP sampling through this client.
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
elicitation_callback: Callback function to handle elicitation requests from the server.
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
"""
self.command = command
Expand All @@ -496,6 +504,7 @@ def __init__(
allow_sampling,
sampling_model,
max_retries,
elicitation_callback,
id=id,
)

Expand Down Expand Up @@ -605,6 +614,9 @@ class _MCPServerHTTP(MCPServer):
max_retries: int
"""The maximum number of times to retry a tool call."""

elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

def __init__(
self,
*,
Expand All @@ -621,6 +633,7 @@ def __init__(
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
**_deprecated_kwargs: Any,
):
"""Build a new MCP server.
Expand All @@ -639,6 +652,7 @@ def __init__(
allow_sampling: Whether to allow MCP sampling through this client.
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
elicitation_callback: Callback function to handle elicitation requests from the server.
"""
if 'sse_read_timeout' in _deprecated_kwargs:
if read_timeout is not None:
Expand Down Expand Up @@ -668,6 +682,7 @@ def __init__(
allow_sampling,
sampling_model,
max_retries,
elicitation_callback,
id=id,
)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
mcp = ["mcp>=1.12.3; python_version >= '3.10'"]
# Evals
evals = ["pydantic-evals=={{ version }}"]
# A2A
Expand Down
20 changes: 18 additions & 2 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TextContent,
TextResourceContents,
)
from pydantic import AnyUrl
from pydantic import AnyUrl, BaseModel

mcp = FastMCP('Pydantic AI MCP Server')
log_level = 'unset'
Expand Down Expand Up @@ -186,7 +186,7 @@ async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) ->


@mcp.tool()
async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
async def use_sampling(ctx: Context[ServerSessionT, LifespanContextT, RequestT], foo: str) -> str:
"""Use sampling callback."""

result = await ctx.session.create_message(
Expand All @@ -202,6 +202,22 @@ async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
return result.model_dump_json(indent=2)


class UserResponse(BaseModel):
response: str


@mcp.tool()
async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the correct type.

It's probably this:

Suggested change
async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str:
async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> str:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i noticed the other mcp test functions (echo_deps, use_sampling) all use Context[ServerSessionT, LifespanContextT, RequestT] as types. for consistency, should we update other functions or is there a reason use_elicitation should use concrete types instead of generics?

"""Use elicitation callback to ask the user a question."""

result = await ctx.elicit(message=question, schema=UserResponse)

if result.action == 'accept' and result.data:
return f'User responded: {result.data.response}'
else:
return f'User {result.action}ed the elicitation'


@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage]
async def set_logging_level(level: str) -> None:
global log_level
Expand Down
47 changes: 43 additions & 4 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@

with try_import() as imports_successful:
from mcp import ErrorData, McpError, SamplingMessage
from mcp.types import CreateMessageRequestParams, ImageContent, TextContent
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.types import CreateMessageRequestParams, ElicitRequestParams, ElicitResult, ImageContent, TextContent

from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response
from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult
Expand Down Expand Up @@ -74,7 +76,7 @@ async def test_stdio_server(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(16)
assert len(tools) == snapshot(17)
assert tools[0].name == 'celsius_to_fahrenheit'
assert isinstance(tools[0].description, str)
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')
Expand Down Expand Up @@ -122,7 +124,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]):
server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir)
async with server:
tools = await server.get_tools(run_context)
assert len(tools) == snapshot(16)
assert len(tools) == snapshot(17)


async def test_process_tool_call(run_context: RunContext[int]) -> int:
Expand Down Expand Up @@ -297,7 +299,7 @@ async def test_log_level_unset(run_context: RunContext[int]):
assert server.log_level is None
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(16)
assert len(tools) == snapshot(17)
assert tools[13].name == 'get_log_level'

result = await server.direct_call_tool('get_log_level', {})
Expand Down Expand Up @@ -1322,3 +1324,40 @@ def test_map_from_mcp_params_model_response():
def test_map_from_model_response():
with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'):
map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')]))


async def test_elicitation_callback_functionality(run_context: RunContext[int]):
"""Test that elicitation callback is actually called and works."""
# Track callback execution
callback_called = False
callback_message = None
callback_response = 'Yes, proceed with the action'

async def mock_elicitation_callback(
context: RequestContext[ClientSession, Any, Any], params: ElicitRequestParams
) -> ElicitResult:
nonlocal callback_called, callback_message
callback_called = True
callback_message = params.message
return ElicitResult(action='accept', content={'response': callback_response})

server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], elicitation_callback=mock_elicitation_callback)

async with server:
# Call the tool that uses elicitation
result = await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'})

# Verify the callback was called
assert callback_called, 'Elicitation callback should have been called'
assert callback_message == 'Should I continue?', 'Callback should receive the question'
assert result == f'User responded: {callback_response}', 'Tool should return the callback response'


async def test_elicitation_callback_not_set(run_context: RunContext[int]):
"""Test that elicitation fails when no callback is set."""
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])

async with server:
# Should raise an error when elicitation is attempted without callback
with pytest.raises(ModelRetry, match='Elicitation not supported'):
await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'})
Loading