diff --git a/src/inMemory.ts b/src/inMemory.ts index 5dd6e81e0..b997965dd 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -1,10 +1,10 @@ import { Transport } from "./shared/transport.js"; -import { JSONRPCMessage, RequestId } from "./types.js"; +import { JSONRPCMessage, RequestId, MessageExtraInfo } from "./types.js"; import { AuthInfo } from "./server/auth/types.js"; interface QueuedMessage { message: JSONRPCMessage; - extra?: { authInfo?: AuthInfo }; + extra?: MessageExtraInfo; } /** @@ -13,10 +13,11 @@ interface QueuedMessage { export class InMemoryTransport implements Transport { private _otherTransport?: InMemoryTransport; private _messageQueue: QueuedMessage[] = []; + private _customContext?: Record; onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; sessionId?: string; /** @@ -34,7 +35,12 @@ export class InMemoryTransport implements Transport { // Process any messages that were queued before start was called while (this._messageQueue.length > 0) { const queuedMessage = this._messageQueue.shift()!; - this.onmessage?.(queuedMessage.message, queuedMessage.extra); + // Merge custom context with queued extra info + const enhancedExtra: MessageExtraInfo = { + ...queuedMessage.extra, + customContext: this._customContext + }; + this.onmessage?.(queuedMessage.message, enhancedExtra); } } @@ -46,18 +52,45 @@ export class InMemoryTransport implements Transport { } /** - * Sends a message with optional auth info. - * This is useful for testing authentication scenarios. + * Sends a message with optional extra info. + * This is useful for testing authentication scenarios and custom context. + * + * @deprecated The authInfo parameter is deprecated. Use MessageExtraInfo instead. */ - async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo }): Promise { + async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo } | MessageExtraInfo): Promise { if (!this._otherTransport) { throw new Error("Not connected"); } + // Handle both old and new API formats + let extra: MessageExtraInfo | undefined; + if (options && 'authInfo' in options && !('requestInfo' in options)) { + // Old API format - convert to new format + extra = { authInfo: options.authInfo }; + } else if (options && ('requestInfo' in options || 'customContext' in options || 'authInfo' in options)) { + // New API format + extra = options as MessageExtraInfo; + } else if (options && 'authInfo' in options) { + // Old API with authInfo + extra = { authInfo: options.authInfo }; + } + if (this._otherTransport.onmessage) { - this._otherTransport.onmessage(message, { authInfo: options?.authInfo }); + // Merge the other transport's custom context with the extra info + const enhancedExtra: MessageExtraInfo = { + ...extra, + customContext: this._otherTransport._customContext + }; + this._otherTransport.onmessage(message, enhancedExtra); } else { - this._otherTransport._messageQueue.push({ message, extra: { authInfo: options?.authInfo } }); + this._otherTransport._messageQueue.push({ message, extra }); } } + + /** + * Sets custom context data that will be passed to all message handlers. + */ + setCustomContext(context: Record): void { + this._customContext = context; + } } diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 10e550df4..4a7629af2 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1433,6 +1433,64 @@ describe("tool()", () => { expect(result.content && result.content[0].text).toContain("Received request ID:"); }); + /*** + * Test: Pass Custom Context to Tool Callback + */ + test("should pass customContext to tool callback via RequestHandlerExtra", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + let receivedCustomContext: Record | undefined; + mcpServer.tool("custom-context-test", async (extra) => { + receivedCustomContext = extra.customContext; + return { + content: [ + { + type: "text", + text: `Custom context: ${JSON.stringify(extra.customContext)}`, + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + // Use the new setCustomContext method to inject custom context + serverTransport.setCustomContext({ + tenantId: "test-tenant-123", + featureFlags: { newFeature: true }, + customData: "test-value" + }); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "custom-context-test", + }, + }, + CallToolResultSchema, + ); + + expect(receivedCustomContext).toBeDefined(); + expect(receivedCustomContext?.tenantId).toBe("test-tenant-123"); + expect(receivedCustomContext?.featureFlags).toEqual({ newFeature: true }); + expect(receivedCustomContext?.customData).toBe("test-value"); + expect(result.content && result.content[0].text).toContain("test-tenant-123"); + }); + /*** * Test: Send Notification within Tool Call */ diff --git a/src/server/sse.ts b/src/server/sse.ts index e07256867..af623cc66 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -41,6 +41,7 @@ export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; private _sessionId: string; private _options: SSEServerTransportOptions; + private _customContext?: Record; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; @@ -191,7 +192,12 @@ export class SSEServerTransport implements Transport { throw error; } - this.onmessage?.(parsedMessage, extra); + // Merge custom context with the extra info + const enhancedExtra: MessageExtraInfo = { + ...extra, + customContext: this._customContext + }; + this.onmessage?.(parsedMessage, enhancedExtra); } async close(): Promise { @@ -218,4 +224,11 @@ export class SSEServerTransport implements Transport { get sessionId(): string { return this._sessionId; } + + /** + * Sets custom context data that will be passed to all message handlers. + */ + setCustomContext(context: Record): void { + this._customContext = context; + } } diff --git a/src/server/stdio.ts b/src/server/stdio.ts index 30c80012e..42411df49 100644 --- a/src/server/stdio.ts +++ b/src/server/stdio.ts @@ -1,7 +1,7 @@ import process from "node:process"; import { Readable, Writable } from "node:stream"; import { ReadBuffer, serializeMessage } from "../shared/stdio.js"; -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, MessageExtraInfo } from "../types.js"; import { Transport } from "../shared/transport.js"; /** @@ -12,6 +12,7 @@ import { Transport } from "../shared/transport.js"; export class StdioServerTransport implements Transport { private _readBuffer: ReadBuffer = new ReadBuffer(); private _started = false; + private _customContext?: Record; constructor( private _stdin: Readable = process.stdin, @@ -20,7 +21,7 @@ export class StdioServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; // Arrow functions to bind `this` properly, while maintaining function identity. _ondata = (chunk: Buffer) => { @@ -54,7 +55,11 @@ export class StdioServerTransport implements Transport { break; } - this.onmessage?.(message); + // Pass custom context to message handlers + const extra: MessageExtraInfo = { + customContext: this._customContext + }; + this.onmessage?.(message, extra); } catch (error) { this.onerror?.(error as Error); } @@ -89,4 +94,11 @@ export class StdioServerTransport implements Transport { } }); } + + /** + * Sets custom context data that will be passed to all message handlers. + */ + setCustomContext(context: Record): void { + this._customContext = context; + } } diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 3bf84e430..02283ba44 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -143,6 +143,7 @@ export class StreamableHTTPServerTransport implements Transport { private _allowedHosts?: string[]; private _allowedOrigins?: string[]; private _enableDnsRebindingProtection: boolean; + private _customContext?: Record; sessionId?: string; onclose?: () => void; @@ -487,7 +488,12 @@ export class StreamableHTTPServerTransport implements Transport { // handle each message for (const message of messages) { - this.onmessage?.(message, { authInfo, requestInfo }); + const enhancedExtra: MessageExtraInfo = { + authInfo, + requestInfo, + customContext: this._customContext + }; + this.onmessage?.(message, enhancedExtra); } } else if (hasRequests) { // The default behavior is to use SSE streaming @@ -522,7 +528,12 @@ export class StreamableHTTPServerTransport implements Transport { // handle each message for (const message of messages) { - this.onmessage?.(message, { authInfo, requestInfo }); + const enhancedExtra: MessageExtraInfo = { + authInfo, + requestInfo, + customContext: this._customContext + }; + this.onmessage?.(message, enhancedExtra); } // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses // This will be handled by the send() method when responses are ready @@ -748,5 +759,12 @@ export class StreamableHTTPServerTransport implements Transport { } } } + + /** + * Sets custom context data that will be passed to all message handlers. + */ + setCustomContext(context: Record): void { + this._customContext = context; + } } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 7df190ba1..5152f335d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -141,6 +141,13 @@ export type RequestHandlerExtra; + /** * Sends a notification that relates to the current request being handled. * @@ -405,7 +412,8 @@ export abstract class Protocol< this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), authInfo: extra?.authInfo, requestId: request.id, - requestInfo: extra?.requestInfo + requestInfo: extra?.requestInfo, + customContext: extra?.customContext }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 386b6bae5..8dc671a39 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -82,4 +82,10 @@ export interface Transport { * Sets the protocol version used for the connection (called when the initialize response is received). */ setProtocolVersion?: (version: string) => void; + + /** + * Sets custom context data that will be passed to all message handlers. + * This context will be included in the MessageExtraInfo passed to handlers. + */ + setCustomContext?: (context: Record) => void; } diff --git a/src/types.ts b/src/types.ts index 323e37389..ef87f3665 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1512,6 +1512,13 @@ export interface MessageExtraInfo { * The authentication information. */ authInfo?: AuthInfo; + + /** + * Custom context data that can be passed through the message handling pipeline. + * This allows transport implementations to attach arbitrary data that will be + * available to request handlers. + */ + customContext?: Record; } /* JSON-RPC types */