|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections.abc import Iterable |
| 4 | +from typing import TYPE_CHECKING, Any, Generic, Literal |
| 5 | + |
| 6 | +from pydantic import AnyUrl, BaseModel |
| 7 | + |
| 8 | +from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext |
| 9 | +from mcp.server.elicitation import ( |
| 10 | + ElicitationResult, |
| 11 | + ElicitSchemaModelT, |
| 12 | + UrlElicitationResult, |
| 13 | + elicit_url, |
| 14 | + elicit_with_validation, |
| 15 | +) |
| 16 | +from mcp.server.lowlevel.helper_types import ReadResourceContents |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from mcp.server.mcpserver.server import MCPServer |
| 20 | + |
| 21 | + |
| 22 | +class Context(BaseModel, Generic[LifespanContextT, RequestT]): |
| 23 | + """Context object providing access to MCP capabilities. |
| 24 | +
|
| 25 | + This provides a cleaner interface to MCP's RequestContext functionality. |
| 26 | + It gets injected into tool and resource functions that request it via type hints. |
| 27 | +
|
| 28 | + To use context in a tool function, add a parameter with the Context type annotation: |
| 29 | +
|
| 30 | + ```python |
| 31 | + @server.tool() |
| 32 | + async def my_tool(x: int, ctx: Context) -> str: |
| 33 | + # Log messages to the client |
| 34 | + await ctx.info(f"Processing {x}") |
| 35 | + await ctx.debug("Debug info") |
| 36 | + await ctx.warning("Warning message") |
| 37 | + await ctx.error("Error message") |
| 38 | +
|
| 39 | + # Report progress |
| 40 | + await ctx.report_progress(50, 100) |
| 41 | +
|
| 42 | + # Access resources |
| 43 | + data = await ctx.read_resource("resource://data") |
| 44 | +
|
| 45 | + # Get request info |
| 46 | + request_id = ctx.request_id |
| 47 | + client_id = ctx.client_id |
| 48 | +
|
| 49 | + return str(x) |
| 50 | + ``` |
| 51 | +
|
| 52 | + The context parameter name can be anything as long as it's annotated with Context. |
| 53 | + The context is optional - tools that don't need it can omit the parameter. |
| 54 | + """ |
| 55 | + |
| 56 | + _request_context: ServerRequestContext[LifespanContextT, RequestT] | None |
| 57 | + _mcp_server: MCPServer | None |
| 58 | + |
| 59 | + # TODO(maxisbey): Consider making request_context/mcp_server required, or refactor Context entirely. |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + *, |
| 63 | + request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, |
| 64 | + mcp_server: MCPServer | None = None, |
| 65 | + # TODO(Marcelo): We should drop this kwargs parameter. |
| 66 | + **kwargs: Any, |
| 67 | + ): |
| 68 | + super().__init__(**kwargs) |
| 69 | + self._request_context = request_context |
| 70 | + self._mcp_server = mcp_server |
| 71 | + |
| 72 | + @property |
| 73 | + def mcp_server(self) -> MCPServer: |
| 74 | + """Access to the MCPServer instance.""" |
| 75 | + if self._mcp_server is None: # pragma: no cover |
| 76 | + raise ValueError("Context is not available outside of a request") |
| 77 | + return self._mcp_server # pragma: no cover |
| 78 | + |
| 79 | + @property |
| 80 | + def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: |
| 81 | + """Access to the underlying request context.""" |
| 82 | + if self._request_context is None: # pragma: no cover |
| 83 | + raise ValueError("Context is not available outside of a request") |
| 84 | + return self._request_context |
| 85 | + |
| 86 | + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: |
| 87 | + """Report progress for the current operation. |
| 88 | +
|
| 89 | + Args: |
| 90 | + progress: Current progress value (e.g., 24) |
| 91 | + total: Optional total value (e.g., 100) |
| 92 | + message: Optional message (e.g., "Starting render...") |
| 93 | + """ |
| 94 | + progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None |
| 95 | + |
| 96 | + if progress_token is None: # pragma: no cover |
| 97 | + return |
| 98 | + |
| 99 | + await self.request_context.session.send_progress_notification( |
| 100 | + progress_token=progress_token, |
| 101 | + progress=progress, |
| 102 | + total=total, |
| 103 | + message=message, |
| 104 | + related_request_id=self.request_id, |
| 105 | + ) |
| 106 | + |
| 107 | + async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: |
| 108 | + """Read a resource by URI. |
| 109 | +
|
| 110 | + Args: |
| 111 | + uri: Resource URI to read |
| 112 | +
|
| 113 | + Returns: |
| 114 | + The resource content as either text or bytes |
| 115 | + """ |
| 116 | + assert self._mcp_server is not None, "Context is not available outside of a request" |
| 117 | + return await self._mcp_server.read_resource(uri, self) |
| 118 | + |
| 119 | + async def elicit( |
| 120 | + self, |
| 121 | + message: str, |
| 122 | + schema: type[ElicitSchemaModelT], |
| 123 | + ) -> ElicitationResult[ElicitSchemaModelT]: |
| 124 | + """Elicit information from the client/user. |
| 125 | +
|
| 126 | + This method can be used to interactively ask for additional information from the |
| 127 | + client within a tool's execution. The client might display the message to the |
| 128 | + user and collect a response according to the provided schema. If the client |
| 129 | + is an agent, it might decide how to handle the elicitation -- either by asking |
| 130 | + the user or automatically generating a response. |
| 131 | +
|
| 132 | + Args: |
| 133 | + message: Message to present to the user |
| 134 | + schema: A Pydantic model class defining the expected response structure. |
| 135 | + According to the specification, only primitive types are allowed. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + An ElicitationResult containing the action taken and the data if accepted |
| 139 | +
|
| 140 | + Note: |
| 141 | + Check the result.action to determine if the user accepted, declined, or cancelled. |
| 142 | + The result.data will only be populated if action is "accept" and validation succeeded. |
| 143 | + """ |
| 144 | + |
| 145 | + return await elicit_with_validation( |
| 146 | + session=self.request_context.session, |
| 147 | + message=message, |
| 148 | + schema=schema, |
| 149 | + related_request_id=self.request_id, |
| 150 | + ) |
| 151 | + |
| 152 | + async def elicit_url( |
| 153 | + self, |
| 154 | + message: str, |
| 155 | + url: str, |
| 156 | + elicitation_id: str, |
| 157 | + ) -> UrlElicitationResult: |
| 158 | + """Request URL mode elicitation from the client. |
| 159 | +
|
| 160 | + This directs the user to an external URL for out-of-band interactions |
| 161 | + that must not pass through the MCP client. Use this for: |
| 162 | + - Collecting sensitive credentials (API keys, passwords) |
| 163 | + - OAuth authorization flows with third-party services |
| 164 | + - Payment and subscription flows |
| 165 | + - Any interaction where data should not pass through the LLM context |
| 166 | +
|
| 167 | + The response indicates whether the user consented to navigate to the URL. |
| 168 | + The actual interaction happens out-of-band. When the elicitation completes, |
| 169 | + call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. |
| 170 | +
|
| 171 | + Args: |
| 172 | + message: Human-readable explanation of why the interaction is needed |
| 173 | + url: The URL the user should navigate to |
| 174 | + elicitation_id: Unique identifier for tracking this elicitation |
| 175 | +
|
| 176 | + Returns: |
| 177 | + UrlElicitationResult indicating accept, decline, or cancel |
| 178 | + """ |
| 179 | + return await elicit_url( |
| 180 | + session=self.request_context.session, |
| 181 | + message=message, |
| 182 | + url=url, |
| 183 | + elicitation_id=elicitation_id, |
| 184 | + related_request_id=self.request_id, |
| 185 | + ) |
| 186 | + |
| 187 | + async def log( |
| 188 | + self, |
| 189 | + level: Literal["debug", "info", "warning", "error"], |
| 190 | + message: str, |
| 191 | + *, |
| 192 | + logger_name: str | None = None, |
| 193 | + extra: dict[str, Any] | None = None, |
| 194 | + ) -> None: |
| 195 | + """Send a log message to the client. |
| 196 | +
|
| 197 | + Args: |
| 198 | + level: Log level (debug, info, warning, error) |
| 199 | + message: Log message |
| 200 | + logger_name: Optional logger name |
| 201 | + extra: Optional dictionary with additional structured data to include |
| 202 | + """ |
| 203 | + |
| 204 | + if extra: |
| 205 | + log_data = {"message": message, **extra} |
| 206 | + else: |
| 207 | + log_data = message |
| 208 | + |
| 209 | + await self.request_context.session.send_log_message( |
| 210 | + level=level, |
| 211 | + data=log_data, |
| 212 | + logger=logger_name, |
| 213 | + related_request_id=self.request_id, |
| 214 | + ) |
| 215 | + |
| 216 | + @property |
| 217 | + def client_id(self) -> str | None: |
| 218 | + """Get the client ID if available.""" |
| 219 | + return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover |
| 220 | + |
| 221 | + @property |
| 222 | + def request_id(self) -> str: |
| 223 | + """Get the unique ID for this request.""" |
| 224 | + return str(self.request_context.request_id) |
| 225 | + |
| 226 | + @property |
| 227 | + def session(self): |
| 228 | + """Access to the underlying session for advanced usage.""" |
| 229 | + return self.request_context.session |
| 230 | + |
| 231 | + async def close_sse_stream(self) -> None: |
| 232 | + """Close the SSE stream to trigger client reconnection. |
| 233 | +
|
| 234 | + This method closes the HTTP connection for the current request, triggering |
| 235 | + client reconnection. Events continue to be stored in the event store and will |
| 236 | + be replayed when the client reconnects with Last-Event-ID. |
| 237 | +
|
| 238 | + Use this to implement polling behavior during long-running operations - |
| 239 | + the client will reconnect after the retry interval specified in the priming event. |
| 240 | +
|
| 241 | + Note: |
| 242 | + This is a no-op if not using StreamableHTTP transport with event_store. |
| 243 | + The callback is only available when event_store is configured. |
| 244 | + """ |
| 245 | + if self._request_context and self._request_context.close_sse_stream: # pragma: no cover |
| 246 | + await self._request_context.close_sse_stream() |
| 247 | + |
| 248 | + async def close_standalone_sse_stream(self) -> None: |
| 249 | + """Close the standalone GET SSE stream to trigger client reconnection. |
| 250 | +
|
| 251 | + This method closes the HTTP connection for the standalone GET stream used |
| 252 | + for unsolicited server-to-client notifications. The client SHOULD reconnect |
| 253 | + with Last-Event-ID to resume receiving notifications. |
| 254 | +
|
| 255 | + Note: |
| 256 | + This is a no-op if not using StreamableHTTP transport with event_store. |
| 257 | + Currently, client reconnection for standalone GET streams is NOT |
| 258 | + implemented - this is a known gap. |
| 259 | + """ |
| 260 | + if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover |
| 261 | + await self._request_context.close_standalone_sse_stream() |
| 262 | + |
| 263 | + # Convenience methods for common log levels |
| 264 | + async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: |
| 265 | + """Send a debug log message.""" |
| 266 | + await self.log("debug", message, logger_name=logger_name, extra=extra) |
| 267 | + |
| 268 | + async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: |
| 269 | + """Send an info log message.""" |
| 270 | + await self.log("info", message, logger_name=logger_name, extra=extra) |
| 271 | + |
| 272 | + async def warning( |
| 273 | + self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None |
| 274 | + ) -> None: |
| 275 | + """Send a warning log message.""" |
| 276 | + await self.log("warning", message, logger_name=logger_name, extra=extra) |
| 277 | + |
| 278 | + async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: |
| 279 | + """Send an error log message.""" |
| 280 | + await self.log("error", message, logger_name=logger_name, extra=extra) |
0 commit comments