diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8f071021d..227048bac 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -139,7 +139,12 @@ def __init__( async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None elicitation = ( - types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None + types.ElicitationCapability( + form=types.FormElicitationCapability(), + url=types.UrlElicitationCapability(), + ) + if self._elicitation_callback is not _default_elicitation_callback + else None ) roots = ( # TODO: Should this be based on whether we @@ -501,6 +506,29 @@ async def list_tools( return result + async def track_elicitation( + self, + elicitation_id: str, + progress_token: types.ProgressToken | None = None, + ) -> types.ElicitTrackResult: + """Send an elicitation/track request to monitor URL mode elicitation progress. + + Args: + elicitation_id: The unique identifier of the elicitation to track + progress_token: Optional token for receiving progress notifications + + Returns: + ElicitTrackResult indicating the status of the elicitation + """ + params = types.ElicitTrackRequestParams(elicitationId=elicitation_id) + if progress_token is not None: + params.meta = types.RequestParams.Meta(progressToken=progress_token) + + return await self.send_request( + types.ClientRequest(types.ElicitTrackRequest(params=params)), + types.ElicitTrackResult, + ) + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 39e3212e9..fbc0ab5be 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -36,6 +36,15 @@ class CancelledElicitation(BaseModel): ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation +class AcceptedUrlElicitation(BaseModel): + """Result when user accepts a URL mode elicitation.""" + + action: Literal["accept"] = "accept" + + +UrlElicitationResult = AcceptedUrlElicitation | DeclinedElicitation | CancelledElicitation + + # Primitive types allowed in elicitation schemas _ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) @@ -79,20 +88,22 @@ async def elicit_with_validation( schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user with schema validation. + """Elicit information from the client/user with schema validation (form mode). This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the user and collect a response according to the provided schema. Or in case a client is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. + + For sensitive data like credentials or OAuth flows, use elicit_url() instead. """ # Validate that schema only contains primitive types and fail loudly if not _validate_elicitation_schema(schema) json_schema = schema.model_json_schema() - result = await session.elicit( + result = await session.elicit_form( message=message, requestedSchema=json_schema, related_request_id=related_request_id, @@ -109,3 +120,51 @@ async def elicit_with_validation( else: # This should never happen, but handle it just in case raise ValueError(f"Unexpected elicitation action: {result.action}") + + +async def elicit_url( + session: ServerSession, + message: str, + url: str, + elicitation_id: str, + related_request_id: RequestId | None = None, +) -> UrlElicitationResult: + """Elicit information from the user via out-of-band URL navigation (URL mode). + + This method directs the user to an external URL where sensitive interactions can + occur without passing data through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band, and you can track progress using + session.track_elicitation(). + + Args: + session: The server session + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + result = await session.elicit_url( + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=related_request_id, + ) + + if result.action == "accept": + return AcceptedUrlElicitation() + elif result.action == "decline": + return DeclinedElicitation() + elif result.action == "cancel": + return CancelledElicitation() + else: + # This should never happen, but handle it just in case + raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 7a99218fa..626cee252 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -260,19 +260,43 @@ async def elicit( requestedSchema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, ) -> types.ElicitResult: - """Send an elicitation/create request. + """Send a form mode elicitation/create request. Args: message: The message to present to the user requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation Returns: The client's response + + Note: + This method is deprecated in favor of elicit_form(). It remains for + backward compatibility but new code should use elicit_form(). + """ + return await self.elicit_form(message, requestedSchema, related_request_id) + + async def elicit_form( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response with form data """ return await self.send_request( types.ServerRequest( types.ElicitRequest( params=types.ElicitRequestParams( + mode="form", message=message, requestedSchema=requestedSchema, ), @@ -282,6 +306,42 @@ async def elicit( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a URL mode elicitation/create request. + + This directs the user to an external URL for out-of-band interactions + like OAuth flows, credential collection, or payment processing. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response indicating acceptance, decline, or cancellation + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + params=types.ElicitRequestParams( + mode="url", + message=message, + url=url, + elicitationId=elicitation_id, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( diff --git a/src/mcp/types.py b/src/mcp/types.py index 871322740..b0be491b2 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -146,9 +146,13 @@ class JSONRPCResponse(BaseModel): model_config = ConfigDict(extra="allow") +# MCP-specific error codes +ELICITATION_REQUIRED = -32000 +"""Error code indicating that an elicitation is required before the request can be processed.""" + # SDK error codes -CONNECTION_CLOSED = -32000 -# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this +CONNECTION_CLOSED = -32001 +# REQUEST_TIMEOUT = -32002 # the typescript sdk uses this # Standard JSON-RPC error codes PARSE_ERROR = -32700 @@ -256,8 +260,29 @@ class SamplingCapability(BaseModel): model_config = ConfigDict(extra="allow") +class FormElicitationCapability(BaseModel): + """Capability for form mode elicitation.""" + + model_config = ConfigDict(extra="allow") + + +class UrlElicitationCapability(BaseModel): + """Capability for URL mode elicitation.""" + + model_config = ConfigDict(extra="allow") + + class ElicitationCapability(BaseModel): - """Capability for elicitation operations.""" + """Capability for elicitation operations. + + Clients must support at least one mode (form or url). + """ + + form: FormElicitationCapability | None = None + """Present if the client supports form mode elicitation.""" + + url: UrlElicitationCapability | None = None + """Present if the client supports URL mode elicitation.""" model_config = ConfigDict(extra="allow") @@ -1247,6 +1272,22 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n params: CancelledNotificationParams +class ElicitTrackRequestParams(RequestParams): + """Parameters for elicitation tracking requests.""" + + elicitationId: str + """The unique identifier of the elicitation to track.""" + + model_config = ConfigDict(extra="allow") + + +class ElicitTrackRequest(Request[ElicitTrackRequestParams, Literal["elicitation/track"]]): + """A request from the client to track progress of a URL mode elicitation.""" + + method: Literal["elicitation/track"] = "elicitation/track" + params: ElicitTrackRequestParams + + class ClientRequest( RootModel[ PingRequest @@ -1262,6 +1303,7 @@ class ClientRequest( | UnsubscribeRequest | CallToolRequest | ListToolsRequest + | ElicitTrackRequest ] ): pass @@ -1279,10 +1321,30 @@ class ClientNotification( class ElicitRequestParams(RequestParams): - """Parameters for elicitation requests.""" + """Parameters for elicitation requests. + + The mode field determines the type of elicitation: + - "form": In-band structured data collection with optional schema validation + - "url": Out-of-band interaction via URL navigation + """ + + mode: Literal["form", "url"] + """The mode of elicitation (form or url).""" message: str - requestedSchema: ElicitRequestedSchema + """A human-readable message explaining why the interaction is needed.""" + + # Form mode fields + requestedSchema: ElicitRequestedSchema | None = None + """JSON Schema defining the structure of expected response (form mode only).""" + + # URL mode fields + url: str | None = None + """The URL that the user should navigate to (url mode only).""" + + elicitationId: str | None = None + """A unique identifier for the elicitation (url mode only).""" + model_config = ConfigDict(extra="allow") @@ -1299,19 +1361,64 @@ class ElicitResult(Result): action: Literal["accept", "decline", "cancel"] """ The user action in response to the elicitation. - - "accept": User submitted the form/confirmed the action + - "accept": User submitted the form/confirmed the action (or consented to URL navigation) - "decline": User explicitly declined the action - "cancel": User dismissed without making an explicit choice """ content: dict[str, str | int | float | bool | None] | None = None """ - The submitted form data, only present when action is "accept". + The submitted form data, only present when action is "accept" in form mode. Contains values matching the requested schema. + For URL mode, this field is omitted. + """ + + +class ElicitTrackResult(Result): + """The server's response to an elicitation tracking request.""" + + status: Literal["pending", "complete"] + """ + The status of the elicitation. + - "pending": The elicitation is still in progress + - "complete": The elicitation has been completed """ + model_config = ConfigDict(extra="allow") + + +class UrlElicitationInfo(BaseModel): + """Information about a URL mode elicitation embedded in an ElicitationRequired error.""" + + mode: Literal["url"] = "url" + """The mode of elicitation (must be "url").""" + + elicitationId: str + """A unique identifier for the elicitation.""" + + url: str + """The URL that the user should navigate to.""" + + message: str + """A human-readable message explaining why the interaction is needed.""" + + model_config = ConfigDict(extra="allow") + + +class ElicitationRequiredErrorData(BaseModel): + """Error data for ElicitationRequired errors. + + Servers return this when a request cannot be processed until one or more + URL mode elicitations are completed. + """ + + elicitations: list[UrlElicitationInfo] + """List of URL mode elicitations that must be completed.""" + + model_config = ConfigDict(extra="allow") + -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult | ElicitTrackResult]): pass diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 896eb1f80..2b7fc9c4b 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -242,6 +242,7 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): # Verify the schema includes defaults schema = params.requestedSchema + assert schema is not None, "Schema should not be None for form mode elicitation" props = schema["properties"] assert props["name"]["default"] == "Guest" diff --git a/tests/server/fastmcp/test_url_elicitation.py b/tests/server/fastmcp/test_url_elicitation.py new file mode 100644 index 000000000..791f441c1 --- /dev/null +++ b/tests/server/fastmcp/test_url_elicitation.py @@ -0,0 +1,231 @@ +"""Test URL mode elicitation feature (SEP 1036).""" + +import pytest + +from mcp.client.session import ClientSession +from mcp.server.elicitation import AcceptedUrlElicitation, DeclinedElicitation +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitRequestParams, ElicitResult, TextContent + + +@pytest.mark.anyio +async def test_url_elicitation_accept(): + """Test URL mode elicitation with user acceptance.""" + mcp = FastMCP(name="URLElicitationServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def request_api_key(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Please provide your API key to continue.", + url="https://example.com/api_key_setup", + elicitation_id="test-elicitation-001", + ) + + if result.action == "accept": + return "User consented to navigate to URL" + elif result.action == "decline": + return "User declined" + else: + return "User cancelled" + + # Create elicitation callback that accepts URL mode + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + assert params.url == "https://example.com/api_key_setup" + assert params.elicitationId == "test-elicitation-001" + assert params.message == "Please provide your API key to continue." + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("request_api_key", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User consented to navigate to URL" + + +@pytest.mark.anyio +async def test_url_elicitation_decline(): + """Test URL mode elicitation with user declining.""" + mcp = FastMCP(name="URLElicitationDeclineServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def oauth_flow(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Authorize access to your files.", + url="https://example.com/oauth/authorize", + elicitation_id="oauth-001", + ) + + if result.action == "accept": + return "User consented" + elif result.action == "decline": + return "User declined authorization" + else: + return "User cancelled" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + return ElicitResult(action="decline") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("oauth_flow", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User declined authorization" + + +@pytest.mark.anyio +async def test_url_elicitation_cancel(): + """Test URL mode elicitation with user cancelling.""" + mcp = FastMCP(name="URLElicitationCancelServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def payment_flow(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Complete payment to proceed.", + url="https://example.com/payment", + elicitation_id="payment-001", + ) + + if result.action == "accept": + return "User consented" + elif result.action == "decline": + return "User declined" + else: + return "User cancelled payment" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + return ElicitResult(action="cancel") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("payment_flow", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User cancelled payment" + + +@pytest.mark.anyio +async def test_url_elicitation_helper_function(): + """Test the elicit_url helper function.""" + from mcp.server.elicitation import elicit_url + + mcp = FastMCP(name="URLElicitationHelperServer") + + @mcp.tool(description="Tool using elicit_url helper") + async def setup_credentials(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Set up your credentials", + url="https://example.com/setup", + elicitation_id="setup-001", + ) + + if isinstance(result, AcceptedUrlElicitation): + return "Accepted" + elif isinstance(result, DeclinedElicitation): + return "Declined" + else: + # Must be CancelledElicitation + return "Cancelled" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("setup_credentials", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Accepted" + + +@pytest.mark.anyio +async def test_url_no_content_in_response(): + """Test that URL mode elicitation responses don't include content field.""" + mcp = FastMCP(name="URLContentCheckServer") + + @mcp.tool(description="Check URL response format") + async def check_url_response(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Test message", + url="https://example.com/test", + elicitation_id="test-001", + ) + + # URL mode responses should not have content + assert result.content is None + return f"Action: {result.action}, Content: {result.content}" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify that no content field is expected for URL mode + assert params.mode == "url" + assert params.requestedSchema is None + # Return without content - this is correct for URL mode + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("check_url_response", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Content: None" in result.content[0].text + + +@pytest.mark.anyio +async def test_form_mode_still_works(): + """Ensure form mode elicitation still works after SEP 1036.""" + from pydantic import BaseModel, Field + + mcp = FastMCP(name="FormModeBackwardCompatServer") + + class NameSchema(BaseModel): + name: str = Field(description="Your name") + + @mcp.tool(description="Test form mode") + async def ask_name(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="What is your name?", schema=NameSchema) + + if result.action == "accept" and result.data: + return f"Hello, {result.data.name}!" + else: + return "No name provided" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify form mode parameters + assert params.mode == "form" + assert params.requestedSchema is not None + assert params.url is None + assert params.elicitationId is None + return ElicitResult(action="accept", content={"name": "Alice"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("ask_name", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Hello, Alice!"