From deb2b9d1c81b11b72042eeb37f57fc310e7e3210 Mon Sep 17 00:00:00 2001 From: Maria Tigina Date: Thu, 24 Jul 2025 19:16:51 +0200 Subject: [PATCH 1/6] Introduce server session --- .../kotlin/sdk/client/SSEClientTransport.kt | 7 + .../kotlin/sdk/server/KtorServer.kt | 36 +- .../kotlin/sdk/server/Server.kt | 336 ++++++++++++------ .../kotlin/sdk/server/ServerSession.kt | 0 .../sdk/integration/SseBugReproductionTest.kt | 0 5 files changed, 260 insertions(+), 119 deletions(-) create mode 100644 kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt create mode 100644 kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseBugReproductionTest.kt diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index d30f5288..20b7382f 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.plugins.sse.ClientSSESession import io.ktor.client.plugins.sse.sseSession @@ -46,6 +47,8 @@ public class SseClientTransport( private val reconnectionTime: Duration? = null, private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val initialized: AtomicBoolean = AtomicBoolean(false) private val endpoint = CompletableDeferred() @@ -68,6 +71,7 @@ public class SseClientTransport( check(initialized.compareAndSet(expectedValue = false, newValue = true)) { "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically." } + logger.info { "Starting SseClientTransport..." } try { session = urlString?.let { @@ -111,6 +115,8 @@ public class SseClientTransport( val text = response.bodyAsText() error("Error POSTing to endpoint (HTTP ${response.status}): $text") } + + logger.debug { "Client successfully sent message via SSE $endpoint" } } catch (e: Throwable) { _onError(e) throw e @@ -157,6 +163,7 @@ public class SseClientTransport( val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData val endpointUrl = Url("$baseUrl/$path") endpoint.complete(endpointUrl.toString()) + logger.debug { "Client connected to endpoint: $endpointUrl" } } catch (e: Throwable) { _onError(e) endpoint.completeExceptionally(e) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 056c7854..882d4500 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -13,8 +13,12 @@ import io.ktor.server.routing.routing import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession import io.ktor.server.sse.sse -import io.ktor.util.collections.ConcurrentMap import io.ktor.utils.io.KtorDsl +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentMap +import kotlinx.collections.immutable.persistentMapOf private val logger = KotlinLogging.logger {} @@ -30,7 +34,7 @@ public fun Routing.mcp(path: String, block: () -> Server) { */ @KtorDsl public fun Routing.mcp(block: () -> Server) { - val transports = ConcurrentMap() + val transports = atomic(persistentMapOf()) sse { mcpSseEndpoint("", transports, block) @@ -49,24 +53,16 @@ public fun Application.MCP(block: () -> Server) { @KtorDsl public fun Application.mcp(block: () -> Server) { - val transports = ConcurrentMap() - install(SSE) routing { - sse("/sse") { - mcpSseEndpoint("/message", transports, block) - } - - post("/message") { - mcpPostEndpoint(transports) - } + mcp(block) } } -private suspend fun ServerSSESession.mcpSseEndpoint( +internal suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, - transports: ConcurrentMap, + transports: AtomicRef>, block: () -> Server, ) { val transport = mcpSseTransport(postEndpoint, transports) @@ -75,27 +71,27 @@ private suspend fun ServerSSESession.mcpSseEndpoint( server.onClose { logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } - transports.remove(transport.sessionId) + transports.update { it.remove(transport.sessionId) } } - server.connect(transport) + server.connectSession(transport) + logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } } internal fun ServerSSESession.mcpSseTransport( postEndpoint: String, - transports: ConcurrentMap, + transports: AtomicRef>, ): SseServerTransport { val transport = SseServerTransport(postEndpoint, this) - transports[transport.sessionId] = transport - + transports.update { it.put(transport.sessionId, transport) } logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } return transport } internal suspend fun RoutingContext.mcpPostEndpoint( - transports: ConcurrentMap, + transports: AtomicRef>, ) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { @@ -105,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint( logger.debug { "Received message for sessionId: $sessionId" } - val transport = transports[sessionId] + val transport = transports.value[sessionId] if (transport == null) { logger.warn { "Session not found for sessionId: $sessionId" } call.respond(HttpStatusCode.NotFound, "Session not found") diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index f0655fd9..205c59a1 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -1,57 +1,18 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging -import io.modelcontextprotocol.kotlin.sdk.CallToolRequest -import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest +import io.modelcontextprotocol.kotlin.sdk.* import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult -import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest -import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult -import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject -import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult -import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest -import io.modelcontextprotocol.kotlin.sdk.GetPromptResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeRequest -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION -import io.modelcontextprotocol.kotlin.sdk.ListPromptsRequest -import io.modelcontextprotocol.kotlin.sdk.ListPromptsResult -import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest -import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult -import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest -import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult -import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest -import io.modelcontextprotocol.kotlin.sdk.ListRootsResult -import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest -import io.modelcontextprotocol.kotlin.sdk.ListToolsResult -import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification -import io.modelcontextprotocol.kotlin.sdk.Method -import io.modelcontextprotocol.kotlin.sdk.PingRequest -import io.modelcontextprotocol.kotlin.sdk.Prompt -import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest -import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult -import io.modelcontextprotocol.kotlin.sdk.Resource -import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification -import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.Tool -import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations -import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import io.modelcontextprotocol.kotlin.sdk.shared.Transport import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus import kotlinx.collections.immutable.persistentMapOf +import kotlinx.collections.immutable.persistentSetOf import kotlinx.collections.immutable.toPersistentSet import kotlinx.coroutines.CompletableDeferred import kotlinx.serialization.json.JsonObject @@ -83,23 +44,35 @@ public open class Server( private val serverInfo: Implementation, options: ServerOptions, ) : Protocol(options) { + private val sessions = atomic(persistentSetOf()) + private val serverOptions = options + private var _onInitialized: (() -> Unit) = {} + private var _onConnect: (() -> Unit) = {} private var _onClose: () -> Unit = {} /** * The client's reported capabilities after initialization. */ + @Deprecated( + "Moved to ServerSession", + ReplaceWith("ServerSession.clientCapabilities"), + DeprecationLevel.WARNING + ) public var clientCapabilities: ClientCapabilities? = null private set /** * The client's version information after initialization. */ + @Deprecated( + "Moved to ServerSession", + ReplaceWith("ServerSession.clientVersion"), + DeprecationLevel.WARNING + ) public var clientVersion: Implementation? = null private set - private val capabilities: ServerCapabilities = options.capabilities - private val _tools = atomic(persistentMapOf()) private val _prompts = atomic(persistentMapOf()) private val _resources = atomic(persistentMapOf()) @@ -111,8 +84,9 @@ public open class Server( get() = _resources.value init { - logger.debug { "Initializing MCP server with capabilities: $capabilities" } + logger.debug { "Initializing MCP server with option: $options" } + // TODO: Remove all after Protocol inheritance // Core protocol handlers setRequestHandler(Method.Defined.Initialize) { request, _ -> handleInitialize(request) @@ -123,7 +97,7 @@ public open class Server( } // Internal handlers for tools - if (capabilities.tools != null) { + if (serverOptions.capabilities.tools != null) { setRequestHandler(Method.Defined.ToolsList) { _, _ -> handleListTools() } @@ -133,7 +107,7 @@ public open class Server( } // Internal handlers for prompts - if (capabilities.prompts != null) { + if (serverOptions.capabilities.prompts != null) { setRequestHandler(Method.Defined.PromptsList) { _, _ -> handleListPrompts() } @@ -143,7 +117,7 @@ public open class Server( } // Internal handlers for resources - if (capabilities.resources != null) { + if (serverOptions.capabilities.resources != null) { setRequestHandler(Method.Defined.ResourcesList) { _, _ -> handleListResources() } @@ -156,9 +130,102 @@ public open class Server( } } + @Deprecated( + "Will be removed with Protocol inheritance. Use connectSession instead.", + ReplaceWith("connectSession"), + DeprecationLevel.WARNING + ) + public override fun onClose() { + logger.debug { "Closing MCP server" } + _onClose() + } + + // TODO: Rename closeSessions to close after the full onClose deprecation + public suspend fun closeSessions() { + logger.debug { "Closing MCP server" } + sessions.value.forEach { it.close() } + _onClose() + } + + /** + * Starts a new server session with the given transport and initializes + * internal request handlers based on the server's capabilities. + * + * @param transport The transport layer to connect the session with. + * @return The initialized and connected server session. + */ + // TODO: Rename connectSession to connect after the full connect deprecation + public suspend fun connectSession(transport: Transport): ServerSession { + val session = ServerSession(serverInfo, serverOptions) + + // Internal handlers for tools + if (serverOptions.capabilities.tools != null) { + session.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + handleListTools() + } + session.setRequestHandler(Method.Defined.ToolsCall) { request, _ -> + handleCallTool(request) + } + } + + // Internal handlers for prompts + if (serverOptions.capabilities.prompts != null) { + session.setRequestHandler(Method.Defined.PromptsList) { _, _ -> + handleListPrompts() + } + session.setRequestHandler(Method.Defined.PromptsGet) { request, _ -> + handleGetPrompt(request) + } + } + + // Internal handlers for resources + if (serverOptions.capabilities.resources != null) { + session.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + handleListResources() + } + session.setRequestHandler(Method.Defined.ResourcesRead) { request, _ -> + handleReadResource(request) + } + session.setRequestHandler(Method.Defined.ResourcesTemplatesList) { _, _ -> + handleListResourceTemplates() + } + } + + session.connect(transport) + sessions.update { it.add(session) } + + _onConnect() + return session + } + + @Deprecated( + "Will be removed with Protocol inheritance. Use connectSession instead.", + ReplaceWith("connectSession"), + DeprecationLevel.WARNING + ) + public override suspend fun connect(transport: Transport) { + super.connect(transport) + } + + /** + * Registers a callback to be invoked when the new server session connected. + */ + public fun onConnect(block: () -> Unit) { + val old = _onConnect + _onConnect = { + old() + block() + } + } + /** * Registers a callback to be invoked when the server has completed initialization. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use onConnect instead.", + ReplaceWith("onConnect"), + DeprecationLevel.WARNING + ) public fun onInitialized(block: () -> Unit) { val old = _onInitialized _onInitialized = { @@ -178,14 +245,6 @@ public open class Server( } } - /** - * Called when the server connection is closing. - */ - override fun onClose() { - logger.info { "Server connection closing" } - _onClose() - } - /** * Registers a single tool. The client can then call this tool. * @@ -194,7 +253,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun addTool(tool: Tool, handler: suspend (CallToolRequest) -> CallToolResult) { - if (capabilities.tools == null) { + if (serverOptions.capabilities.tools == null) { logger.error { "Failed to add tool '${tool.name}': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.") } @@ -234,7 +293,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun addTools(toolsToAdd: List) { - if (capabilities.tools == null) { + if (serverOptions.capabilities.tools == null) { logger.error { "Failed to add tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -250,7 +309,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun removeTool(name: String): Boolean { - if (capabilities.tools == null) { + if (serverOptions.capabilities.tools == null) { logger.error { "Failed to remove tool '$name': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -277,7 +336,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun removeTools(toolNames: List): Int { - if (capabilities.tools == null) { + if (serverOptions.capabilities.tools == null) { logger.error { "Failed to remove tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -304,7 +363,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompt(prompt: Prompt, promptProvider: suspend (GetPromptRequest) -> GetPromptResult) { - if (capabilities.prompts == null) { + if (serverOptions.capabilities.prompts == null) { logger.error { "Failed to add prompt '${prompt.name}': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -338,7 +397,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompts(promptsToAdd: List) { - if (capabilities.prompts == null) { + if (serverOptions.capabilities.prompts == null) { logger.error { "Failed to add prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -354,7 +413,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompt(name: String): Boolean { - if (capabilities.prompts == null) { + if (serverOptions.capabilities.prompts == null) { logger.error { "Failed to remove prompt '$name': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -381,7 +440,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompts(promptNames: List): Int { - if (capabilities.prompts == null) { + if (serverOptions.capabilities.prompts == null) { logger.error { "Failed to remove prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -418,7 +477,7 @@ public open class Server( mimeType: String = "text/html", readHandler: suspend (ReadResourceRequest) -> ReadResourceResult ) { - if (capabilities.resources == null) { + if (serverOptions.capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -438,7 +497,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun addResources(resourcesToAdd: List) { - if (capabilities.resources == null) { + if (serverOptions.capabilities.resources == null) { logger.error { "Failed to add resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -454,7 +513,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun removeResource(uri: String): Boolean { - if (capabilities.resources == null) { + if (serverOptions.capabilities.resources == null) { logger.error { "Failed to remove resource '$uri': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -481,7 +540,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun removeResources(uris: List): Int { - if (capabilities.resources == null) { + if (serverOptions.capabilities.resources == null) { logger.error { "Failed to remove resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -507,6 +566,11 @@ public open class Server( * @return The result of the ping request. * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.ping instead.", + ReplaceWith("session.ping"), + DeprecationLevel.WARNING + ) public suspend fun ping(): EmptyRequestResult { return request(PingRequest()) } @@ -519,6 +583,11 @@ public open class Server( * @return The created message result. * @throws IllegalStateException If the server does not support sampling or if the request fails. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.createMessage instead.", + ReplaceWith("session.createMessage"), + DeprecationLevel.WARNING + ) public suspend fun createMessage( params: CreateMessageRequest, options: RequestOptions? = null @@ -535,6 +604,11 @@ public open class Server( * @return The list of roots. * @throws IllegalStateException If the server or client does not support roots. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.listRoots instead.", + ReplaceWith("session.listRoots"), + DeprecationLevel.WARNING + ) public suspend fun listRoots( params: JsonObject = EmptyJsonObject, options: RequestOptions? = null @@ -543,6 +617,19 @@ public open class Server( return request(ListRootsRequest(params), options) } + /** + * Creates an elicitation request with the specified message and schema. + * + * @param message The message to be used for the elicitation. + * @param requestedSchema The schema defining the structure of the requested elicitation. + * @param options Optional parameters to customize the elicitation request. + * @return Returns the result of the elicitation creation process. + */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.createElicitation instead.", + ReplaceWith("session.createElicitation"), + DeprecationLevel.WARNING + ) public suspend fun createElicitation( message: String, requestedSchema: RequestedSchema, @@ -557,7 +644,14 @@ public open class Server( * * @param params The logging message notification parameters. */ - public suspend fun sendLoggingMessage(params: LoggingMessageNotification) { + @Deprecated( + "Will be removed with Protocol inheritance. Use session.sendLoggingMessage instead.", + ReplaceWith("session.sendLoggingMessage"), + DeprecationLevel.WARNING + ) + public suspend fun sendLoggingMessage( + params: LoggingMessageNotification + ) { logger.trace { "Sending logging message: ${params.params.data}" } notification(params) } @@ -567,6 +661,11 @@ public open class Server( * * @param params Details of the updated resource. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.sendResourceUpdated instead.", + ReplaceWith("session.sendResourceUpdated"), + DeprecationLevel.WARNING + ) public suspend fun sendResourceUpdated(params: ResourceUpdatedNotification) { logger.debug { "Sending resource updated notification for: ${params.params.uri}" } notification(params) @@ -575,6 +674,11 @@ public open class Server( /** * Sends a notification to the client indicating that the list of resources has changed. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.sendResourceListChanged instead.", + ReplaceWith("session.sendResourceListChanged"), + DeprecationLevel.WARNING + ) public suspend fun sendResourceListChanged() { logger.debug { "Sending resource list changed notification" } notification(ResourceListChangedNotification()) @@ -583,6 +687,11 @@ public open class Server( /** * Sends a notification to the client indicating that the list of tools has changed. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.sendToolListChanged instead.", + ReplaceWith("session.sendToolListChanged"), + DeprecationLevel.WARNING + ) public suspend fun sendToolListChanged() { logger.debug { "Sending tool list changed notification" } notification(ToolListChangedNotification()) @@ -591,33 +700,20 @@ public open class Server( /** * Sends a notification to the client indicating that the list of prompts has changed. */ + /** + * Sends a notification to the client indicating that the list of tools has changed. + */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.sendPromptListChanged instead.", + ReplaceWith("session.sendPromptListChanged"), + DeprecationLevel.WARNING + ) public suspend fun sendPromptListChanged() { logger.debug { "Sending prompt list changed notification" } notification(PromptListChangedNotification()) } // --- Internal Handlers --- - - private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { - logger.info { "Handling initialize request from client ${request.clientInfo}" } - clientCapabilities = request.capabilities - clientVersion = request.clientInfo - - val requestedVersion = request.protocolVersion - val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { - requestedVersion - } else { - logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } - LATEST_PROTOCOL_VERSION - } - - return InitializeResult( - protocolVersion = protocolVersion, - capabilities = capabilities, - serverInfo = serverInfo - ) - } - private suspend fun handleListTools(): ListToolsResult { val toolList = tools.values.map { it.tool } return ListToolsResult(tools = toolList, nextCursor = null) @@ -669,6 +765,33 @@ public open class Server( return ListResourceTemplatesResult(listOf()) } + + @Deprecated( + "Will be removed with Protocol inheritance. Use session.handleInitialize instead.", + ReplaceWith("session.handleInitialize"), + DeprecationLevel.WARNING + ) + private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { + logger.info { "Handling initialize request from client ${request.clientInfo}" } + clientCapabilities = request.capabilities + clientVersion = request.clientInfo + + val requestedVersion = request.protocolVersion + val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { + requestedVersion + } else { + logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } + LATEST_PROTOCOL_VERSION + } + + return InitializeResult( + protocolVersion = protocolVersion, + capabilities = serverOptions.capabilities, + serverInfo = serverInfo + ) + } + + /** * Asserts that the client supports the capability required for the given [method]. * @@ -677,6 +800,11 @@ public open class Server( * * @param method The method for which we are asserting capability. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.assertCapabilityForMethod instead.", + ReplaceWith("session.assertCapabilityForMethod"), + DeprecationLevel.WARNING + ) override fun assertCapabilityForMethod(method: Method) { logger.trace { "Asserting capability for method: ${method.value}" } when (method.value) { @@ -712,11 +840,16 @@ public open class Server( * * @param method The notification method. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.assertNotificationCapability instead.", + ReplaceWith("session.assertNotificationCapability"), + DeprecationLevel.WARNING + ) override fun assertNotificationCapability(method: Method) { logger.trace { "Asserting notification capability for method: ${method.value}" } when (method.value) { "notifications/message" -> { - if (capabilities.logging == null) { + if (serverOptions.capabilities.logging == null) { logger.error { "Server capability assertion failed: logging not supported" } throw IllegalStateException("Server does not support logging (required for ${method.value})") } @@ -724,19 +857,19 @@ public open class Server( "notifications/resources/updated", "notifications/resources/list_changed" -> { - if (capabilities.resources == null) { + if (serverOptions.capabilities.resources == null) { throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") } } "notifications/tools/list_changed" -> { - if (capabilities.tools == null) { + if (serverOptions.capabilities.tools == null) { throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") } } "notifications/prompts/list_changed" -> { - if (capabilities.prompts == null) { + if (serverOptions.capabilities.prompts == null) { throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") } } @@ -755,25 +888,30 @@ public open class Server( * * @param method The request method. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use session.assertRequestHandlerCapability instead.", + ReplaceWith("session.assertRequestHandlerCapability"), + DeprecationLevel.WARNING + ) override fun assertRequestHandlerCapability(method: Method) { logger.trace { "Asserting request handler capability for method: ${method.value}" } when (method.value) { "sampling/createMessage" -> { - if (capabilities.sampling == null) { + if (serverOptions.capabilities.sampling == null) { logger.error { "Server capability assertion failed: sampling not supported" } throw IllegalStateException("Server does not support sampling (required for $method)") } } "logging/setLevel" -> { - if (capabilities.logging == null) { + if (serverOptions.capabilities.logging == null) { throw IllegalStateException("Server does not support logging (required for $method)") } } "prompts/get", "prompts/list" -> { - if (capabilities.prompts == null) { + if (serverOptions.capabilities.prompts == null) { throw IllegalStateException("Server does not support prompts (required for $method)") } } @@ -781,14 +919,14 @@ public open class Server( "resources/list", "resources/templates/list", "resources/read" -> { - if (capabilities.resources == null) { + if (serverOptions.capabilities.resources == null) { throw IllegalStateException("Server does not support resources (required for $method)") } } "tools/call", "tools/list" -> { - if (capabilities.tools == null) { + if (serverOptions.capabilities.tools == null) { throw IllegalStateException("Server does not support tools (required for $method)") } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt new file mode 100644 index 00000000..e69de29b diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseBugReproductionTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseBugReproductionTest.kt new file mode 100644 index 00000000..e69de29b From d50c661940253f4bf4adfa93c0933d7c43c420d1 Mon Sep 17 00:00:00 2001 From: Maria Tigina Date: Fri, 25 Jul 2025 19:21:35 +0200 Subject: [PATCH 2/6] Fix websocket ktor server implementation, add test and logs --- .../kotlin/sdk/client/Client.kt | 7 +- .../sdk/client/WebSocketClientTransport.kt | 18 +- .../WebSocketMcpKtorClientExtensions.kt | 6 + .../kotlin/sdk/shared/Protocol.kt | 39 +- .../sdk/shared/WebSocketMcpTransport.kt | 11 +- .../kotlin/sdk/server/KtorServer.kt | 50 +-- .../kotlin/sdk/server/Server.kt | 2 + .../WebSocketMcpKtorServerExtensions.kt | 132 ++++++- .../sdk/server/WebSocketMcpServerTransport.kt | 5 + .../kotlin/sdk/client/SseTransportTest.kt | 8 +- .../sdk/integration/SseIntegrationTest.kt | 168 ++++++++- .../integration/WebSocketIntegrationTest.kt | 207 +++++++++++ .../kotlin/sdk/server/ServerSession.kt | 343 ++++++++++++++++++ 13 files changed, 908 insertions(+), 88 deletions(-) create mode 100644 kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt create mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 75d0b221..a836858d 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -158,14 +158,13 @@ public open class Client( serverVersion = result.serverInfo notification(InitializedNotification()) + } catch (error: CancellationException) { + throw IllegalStateException("Error connecting to transport: ${error.message}") } catch (error: Throwable) { + logger.error(error) { "Failed to initialize client" } close() - if (error !is CancellationException) { - throw IllegalStateException("Error connecting to transport: ${error.message}") - } throw error - } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt index 45719073..9c7fab67 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt @@ -1,15 +1,17 @@ package io.modelcontextprotocol.kotlin.sdk.client -import io.ktor.client.HttpClient -import io.ktor.client.plugins.websocket.webSocketSession -import io.ktor.client.request.HttpRequestBuilder -import io.ktor.client.request.header -import io.ktor.http.HttpHeaders -import io.ktor.websocket.WebSocketSession +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.ktor.websocket.* import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport import kotlin.properties.Delegates +private val logger = KotlinLogging.logger {} + /** * Client transport for WebSocket: this will connect to a server over the WebSocket protocol. */ @@ -21,6 +23,8 @@ public class WebSocketClientTransport( override var session: WebSocketSession by Delegates.notNull() override suspend fun initializeSession() { + logger.debug { "Websocket session initialization started..." } + session = urlString?.let { client.webSocketSession(it) { requestBuilder() @@ -32,5 +36,7 @@ public class WebSocketClientTransport( header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL) } + + logger.debug { "Websocket session initialization finished" } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt index 9d70d6c0..7b63420a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt @@ -1,11 +1,15 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.request.HttpRequestBuilder import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +private val logger = KotlinLogging.logger {} + + /** * Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient. * @@ -36,6 +40,8 @@ public suspend fun HttpClient.mcpWebSocket( version = LIB_VERSION ) ) + logger.debug { "Client started to connect to server" } client.connect(transport) + logger.debug { "Client finished to connect to server" } return client } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index b5f15751..2f85dd6f 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -42,7 +42,7 @@ import kotlin.reflect.typeOf import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds -private val LOGGER = KotlinLogging.logger { } +private val logger = KotlinLogging.logger { } public const val IMPLEMENTATION_NAME: String = "mcp-ktor" @@ -204,6 +204,7 @@ public abstract class Protocol( } } + logger.info { "Starting transport" } return transport.start() } @@ -221,29 +222,29 @@ public abstract class Protocol( } private suspend fun onNotification(notification: JSONRPCNotification) { - LOGGER.trace { "Received notification: ${notification.method}" } + logger.trace { "Received notification: ${notification.method}" } val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler if (handler == null) { - LOGGER.trace { "No handler found for notification: ${notification.method}" } + logger.trace { "No handler found for notification: ${notification.method}" } return } try { handler(notification) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling notification: ${notification.method}" } + logger.error(cause) { "Error handling notification: ${notification.method}" } onError(cause) } } private suspend fun onRequest(request: JSONRPCRequest) { - LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" } + logger.trace { "Received request: ${request.method} (id: ${request.id})" } val handler = requestHandlers[request.method] ?: fallbackRequestHandler if (handler === null) { - LOGGER.trace { "No handler found for request: ${request.method}" } + logger.trace { "No handler found for request: ${request.method}" } try { transport?.send( JSONRPCResponse( @@ -255,7 +256,7 @@ public abstract class Protocol( ) ) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error sending method not found response" } + logger.error(cause) { "Error sending method not found response" } onError(cause) } return @@ -263,7 +264,7 @@ public abstract class Protocol( try { val result = handler(request, RequestHandlerExtra()) - LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } + logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } transport?.send( JSONRPCResponse( @@ -273,7 +274,7 @@ public abstract class Protocol( ) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } + logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } try { transport?.send( @@ -286,14 +287,14 @@ public abstract class Protocol( ) ) } catch (sendError: Throwable) { - LOGGER.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } + logger.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } // Optionally implement fallback behavior here } } } private fun onProgress(notification: ProgressNotification) { - LOGGER.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" } + logger.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" } val progress = notification.params.progress val total = notification.params.total val message = notification.params.message @@ -304,7 +305,7 @@ public abstract class Protocol( val error = Error( "Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}", ) - LOGGER.error { error.message } + logger.error { error.message } onError(error) return } @@ -382,9 +383,9 @@ public abstract class Protocol( request: Request, options: RequestOptions? = null, ): T { - LOGGER.trace { "Sending request: ${request.method}" } + logger.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() - val transport = this@Protocol.transport ?: throw Error("Not connected") + val transport = transport ?: throw Error("Not connected") if (this@Protocol.options?.enforceStrictCapabilities == true) { assertCapabilityForMethod(request.method) @@ -394,7 +395,7 @@ public abstract class Protocol( val messageId = message.id if (options?.onProgress != null) { - LOGGER.trace { "Registering progress handler for request id: $messageId" } + logger.trace { "Registering progress handler for request id: $messageId" } _progressHandlers.update { current -> current.put(messageId, options.onProgress) } @@ -427,7 +428,7 @@ public abstract class Protocol( val notification = CancelledNotification( params = CancelledNotification.Params( - requestId = messageId, + requestId = messageId, reason = reason.message ?: "Unknown" ) ) @@ -444,12 +445,12 @@ public abstract class Protocol( val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT try { withTimeout(timeout) { - LOGGER.trace { "Sending request message with id: $messageId" } + logger.trace { "Sending request message with id: $messageId" } this@Protocol.transport?.send(message) } return result.await() } catch (cause: TimeoutCancellationException) { - LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } + logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } cancel( McpError( ErrorCode.Defined.RequestTimeout.code, @@ -466,7 +467,7 @@ public abstract class Protocol( * Emits a notification, which is a one-way message that does not expect a response. */ public suspend fun notification(notification: Notification) { - LOGGER.trace { "Sending notification: ${notification.method}" } + logger.trace { "Sending notification: ${notification.method}" } val transport = this.transport ?: error("Not connected") assertNotificationCapability(notification.method) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index 29e7b866..3eea2d0f 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.websocket.Frame import io.ktor.websocket.WebSocketSession import io.ktor.websocket.close @@ -17,6 +18,9 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi public const val MCP_SUBPROTOCOL: String = "mcp" +private val logger = KotlinLogging.logger {} + + /** * Abstract class representing a WebSocket transport for the Model Context Protocol (MCP). * Handles communication over a WebSocket session. @@ -40,6 +44,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { protected abstract suspend fun initializeSession() override suspend fun start() { + logger.debug { "Starting websocket transport" } + if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error( "WebSocketClientTransport already started! " + @@ -53,7 +59,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { while (true) { val message = try { session.incoming.receive() - } catch (_: ClosedReceiveChannelException) { + } catch (e: ClosedReceiveChannelException) { + logger.debug { "Closed receive channel, exiting" } return@launch } @@ -84,6 +91,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } override suspend fun send(message: JSONRPCMessage) { + logger.debug { "Sending message" } if (!initialized.load()) { error("Not connected") } @@ -96,6 +104,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { error("Not connected") } + logger.debug { "Closing websocket session" } session.close() session.coroutineContext.job.join() } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 882d4500..729636a8 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -8,7 +8,6 @@ import io.ktor.server.response.respond import io.ktor.server.routing.Routing import io.ktor.server.routing.RoutingContext import io.ktor.server.routing.post -import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession @@ -18,39 +17,40 @@ import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic import kotlinx.atomicfu.update import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentMapOf +import kotlinx.collections.immutable.toPersistentMap private val logger = KotlinLogging.logger {} -@KtorDsl -public fun Routing.mcp(path: String, block: () -> Server) { - route(path) { - mcp(block) +internal class SseTransportManager(transports: Map = emptyMap()) { + private val transports: AtomicRef> = atomic(transports.toPersistentMap()) + + fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId] + + fun addTransport(transport: SseServerTransport) { + transports.update { it.put(transport.sessionId, transport) } + } + + fun removeTransport(sessionId: String) { + transports.update { it.remove(sessionId) } } } -/** - * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). - */ +/* +* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). +*/ @KtorDsl public fun Routing.mcp(block: () -> Server) { - val transports = atomic(persistentMapOf()) + val sseTransportManager = SseTransportManager() sse { - mcpSseEndpoint("", transports, block) + mcpSseEndpoint("", sseTransportManager, block) } post { - mcpPostEndpoint(transports) + mcpPostEndpoint(sseTransportManager) } } -@Suppress("FunctionName") -@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.WARNING) -public fun Application.MCP(block: () -> Server) { - mcp(block) -} - @KtorDsl public fun Application.mcp(block: () -> Server) { install(SSE) @@ -62,16 +62,16 @@ public fun Application.mcp(block: () -> Server) { internal suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, - transports: AtomicRef>, + sseTransportManager: SseTransportManager, block: () -> Server, ) { - val transport = mcpSseTransport(postEndpoint, transports) + val transport = mcpSseTransport(postEndpoint, sseTransportManager) val server = block() server.onClose { logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } - transports.update { it.remove(transport.sessionId) } + sseTransportManager.removeTransport(transport.sessionId) } server.connectSession(transport) @@ -81,17 +81,17 @@ internal suspend fun ServerSSESession.mcpSseEndpoint( internal fun ServerSSESession.mcpSseTransport( postEndpoint: String, - transports: AtomicRef>, + sseTransportManager: SseTransportManager, ): SseServerTransport { val transport = SseServerTransport(postEndpoint, this) - transports.update { it.put(transport.sessionId, transport) } + sseTransportManager.addTransport(transport) logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } return transport } internal suspend fun RoutingContext.mcpPostEndpoint( - transports: AtomicRef>, + sseTransportManager: SseTransportManager, ) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { @@ -101,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint( logger.debug { "Received message for sessionId: $sessionId" } - val transport = transports.value[sessionId] + val transport = sseTransportManager.getTransport(sessionId) if (transport == null) { logger.warn { "Session not found for sessionId: $sessionId" } call.respond(HttpStatusCode.NotFound, "Session not found") diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 205c59a1..f4a8ab5c 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -191,7 +191,9 @@ public open class Server( } } + logger.debug { "Server session connecting to transport" } session.connect(transport) + logger.debug { "Server session successfully connected to transport" } sessions.update { it.add(session) } _onConnect() diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index 9301749b..6d93a87d 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -1,12 +1,93 @@ package io.modelcontextprotocol.kotlin.sdk.server -import io.ktor.server.routing.Route -import io.ktor.server.websocket.WebSocketServerSession -import io.ktor.server.websocket.webSocket +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.application.* +import io.ktor.server.routing.* +import io.ktor.server.websocket.* +import io.ktor.utils.io.* import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +import kotlinx.coroutines.awaitCancellation + +private val logger = KotlinLogging.logger {} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Routing.mcpWebSocket( + block: () -> Server +) { + webSocket { + mcpWebSocketEndpoint(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Routing.mcpWebSocket( + path: String, + block: () -> Server +) { + + webSocket(path) { + mcpWebSocketEndpoint(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Application.mcpWebSocket( + block: () -> Server +) { + install(WebSockets) + + routing { + mcpWebSocket(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket at the specified path. + */ +@KtorDsl +public fun Application.mcpWebSocket( + path: String, + block: () -> Server +) { + install(WebSockets) + + routing { + mcpWebSocket(path, block) + } +} + +internal suspend fun WebSocketServerSession.mcpWebSocketEndpoint( + block: () -> Server +) { + logger.info { "Ktor Server establishing new connection" } + val transport = createMcpTransport(this) + val server = block() + var session: ServerSession? = null + try { + session = server.connectSession(transport) + awaitCancellation() + } catch (e: CancellationException) { + session?.close() + } +} + +private fun createMcpTransport( + webSocketSession: WebSocketServerSession, +): WebSocketMcpServerTransport { + return WebSocketMcpServerTransport(webSocketSession) +} /** * Registers a WebSocket route that establishes an MCP (Model Context Protocol) server session. @@ -14,6 +95,11 @@ import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), options ?: ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocket( options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}, @@ -23,6 +109,19 @@ public fun Route.mcpWebSocket( } } +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket(block)"), + DeprecationLevel.WARNING +) +public fun Route.mcpWebSocket( + block: () -> Server +) { + webSocket { + block().connect(createMcpTransport(this)) + } +} + /** * Registers a WebSocket route at the specified [path] that establishes an MCP server session. * @@ -30,6 +129,11 @@ public fun Route.mcpWebSocket( * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ +@Deprecated( + "Use mcpWebSocket with a path and a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket(path) { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), options ?: ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocket( path: String, options: ServerOptions? = null, @@ -45,6 +149,11 @@ public fun Route.mcpWebSocket( * * @param handler A suspend function that defines the behavior of the transport layer. */ +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocketTransport( handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, ) { @@ -62,6 +171,11 @@ public fun Route.mcpWebSocketTransport( * @param path The URL path at which to register the WebSocket route. * @param handler A suspend function that defines the behavior of the transport layer. */ +@Deprecated( + "Use mcpWebSocket with a path and a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket(path) { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocketTransport( path: String, handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, @@ -74,7 +188,11 @@ public fun Route.mcpWebSocketTransport( } } - +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), options ?: ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) private suspend fun Route.createMcpServer( session: WebSocketServerSession, options: ServerOptions?, @@ -100,9 +218,3 @@ private suspend fun Route.createMcpServer( handler(server) server.close() } - -private fun createMcpTransport( - session: WebSocketServerSession, -): WebSocketMcpServerTransport { - return WebSocketMcpServerTransport(session) -} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt index 45cb4df9..35885cb5 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt @@ -1,10 +1,14 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.http.HttpHeaders import io.ktor.server.websocket.WebSocketServerSession import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport + +private val logger = KotlinLogging.logger {} + /** * Server-side implementation of the MCP (Model Context Protocol) transport over WebSocket. * @@ -14,6 +18,7 @@ public class WebSocketMcpServerTransport( override val session: WebSocketServerSession, ) : WebSocketMcpTransport() { override suspend fun initializeSession() { + logger.debug { "Checking session headers" } val subprotocol = session.call.request.headers[HttpHeaders.SecWebSocketProtocol] if (subprotocol != MCP_SUBPROTOCOL) { error("Invalid subprotocol: $subprotocol, expected $MCP_SUBPROTOCOL") diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index 23ddadf1..ce110b3a 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -71,7 +71,7 @@ class SseTransportTest : BaseTransportTest() { routing { mcp { mcpServer } // sse { -// mcpSseTransport("", transports).apply { +// mcpSseTransport("", transportManager).apply { // onMessage { // send(it) // } @@ -81,7 +81,7 @@ class SseTransportTest : BaseTransportTest() { // } // // post { -// mcpPostEndpoint(transports) +// mcpPostEndpoint(transportManager) // } } }.startSuspend(wait = false) @@ -113,7 +113,7 @@ class SseTransportTest : BaseTransportTest() { mcp("/sse") { mcpServer } // route("/sse") { // sse { -// mcpSseTransport("", transports).apply { +// mcpSseTransport("", transportManager).apply { // onMessage { // send(it) // } @@ -123,7 +123,7 @@ class SseTransportTest : BaseTransportTest() { // } // // post { -// mcpPostEndpoint(transports) +// mcpPostEndpoint(transportManager) // } // } } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 19d84589..cf5f0fe1 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -2,68 +2,198 @@ package io.modelcontextprotocol.kotlin.sdk.integration import io.ktor.client.HttpClient import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.ApplicationStopped import io.ktor.server.application.install import io.ktor.server.cio.CIOApplicationEngine import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpSse +import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.mcp import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout import kotlin.test.Test -import kotlin.test.fail +import kotlin.test.assertTrue import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO class SseIntegrationTest { @Test fun `client should be able to connect to sse server`() = runTest { - val serverEngine = initServer() + var server: EmbeddedServer? = null var client: Client? = null + try { withContext(Dispatchers.Default) { - assertDoesNotThrow { client = initClient() } + withTimeout(1000) { + server = initServer() + client = initClient() + } } - } catch (e: Exception) { - fail("Failed to connect client: $e") } finally { client?.close() - // Make sure to stop the server - serverEngine.stopSuspend(1000, 2000) + server?.stop(1000, 2000) } } - private inline fun assertDoesNotThrow(block: () -> T): T { - return try { - block() - } catch (e: Throwable) { - fail("Expected no exception, but got: $e") + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open SSE from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single sse connection`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + client = initClient("Client A") + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } + } finally { + client?.close() + server?.stop(1000, 2000) } } - private suspend fun initClient(): Client { - return HttpClient(ClientCIO) { install(SSE) }.mcpSse("http://$URL:$PORT") + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open SSE connection #1 from Client A and note the sessionId= value. + * 2. Open SSE connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple sse connections`() = runTest { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + clientA = initClient("Client A") + clientB = initClient("Client B") + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } + } finally { + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) + } + } + + private suspend fun initClient(name: String = ""): Client { + val client = Client( + Implementation(name = name, version = "1.0.0") + ) + + val httpClient = HttpClient(ClientCIO) { + install(SSE) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpSseTransport { + url { + host = URL + port = PORT + } + } + + client.connect(transport) + + return client } private suspend fun initServer(): EmbeddedServer { val server = Server( - Implementation(name = "sse-e2e-test", version = "1.0.0"), - ServerOptions(capabilities = ServerCapabilities()), + Implementation(name = "sse-server", version = "1.0.0"), + ServerOptions( + capabilities = + ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)) + ), ) - return embeddedServer(ServerCIO, host = URL, port = PORT) { + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true + ) + ) + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}") + ) + ) + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { install(io.ktor.server.sse.SSE) routing { mcp { server } } - }.startSuspend(wait = false) + } + + ktorServer.monitor.subscribe(ApplicationStopped) { + println("SD -- [T] ktor server has been stopped") + } + + return ktorServer.startSuspend(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName) + ) + ) + + return (response?.messages?.first()?.content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") } companion object { diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt new file mode 100644 index 00000000..4e55f695 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt @@ -0,0 +1,207 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.client.HttpClient +import io.ktor.server.application.ApplicationStopped +import io.ktor.server.application.install +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpWebSocketTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcpWebSocket +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.test.assertTrue +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.websocket.WebSockets as ServerWebSockets +import io.ktor.client.plugins.websocket.WebSockets as ClientWebSocket + +class WebSocketIntegrationTest { + + @Test + fun `client should be able to connect to websocket server 2`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + client = initClient() + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open WebSocket from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single websocket connection`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + client = initClient() + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open WebSocket connection #1 from Client A and note the sessionId= value. + * 2. Open WebSocket connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple websocket connections`() = runTest { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + clientA = initClient() + clientB = initClient() + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } + } finally { + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) + } + } + + + private suspend fun initClient(name: String = ""): Client { + val client = Client( + Implementation(name = name, version = "1.0.0") + ) + + val httpClient = HttpClient(ClientCIO) { + install(ClientWebSocket) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpWebSocketTransport { + url { + host = URL + port = PORT + } + } + + client.connect(transport) + + return client + } + + + private suspend fun initServer(): EmbeddedServer { + val server = Server( + Implementation(name = "websocket-server", version = "1.0.0"), + ServerOptions( + capabilities = + ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)) + ), + ) + + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true + ) + ) + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}") + ) + ) + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerWebSockets) + routing { + mcpWebSocket(block = { server }) + } + } + + ktorServer.monitor.subscribe(ApplicationStopped) { + println("SD -- [T] ktor server has been stopped") + } + + return ktorServer.startSuspend(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName) + ) + ) + + return (response?.messages?.first()?.content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } + + companion object { + private const val PORT = 3002 + private const val URL = "localhost" + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt new file mode 100644 index 00000000..2c430991 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -0,0 +1,343 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult +import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest +import io.modelcontextprotocol.kotlin.sdk.ListRootsResult +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.shared.Protocol +import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import kotlinx.coroutines.CompletableDeferred +import kotlinx.serialization.json.JsonObject + +private val logger = KotlinLogging.logger {} + +public open class ServerSession( + private val serverInfo: Implementation, + options: ServerOptions, +) : Protocol(options) { + private var _onInitialized: (() -> Unit) = {} + private var _onClose: () -> Unit = {} + + init { + // Core protocol handlers + setRequestHandler(Method.Defined.Initialize) { request, _ -> + handleInitialize(request) + } + setNotificationHandler(Method.Defined.NotificationsInitialized) { + _onInitialized() + CompletableDeferred(Unit) + } + } + + /** + * The capabilities supported by the server, related to the session. + */ + private val serverCapabilities = options.capabilities + + /** + * The client's reported capabilities after initialization. + */ + public var clientCapabilities: ClientCapabilities? = null + private set + + /** + * The client's version information after initialization. + */ + public var clientVersion: Implementation? = null + private set + + /** + * Registers a callback to be invoked when the server has completed initialization. + */ + public fun onInitialized(block: () -> Unit) { + val old = _onInitialized + _onInitialized = { + old() + block() + } + } + + /** + * Registers a callback to be invoked when the server session is closing. + */ + public fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + /** + * Called when the server session is closing. + */ + override fun onClose() { + logger.debug { "Server connection closing" } + _onClose() + } + + /** + * Sends a ping request to the client to check connectivity. + * + * @return The result of the ping request. + * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. + */ + public suspend fun ping(): EmptyRequestResult { + return request(PingRequest()) + } + + /** + * Creates a message using the server's sampling capability. + * + * @param params The parameters for creating a message. + * @param options Optional request options. + * @return The created message result. + * @throws IllegalStateException If the server does not support sampling or if the request fails. + */ + public suspend fun createMessage( + params: CreateMessageRequest, + options: RequestOptions? = null + ): CreateMessageResult { + logger.debug { "Creating message with params: $params" } + return request(params, options) + } + + /** + * Lists the available "roots" from the client's perspective (if supported). + * + * @param params JSON parameters for the request, usually empty. + * @param options Optional request options. + * @return The list of roots. + * @throws IllegalStateException If the server or client does not support roots. + */ + public suspend fun listRoots( + params: JsonObject = EmptyJsonObject, + options: RequestOptions? = null + ): ListRootsResult { + logger.debug { "Listing roots with params: $params" } + return request(ListRootsRequest(params), options) + } + + public suspend fun createElicitation( + message: String, + requestedSchema: RequestedSchema, + options: RequestOptions? = null + ): CreateElicitationResult { + logger.debug { "Creating elicitation with message: $message" } + return request(CreateElicitationRequest(message, requestedSchema), options) + } + + /** + * Sends a logging message notification to the client. + * + * @param params The logging message notification parameters. + */ + public suspend fun sendLoggingMessage(params: LoggingMessageNotification) { + logger.trace { "Sending logging message: ${params.data}" } + notification(params) + } + + /** + * Sends a resource-updated notification to the client, indicating that a specific resource has changed. + * + * @param params Details of the updated resource. + */ + public suspend fun sendResourceUpdated(params: ResourceUpdatedNotification) { + logger.debug { "Sending resource updated notification for: ${params.uri}" } + notification(params) + } + + /** + * Sends a notification to the client indicating that the list of resources has changed. + */ + public suspend fun sendResourceListChanged() { + logger.debug { "Sending resource list changed notification" } + notification(ResourceListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of tools has changed. + */ + public suspend fun sendToolListChanged() { + logger.debug { "Sending tool list changed notification" } + notification(ToolListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of prompts has changed. + */ + public suspend fun sendPromptListChanged() { + logger.debug { "Sending prompt list changed notification" } + notification(PromptListChangedNotification()) + } + + /** + * Asserts that the client supports the capability required for the given [method]. + * + * This method is automatically called by the [Protocol] framework before handling requests. + * Throws [IllegalStateException] if the capability is not supported. + * + * @param method The method for which we are asserting capability. + */ + override fun assertCapabilityForMethod(method: Method) { + logger.trace { "Asserting capability for method: ${method.value}" } + when (method.value) { + "sampling/createMessage" -> { + if (clientCapabilities?.sampling == null) { + logger.error { "Client capability assertion failed: sampling not supported" } + throw IllegalStateException("Client does not support sampling (required for ${method.value})") + } + } + + "roots/list" -> { + if (clientCapabilities?.roots == null) { + throw IllegalStateException("Client does not support listing roots (required for ${method.value})") + } + } + + "elicitation/create" -> { + if (clientCapabilities?.elicitation == null) { + throw IllegalStateException("Client does not support elicitation (required for ${method.value})") + } + } + + "ping" -> { + // No specific capability required + } + } + } + + /** + * Asserts that the server can handle the specified notification method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. + * + * @param method The notification method. + */ + override fun assertNotificationCapability(method: Method) { + logger.trace { "Asserting notification capability for method: ${method.value}" } + when (method.value) { + "notifications/message" -> { + if (serverCapabilities.logging == null) { + logger.error { "Server capability assertion failed: logging not supported" } + throw IllegalStateException("Server does not support logging (required for ${method.value})") + } + } + + "notifications/resources/updated", + "notifications/resources/list_changed" -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") + } + } + + "notifications/tools/list_changed" -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") + } + } + + "notifications/prompts/list_changed" -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") + } + } + + "notifications/cancelled", + "notifications/progress" -> { + // Always allowed + } + } + } + + /** + * Asserts that the server can handle the specified request method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. + * + * @param method The request method. + */ + override fun assertRequestHandlerCapability(method: Method) { + logger.trace { "Asserting request handler capability for method: ${method.value}" } + when (method.value) { + "sampling/createMessage" -> { + if (serverCapabilities.sampling == null) { + logger.error { "Server capability assertion failed: sampling not supported" } + throw IllegalStateException("Server does not support sampling (required for $method)") + } + } + + "logging/setLevel" -> { + if (serverCapabilities.logging == null) { + throw IllegalStateException("Server does not support logging (required for $method)") + } + } + + "prompts/get", + "prompts/list" -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support prompts (required for $method)") + } + } + + "resources/list", + "resources/templates/list", + "resources/read" -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support resources (required for $method)") + } + } + + "tools/call", + "tools/list" -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support tools (required for $method)") + } + } + + "ping", "initialize" -> { + // No capability required + } + } + } + + private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { + logger.debug { "Handling initialization request from client" } + clientCapabilities = request.capabilities + clientVersion = request.clientInfo + + val requestedVersion = request.protocolVersion + val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { + requestedVersion + } else { + logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } + LATEST_PROTOCOL_VERSION + } + + return InitializeResult( + protocolVersion = protocolVersion, + capabilities = serverCapabilities, + serverInfo = serverInfo + ) + } +} From 57da4cb119ad374d7b40750df9e99467ef272f04 Mon Sep 17 00:00:00 2001 From: Maria Tigina Date: Fri, 25 Jul 2025 22:45:32 +0200 Subject: [PATCH 3/6] Fix imports --- .../kotlin/sdk/client/SSEClientTransport.kt | 1 - .../sdk/client/WebSocketClientTransport.kt | 11 ++--- .../kotlin/sdk/server/Server.kt | 43 ++++++++++++++++++- .../WebSocketMcpKtorServerExtensions.kt | 14 ++++-- 4 files changed, 58 insertions(+), 11 deletions(-) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index 20b7382f..778d0481 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -71,7 +71,6 @@ public class SseClientTransport( check(initialized.compareAndSet(expectedValue = false, newValue = true)) { "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically." } - logger.info { "Starting SseClientTransport..." } try { session = urlString?.let { diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt index 9c7fab67..ec3f9470 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt @@ -1,11 +1,12 @@ package io.modelcontextprotocol.kotlin.sdk.client import io.github.oshai.kotlinlogging.KotlinLogging -import io.ktor.client.* -import io.ktor.client.plugins.websocket.* -import io.ktor.client.request.* -import io.ktor.http.* -import io.ktor.websocket.* +import io.ktor.client.HttpClient +import io.ktor.client.plugins.websocket.webSocketSession +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.header +import io.ktor.http.HttpHeaders +import io.ktor.websocket.WebSocketSession import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport import kotlin.properties.Delegates diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index f4a8ab5c..09684a12 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -1,8 +1,49 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging -import io.modelcontextprotocol.kotlin.sdk.* +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult +import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.ListPromptsRequest +import io.modelcontextprotocol.kotlin.sdk.ListPromptsResult +import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest +import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult +import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest +import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult +import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest +import io.modelcontextprotocol.kotlin.sdk.ListRootsResult +import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest +import io.modelcontextprotocol.kotlin.sdk.ListToolsResult +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.Prompt +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.Resource +import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations +import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index 6d93a87d..bea21edc 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -1,10 +1,16 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging -import io.ktor.server.application.* -import io.ktor.server.routing.* -import io.ktor.server.websocket.* -import io.ktor.utils.io.* +import io.ktor.server.application.Application +import io.ktor.server.application.install +import io.ktor.server.routing.Route +import io.ktor.server.routing.Routing +import io.ktor.server.routing.routing +import io.ktor.server.websocket.WebSocketServerSession +import io.ktor.server.websocket.WebSockets +import io.ktor.server.websocket.webSocket +import io.ktor.utils.io.CancellationException +import io.ktor.utils.io.KtorDsl import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities From 3b19e6f2e6a7a6fd54a47e33f45abd7047dbcf1a Mon Sep 17 00:00:00 2001 From: Maria Tigina Date: Fri, 25 Jul 2025 22:57:27 +0200 Subject: [PATCH 4/6] Change api --- kotlin-sdk-core/api/kotlin-sdk-core.api | 297 +++++++++++++++++++++++- 1 file changed, 295 insertions(+), 2 deletions(-) diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 569dfb3e..643e7c82 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -3200,8 +3200,301 @@ public final class io/modelcontextprotocol/kotlin/sdk/WithMeta$Companion { public final fun serializer ()Lkotlinx/serialization/KSerializer; } -public final class io/modelcontextprotocol/kotlin/sdk/internal/Utils_jvmKt { - public static final fun getIODispatcher ()Lkotlinx/coroutines/CoroutineDispatcher; +public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun addRoot (Ljava/lang/String;Ljava/lang/String;)V + public final fun addRoots (Ljava/util/List;)V + protected final fun assertCapability (Ljava/lang/String;Ljava/lang/String;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun callTool (Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun callTool (Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun complete (Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun complete$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getPrompt (Lio/modelcontextprotocol/kotlin/sdk/GetPromptRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun getPrompt$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/GetPromptRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getServerCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities; + public final fun getServerVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun listPrompts (Lio/modelcontextprotocol/kotlin/sdk/ListPromptsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listPrompts$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListPromptsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun listResourceTemplates (Lio/modelcontextprotocol/kotlin/sdk/ListResourceTemplatesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listResourceTemplates$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListResourceTemplatesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun listResources (Lio/modelcontextprotocol/kotlin/sdk/ListResourcesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listResources$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListResourcesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun listTools (Lio/modelcontextprotocol/kotlin/sdk/ListToolsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listTools$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListToolsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun ping (Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun readResource (Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun readResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun removeRoot (Ljava/lang/String;)Z + public final fun removeRoots (Ljava/util/List;)I + public final fun sendRootsListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setElicitationHandler (Lkotlin/jvm/functions/Function1;)V + public final fun setLoggingLevel (Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun setLoggingLevel$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun subscribeResource (Lio/modelcontextprotocol/kotlin/sdk/SubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun subscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/SubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun unsubscribeResource (Lio/modelcontextprotocol/kotlin/sdk/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun unsubscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/ClientOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { + public fun ()V + public fun (Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;Z)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/KtorClientKt { + public static final fun mcpSse-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpSse-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpSseTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; + public static synthetic fun mcpSseTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/MainKt { + public static final fun main ()V + public static synthetic fun main ([Ljava/lang/String;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getProtocolVersion ()Ljava/lang/String; + public final fun getSessionId ()Ljava/lang/String; + public final fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun send$default (Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun setProtocolVersion (Ljava/lang/String;)V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun terminateSession (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpError : java/lang/Exception { + public fun ()V + public fun (Ljava/lang/Integer;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/Integer;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCode ()Ljava/lang/Integer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensionsKt { + public static final fun mcpStreamableHttp-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpStreamableHttp-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpStreamableHttpTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; + public static synthetic fun mcpStreamableHttpTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport { + public fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensionsKt { + public static final fun mcpWebSocket (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpWebSocket$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpWebSocketTransport (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport; + public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport; +} + +public final class io/modelcontextprotocol/kotlin/sdk/internal/MainKt { + public static final fun initClient (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun initClient$default (Ljava/lang/String;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun main ()V + public static synthetic fun main ([Ljava/lang/String;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/ClientSession { + public fun (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;Lio/modelcontextprotocol/kotlin/sdk/Implementation;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;Lio/modelcontextprotocol/kotlin/sdk/Implementation;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun getTransport ()Lio/modelcontextprotocol/kotlin/sdk/shared/Transport; + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun setClientCapabilities (Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;)V + public final fun setClientVersion (Lio/modelcontextprotocol/kotlin/sdk/Implementation;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { + public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V + public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function0;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Prompt; + public final fun component2 ()Lkotlin/jvm/functions/Function2; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt;Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt; + public fun equals (Ljava/lang/Object;)Z + public final fun getMessageProvider ()Lkotlin/jvm/functions/Function2; + public final fun getPrompt ()Lio/modelcontextprotocol/kotlin/sdk/Prompt; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredResource { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Resource;Lkotlin/jvm/functions/Function2;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Resource; + public final fun component2 ()Lkotlin/jvm/functions/Function2; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/Resource;Lkotlin/jvm/functions/Function2;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredResource; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredResource;Lio/modelcontextprotocol/kotlin/sdk/Resource;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredResource; + public fun equals (Ljava/lang/Object;)Z + public final fun getReadHandler ()Lkotlin/jvm/functions/Function2; + public final fun getResource ()Lio/modelcontextprotocol/kotlin/sdk/Resource; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredTool { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Tool; + public final fun component2 ()Lkotlin/jvm/functions/Function2; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredTool; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredTool;Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredTool; + public fun equals (Ljava/lang/Object;)Z + public final fun getHandler ()Lkotlin/jvm/functions/Function2; + public final fun getTool ()Lio/modelcontextprotocol/kotlin/sdk/Tool; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;)V + public final fun addPrompt (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V + public final fun addPrompt (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun addPrompt$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public final fun addPrompts (Ljava/util/List;)V + public final fun addResource (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun addResource$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public final fun addResources (Ljava/util/List;)V + public final fun addTool (Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;)V + public final fun addTool (Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Input;Lio/modelcontextprotocol/kotlin/sdk/Tool$Output;Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun addTool$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Input;Lio/modelcontextprotocol/kotlin/sdk/Tool$Output;Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public final fun addTools (Ljava/util/List;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun closeSessions (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun connectSession (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun createElicitation (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun getPrompts ()Ljava/util/Map; + public final fun getResources ()Ljava/util/Map; + public final fun getTools ()Ljava/util/Map; + public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun onClose ()V + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onConnect (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun removePrompt (Ljava/lang/String;)Z + public final fun removePrompts (Ljava/util/List;)I + public final fun removeResource (Ljava/lang/String;)Z + public final fun removeResources (Ljava/util/List;)I + public final fun removeTool (Ljava/lang/String;)Z + public final fun removeTools (Ljava/util/List;)I + public final fun sendLoggingMessage (Lio/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendPromptListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceUpdated (Lio/modelcontextprotocol/kotlin/sdk/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { + public fun (Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities;Z)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities; +} + +public class io/modelcontextprotocol/kotlin/sdk/server/ServerSession : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun createElicitation (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun onClose ()V + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendLoggingMessage (Lio/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendPromptListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceUpdated (Lio/modelcontextprotocol/kotlin/sdk/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun (Ljava/lang/String;Lio/ktor/server/sse/ServerSSESession;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSessionId ()Ljava/lang/String; + public final fun handleMessage (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handlePostMessage (Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { + public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun mcpWebSocket$default (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public static synthetic fun mcpWebSocket$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public static final fun mcpWebSocketTransport (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocketTransport (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport { + public fun (Lio/ktor/server/websocket/WebSocketServerSession;)V + public synthetic fun getSession ()Lio/ktor/websocket/WebSocketSession; } public abstract class io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport : io/modelcontextprotocol/kotlin/sdk/shared/Transport { From f9f1a1554fb058929c0f9b4506a7d6b63826bcf9 Mon Sep 17 00:00:00 2001 From: Maria Tigina Date: Mon, 4 Aug 2025 17:58:36 +0200 Subject: [PATCH 5/6] Changes after rebase --- .../kotlin/sdk/server/ServerSession.kt | 343 ++++++++++++++++++ .../kotlin/sdk/client/SseTransportTest.kt | 4 +- .../kotlin/sdk/server/ServerSession.kt | 343 ------------------ 3 files changed, 346 insertions(+), 344 deletions(-) delete mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt index e69de29b..5952bf25 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -0,0 +1,343 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult +import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest +import io.modelcontextprotocol.kotlin.sdk.ListRootsResult +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.shared.Protocol +import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import kotlinx.coroutines.CompletableDeferred +import kotlinx.serialization.json.JsonObject + +private val logger = KotlinLogging.logger {} + +public open class ServerSession( + private val serverInfo: Implementation, + options: ServerOptions, +) : Protocol(options) { + private var _onInitialized: (() -> Unit) = {} + private var _onClose: () -> Unit = {} + + init { + // Core protocol handlers + setRequestHandler(Method.Defined.Initialize) { request, _ -> + handleInitialize(request) + } + setNotificationHandler(Method.Defined.NotificationsInitialized) { + _onInitialized() + CompletableDeferred(Unit) + } + } + + /** + * The capabilities supported by the server, related to the session. + */ + private val serverCapabilities = options.capabilities + + /** + * The client's reported capabilities after initialization. + */ + public var clientCapabilities: ClientCapabilities? = null + private set + + /** + * The client's version information after initialization. + */ + public var clientVersion: Implementation? = null + private set + + /** + * Registers a callback to be invoked when the server has completed initialization. + */ + public fun onInitialized(block: () -> Unit) { + val old = _onInitialized + _onInitialized = { + old() + block() + } + } + + /** + * Registers a callback to be invoked when the server session is closing. + */ + public fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + /** + * Called when the server session is closing. + */ + override fun onClose() { + logger.debug { "Server connection closing" } + _onClose() + } + + /** + * Sends a ping request to the client to check connectivity. + * + * @return The result of the ping request. + * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. + */ + public suspend fun ping(): EmptyRequestResult { + return request(PingRequest()) + } + + /** + * Creates a message using the server's sampling capability. + * + * @param params The parameters for creating a message. + * @param options Optional request options. + * @return The created message result. + * @throws IllegalStateException If the server does not support sampling or if the request fails. + */ + public suspend fun createMessage( + params: CreateMessageRequest, + options: RequestOptions? = null + ): CreateMessageResult { + logger.debug { "Creating message with params: $params" } + return request(params, options) + } + + /** + * Lists the available "roots" from the client's perspective (if supported). + * + * @param params JSON parameters for the request, usually empty. + * @param options Optional request options. + * @return The list of roots. + * @throws IllegalStateException If the server or client does not support roots. + */ + public suspend fun listRoots( + params: JsonObject = EmptyJsonObject, + options: RequestOptions? = null + ): ListRootsResult { + logger.debug { "Listing roots with params: $params" } + return request(ListRootsRequest(params), options) + } + + public suspend fun createElicitation( + message: String, + requestedSchema: RequestedSchema, + options: RequestOptions? = null + ): CreateElicitationResult { + logger.debug { "Creating elicitation with message: $message" } + return request(CreateElicitationRequest(message, requestedSchema), options) + } + + /** + * Sends a logging message notification to the client. + * + * @param notification The logging message notification. + */ + public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) { + logger.trace { "Sending logging message: ${notification.params.data}" } + notification(notification) + } + + /** + * Sends a resource-updated notification to the client, indicating that a specific resource has changed. + * + * @param notification Details of the updated resource. + */ + public suspend fun sendResourceUpdated(notification: ResourceUpdatedNotification) { + logger.debug { "Sending resource updated notification for: ${notification.params.uri}" } + notification(notification) + } + + /** + * Sends a notification to the client indicating that the list of resources has changed. + */ + public suspend fun sendResourceListChanged() { + logger.debug { "Sending resource list changed notification" } + notification(ResourceListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of tools has changed. + */ + public suspend fun sendToolListChanged() { + logger.debug { "Sending tool list changed notification" } + notification(ToolListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of prompts has changed. + */ + public suspend fun sendPromptListChanged() { + logger.debug { "Sending prompt list changed notification" } + notification(PromptListChangedNotification()) + } + + /** + * Asserts that the client supports the capability required for the given [method]. + * + * This method is automatically called by the [Protocol] framework before handling requests. + * Throws [IllegalStateException] if the capability is not supported. + * + * @param method The method for which we are asserting capability. + */ + override fun assertCapabilityForMethod(method: Method) { + logger.trace { "Asserting capability for method: ${method.value}" } + when (method.value) { + "sampling/createMessage" -> { + if (clientCapabilities?.sampling == null) { + logger.error { "Client capability assertion failed: sampling not supported" } + throw IllegalStateException("Client does not support sampling (required for ${method.value})") + } + } + + "roots/list" -> { + if (clientCapabilities?.roots == null) { + throw IllegalStateException("Client does not support listing roots (required for ${method.value})") + } + } + + "elicitation/create" -> { + if (clientCapabilities?.elicitation == null) { + throw IllegalStateException("Client does not support elicitation (required for ${method.value})") + } + } + + "ping" -> { + // No specific capability required + } + } + } + + /** + * Asserts that the server can handle the specified notification method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. + * + * @param method The notification method. + */ + override fun assertNotificationCapability(method: Method) { + logger.trace { "Asserting notification capability for method: ${method.value}" } + when (method.value) { + "notifications/message" -> { + if (serverCapabilities.logging == null) { + logger.error { "Server capability assertion failed: logging not supported" } + throw IllegalStateException("Server does not support logging (required for ${method.value})") + } + } + + "notifications/resources/updated", + "notifications/resources/list_changed" -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") + } + } + + "notifications/tools/list_changed" -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") + } + } + + "notifications/prompts/list_changed" -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") + } + } + + "notifications/cancelled", + "notifications/progress" -> { + // Always allowed + } + } + } + + /** + * Asserts that the server can handle the specified request method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. + * + * @param method The request method. + */ + override fun assertRequestHandlerCapability(method: Method) { + logger.trace { "Asserting request handler capability for method: ${method.value}" } + when (method.value) { + "sampling/createMessage" -> { + if (serverCapabilities.sampling == null) { + logger.error { "Server capability assertion failed: sampling not supported" } + throw IllegalStateException("Server does not support sampling (required for $method)") + } + } + + "logging/setLevel" -> { + if (serverCapabilities.logging == null) { + throw IllegalStateException("Server does not support logging (required for $method)") + } + } + + "prompts/get", + "prompts/list" -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support prompts (required for $method)") + } + } + + "resources/list", + "resources/templates/list", + "resources/read" -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support resources (required for $method)") + } + } + + "tools/call", + "tools/list" -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support tools (required for $method)") + } + } + + "ping", "initialize" -> { + // No capability required + } + } + } + + private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { + logger.debug { "Handling initialization request from client" } + clientCapabilities = request.capabilities + clientVersion = request.clientInfo + + val requestedVersion = request.protocolVersion + val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { + requestedVersion + } else { + logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } + LATEST_PROTOCOL_VERSION + } + + return InitializeResult( + protocolVersion = protocolVersion, + capabilities = serverCapabilities, + serverInfo = serverInfo + ) + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index ce110b3a..8a49b29b 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -5,6 +5,8 @@ import io.ktor.server.application.install import io.ktor.server.cio.CIO import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.post +import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities @@ -110,7 +112,7 @@ class SseTransportTest : BaseTransportTest() { val server = embeddedServer(CIO, port = 0) { install(ServerSSE) routing { - mcp("/sse") { mcpServer } + mcp { mcpServer } // route("/sse") { // sse { // mcpSseTransport("", transportManager).apply { diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt deleted file mode 100644 index 2c430991..00000000 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt +++ /dev/null @@ -1,343 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.server - -import io.github.oshai.kotlinlogging.KotlinLogging -import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult -import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest -import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult -import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject -import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeRequest -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION -import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest -import io.modelcontextprotocol.kotlin.sdk.ListRootsResult -import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification -import io.modelcontextprotocol.kotlin.sdk.Method -import io.modelcontextprotocol.kotlin.sdk.PingRequest -import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification -import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS -import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.shared.Protocol -import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions -import kotlinx.coroutines.CompletableDeferred -import kotlinx.serialization.json.JsonObject - -private val logger = KotlinLogging.logger {} - -public open class ServerSession( - private val serverInfo: Implementation, - options: ServerOptions, -) : Protocol(options) { - private var _onInitialized: (() -> Unit) = {} - private var _onClose: () -> Unit = {} - - init { - // Core protocol handlers - setRequestHandler(Method.Defined.Initialize) { request, _ -> - handleInitialize(request) - } - setNotificationHandler(Method.Defined.NotificationsInitialized) { - _onInitialized() - CompletableDeferred(Unit) - } - } - - /** - * The capabilities supported by the server, related to the session. - */ - private val serverCapabilities = options.capabilities - - /** - * The client's reported capabilities after initialization. - */ - public var clientCapabilities: ClientCapabilities? = null - private set - - /** - * The client's version information after initialization. - */ - public var clientVersion: Implementation? = null - private set - - /** - * Registers a callback to be invoked when the server has completed initialization. - */ - public fun onInitialized(block: () -> Unit) { - val old = _onInitialized - _onInitialized = { - old() - block() - } - } - - /** - * Registers a callback to be invoked when the server session is closing. - */ - public fun onClose(block: () -> Unit) { - val old = _onClose - _onClose = { - old() - block() - } - } - - /** - * Called when the server session is closing. - */ - override fun onClose() { - logger.debug { "Server connection closing" } - _onClose() - } - - /** - * Sends a ping request to the client to check connectivity. - * - * @return The result of the ping request. - * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. - */ - public suspend fun ping(): EmptyRequestResult { - return request(PingRequest()) - } - - /** - * Creates a message using the server's sampling capability. - * - * @param params The parameters for creating a message. - * @param options Optional request options. - * @return The created message result. - * @throws IllegalStateException If the server does not support sampling or if the request fails. - */ - public suspend fun createMessage( - params: CreateMessageRequest, - options: RequestOptions? = null - ): CreateMessageResult { - logger.debug { "Creating message with params: $params" } - return request(params, options) - } - - /** - * Lists the available "roots" from the client's perspective (if supported). - * - * @param params JSON parameters for the request, usually empty. - * @param options Optional request options. - * @return The list of roots. - * @throws IllegalStateException If the server or client does not support roots. - */ - public suspend fun listRoots( - params: JsonObject = EmptyJsonObject, - options: RequestOptions? = null - ): ListRootsResult { - logger.debug { "Listing roots with params: $params" } - return request(ListRootsRequest(params), options) - } - - public suspend fun createElicitation( - message: String, - requestedSchema: RequestedSchema, - options: RequestOptions? = null - ): CreateElicitationResult { - logger.debug { "Creating elicitation with message: $message" } - return request(CreateElicitationRequest(message, requestedSchema), options) - } - - /** - * Sends a logging message notification to the client. - * - * @param params The logging message notification parameters. - */ - public suspend fun sendLoggingMessage(params: LoggingMessageNotification) { - logger.trace { "Sending logging message: ${params.data}" } - notification(params) - } - - /** - * Sends a resource-updated notification to the client, indicating that a specific resource has changed. - * - * @param params Details of the updated resource. - */ - public suspend fun sendResourceUpdated(params: ResourceUpdatedNotification) { - logger.debug { "Sending resource updated notification for: ${params.uri}" } - notification(params) - } - - /** - * Sends a notification to the client indicating that the list of resources has changed. - */ - public suspend fun sendResourceListChanged() { - logger.debug { "Sending resource list changed notification" } - notification(ResourceListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of tools has changed. - */ - public suspend fun sendToolListChanged() { - logger.debug { "Sending tool list changed notification" } - notification(ToolListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of prompts has changed. - */ - public suspend fun sendPromptListChanged() { - logger.debug { "Sending prompt list changed notification" } - notification(PromptListChangedNotification()) - } - - /** - * Asserts that the client supports the capability required for the given [method]. - * - * This method is automatically called by the [Protocol] framework before handling requests. - * Throws [IllegalStateException] if the capability is not supported. - * - * @param method The method for which we are asserting capability. - */ - override fun assertCapabilityForMethod(method: Method) { - logger.trace { "Asserting capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (clientCapabilities?.sampling == null) { - logger.error { "Client capability assertion failed: sampling not supported" } - throw IllegalStateException("Client does not support sampling (required for ${method.value})") - } - } - - "roots/list" -> { - if (clientCapabilities?.roots == null) { - throw IllegalStateException("Client does not support listing roots (required for ${method.value})") - } - } - - "elicitation/create" -> { - if (clientCapabilities?.elicitation == null) { - throw IllegalStateException("Client does not support elicitation (required for ${method.value})") - } - } - - "ping" -> { - // No specific capability required - } - } - } - - /** - * Asserts that the server can handle the specified notification method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. - * - * @param method The notification method. - */ - override fun assertNotificationCapability(method: Method) { - logger.trace { "Asserting notification capability for method: ${method.value}" } - when (method.value) { - "notifications/message" -> { - if (serverCapabilities.logging == null) { - logger.error { "Server capability assertion failed: logging not supported" } - throw IllegalStateException("Server does not support logging (required for ${method.value})") - } - } - - "notifications/resources/updated", - "notifications/resources/list_changed" -> { - if (serverCapabilities.resources == null) { - throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") - } - } - - "notifications/tools/list_changed" -> { - if (serverCapabilities.tools == null) { - throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") - } - } - - "notifications/prompts/list_changed" -> { - if (serverCapabilities.prompts == null) { - throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") - } - } - - "notifications/cancelled", - "notifications/progress" -> { - // Always allowed - } - } - } - - /** - * Asserts that the server can handle the specified request method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. - * - * @param method The request method. - */ - override fun assertRequestHandlerCapability(method: Method) { - logger.trace { "Asserting request handler capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (serverCapabilities.sampling == null) { - logger.error { "Server capability assertion failed: sampling not supported" } - throw IllegalStateException("Server does not support sampling (required for $method)") - } - } - - "logging/setLevel" -> { - if (serverCapabilities.logging == null) { - throw IllegalStateException("Server does not support logging (required for $method)") - } - } - - "prompts/get", - "prompts/list" -> { - if (serverCapabilities.prompts == null) { - throw IllegalStateException("Server does not support prompts (required for $method)") - } - } - - "resources/list", - "resources/templates/list", - "resources/read" -> { - if (serverCapabilities.resources == null) { - throw IllegalStateException("Server does not support resources (required for $method)") - } - } - - "tools/call", - "tools/list" -> { - if (serverCapabilities.tools == null) { - throw IllegalStateException("Server does not support tools (required for $method)") - } - } - - "ping", "initialize" -> { - // No capability required - } - } - } - - private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { - logger.debug { "Handling initialization request from client" } - clientCapabilities = request.capabilities - clientVersion = request.clientInfo - - val requestedVersion = request.protocolVersion - val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { - requestedVersion - } else { - logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } - LATEST_PROTOCOL_VERSION - } - - return InitializeResult( - protocolVersion = protocolVersion, - capabilities = serverCapabilities, - serverInfo = serverInfo - ) - } -} From fcf25b6622f140ae018eb66d26e5773257a76bec Mon Sep 17 00:00:00 2001 From: Maria Tigina Date: Tue, 5 Aug 2025 14:32:43 +0200 Subject: [PATCH 6/6] Remove Protocol from Server --- .../kotlin/sdk/client/Client.kt | 8 +- kotlin-sdk-core/api/kotlin-sdk-core.api | 5 - .../kotlin/sdk/server/KtorServer.kt | 8 +- .../kotlin/sdk/server/Server.kt | 485 +----------------- .../WebSocketMcpKtorServerExtensions.kt | 2 +- .../kotlin/sdk/client/ClientTest.kt | 264 ++++++---- .../kotlin/sdk/client/SseTransportTest.kt | 30 -- 7 files changed, 198 insertions(+), 604 deletions(-) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index a836858d..6631f5e9 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -158,12 +158,14 @@ public open class Client( serverVersion = result.serverInfo notification(InitializedNotification()) - } catch (error: CancellationException) { - throw IllegalStateException("Error connecting to transport: ${error.message}") } catch (error: Throwable) { - logger.error(error) { "Failed to initialize client" } + logger.error(error) { "Failed to initialize client: ${error.message}" } close() + if (error !is CancellationException) { + throw IllegalStateException("Error connecting to transport: ${error.message}") + } + throw error } } diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 643e7c82..b753a5cd 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -3258,11 +3258,6 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/KtorClientKt { public static synthetic fun mcpSseTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; } -public final class io/modelcontextprotocol/kotlin/sdk/client/MainKt { - public static final fun main ()V - public static synthetic fun main ([Ljava/lang/String;)V -} - public final class io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 729636a8..88bdd94d 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -51,6 +51,12 @@ public fun Routing.mcp(block: () -> Server) { } } +@Suppress("FunctionName") +@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.ERROR) +public fun Application.MCP(block: () -> Server) { + mcp(block) +} + @KtorDsl public fun Application.mcp(block: () -> Server) { install(SSE) @@ -74,7 +80,7 @@ internal suspend fun ServerSSESession.mcpSseEndpoint( sseTransportManager.removeTransport(transport.sessionId) } - server.connectSession(transport) + server.connect(transport) logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 09684a12..03277347 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -3,60 +3,35 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult -import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest -import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult -import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject -import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeRequest -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION import io.modelcontextprotocol.kotlin.sdk.ListPromptsRequest import io.modelcontextprotocol.kotlin.sdk.ListPromptsResult import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult -import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest -import io.modelcontextprotocol.kotlin.sdk.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.ListToolsResult -import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.Method -import io.modelcontextprotocol.kotlin.sdk.PingRequest import io.modelcontextprotocol.kotlin.sdk.Prompt import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult import io.modelcontextprotocol.kotlin.sdk.Resource -import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification -import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations -import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions -import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions import io.modelcontextprotocol.kotlin.sdk.shared.Transport import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus +import kotlinx.collections.immutable.persistentListOf import kotlinx.collections.immutable.persistentMapOf -import kotlinx.collections.immutable.persistentSetOf import kotlinx.collections.immutable.toPersistentSet -import kotlinx.coroutines.CompletableDeferred -import kotlinx.serialization.json.JsonObject private val logger = KotlinLogging.logger {} @@ -83,37 +58,14 @@ public class ServerOptions( */ public open class Server( private val serverInfo: Implementation, - options: ServerOptions, -) : Protocol(options) { - private val sessions = atomic(persistentSetOf()) - private val serverOptions = options + private val options: ServerOptions, +) { + private val sessions = atomic(persistentListOf()) private var _onInitialized: (() -> Unit) = {} private var _onConnect: (() -> Unit) = {} private var _onClose: () -> Unit = {} - /** - * The client's reported capabilities after initialization. - */ - @Deprecated( - "Moved to ServerSession", - ReplaceWith("ServerSession.clientCapabilities"), - DeprecationLevel.WARNING - ) - public var clientCapabilities: ClientCapabilities? = null - private set - - /** - * The client's version information after initialization. - */ - @Deprecated( - "Moved to ServerSession", - ReplaceWith("ServerSession.clientVersion"), - DeprecationLevel.WARNING - ) - public var clientVersion: Implementation? = null - private set - private val _tools = atomic(persistentMapOf()) private val _prompts = atomic(persistentMapOf()) private val _resources = atomic(persistentMapOf()) @@ -124,65 +76,7 @@ public open class Server( public val resources: Map get() = _resources.value - init { - logger.debug { "Initializing MCP server with option: $options" } - - // TODO: Remove all after Protocol inheritance - // Core protocol handlers - setRequestHandler(Method.Defined.Initialize) { request, _ -> - handleInitialize(request) - } - setNotificationHandler(Method.Defined.NotificationsInitialized) { - _onInitialized() - CompletableDeferred(Unit) - } - - // Internal handlers for tools - if (serverOptions.capabilities.tools != null) { - setRequestHandler(Method.Defined.ToolsList) { _, _ -> - handleListTools() - } - setRequestHandler(Method.Defined.ToolsCall) { request, _ -> - handleCallTool(request) - } - } - - // Internal handlers for prompts - if (serverOptions.capabilities.prompts != null) { - setRequestHandler(Method.Defined.PromptsList) { _, _ -> - handleListPrompts() - } - setRequestHandler(Method.Defined.PromptsGet) { request, _ -> - handleGetPrompt(request) - } - } - - // Internal handlers for resources - if (serverOptions.capabilities.resources != null) { - setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - handleListResources() - } - setRequestHandler(Method.Defined.ResourcesRead) { request, _ -> - handleReadResource(request) - } - setRequestHandler(Method.Defined.ResourcesTemplatesList) { _, _ -> - handleListResourceTemplates() - } - } - } - - @Deprecated( - "Will be removed with Protocol inheritance. Use connectSession instead.", - ReplaceWith("connectSession"), - DeprecationLevel.WARNING - ) - public override fun onClose() { - logger.debug { "Closing MCP server" } - _onClose() - } - - // TODO: Rename closeSessions to close after the full onClose deprecation - public suspend fun closeSessions() { + public suspend fun close() { logger.debug { "Closing MCP server" } sessions.value.forEach { it.close() } _onClose() @@ -195,12 +89,11 @@ public open class Server( * @param transport The transport layer to connect the session with. * @return The initialized and connected server session. */ - // TODO: Rename connectSession to connect after the full connect deprecation - public suspend fun connectSession(transport: Transport): ServerSession { - val session = ServerSession(serverInfo, serverOptions) + public suspend fun connect(transport: Transport): ServerSession { + val session = ServerSession(serverInfo, options) // Internal handlers for tools - if (serverOptions.capabilities.tools != null) { + if (options.capabilities.tools != null) { session.setRequestHandler(Method.Defined.ToolsList) { _, _ -> handleListTools() } @@ -210,7 +103,7 @@ public open class Server( } // Internal handlers for prompts - if (serverOptions.capabilities.prompts != null) { + if (options.capabilities.prompts != null) { session.setRequestHandler(Method.Defined.PromptsList) { _, _ -> handleListPrompts() } @@ -220,7 +113,7 @@ public open class Server( } // Internal handlers for resources - if (serverOptions.capabilities.resources != null) { + if (options.capabilities.resources != null) { session.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> handleListResources() } @@ -241,15 +134,6 @@ public open class Server( return session } - @Deprecated( - "Will be removed with Protocol inheritance. Use connectSession instead.", - ReplaceWith("connectSession"), - DeprecationLevel.WARNING - ) - public override suspend fun connect(transport: Transport) { - super.connect(transport) - } - /** * Registers a callback to be invoked when the new server session connected. */ @@ -296,7 +180,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun addTool(tool: Tool, handler: suspend (CallToolRequest) -> CallToolResult) { - if (serverOptions.capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to add tool '${tool.name}': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.") } @@ -336,7 +220,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun addTools(toolsToAdd: List) { - if (serverOptions.capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to add tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -352,7 +236,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun removeTool(name: String): Boolean { - if (serverOptions.capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to remove tool '$name': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -379,7 +263,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun removeTools(toolNames: List): Int { - if (serverOptions.capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to remove tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -406,7 +290,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompt(prompt: Prompt, promptProvider: suspend (GetPromptRequest) -> GetPromptResult) { - if (serverOptions.capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to add prompt '${prompt.name}': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -440,7 +324,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompts(promptsToAdd: List) { - if (serverOptions.capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to add prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -456,7 +340,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompt(name: String): Boolean { - if (serverOptions.capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to remove prompt '$name': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -483,7 +367,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompts(promptNames: List): Int { - if (serverOptions.capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to remove prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -520,7 +404,7 @@ public open class Server( mimeType: String = "text/html", readHandler: suspend (ReadResourceRequest) -> ReadResourceResult ) { - if (serverOptions.capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -540,7 +424,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun addResources(resourcesToAdd: List) { - if (serverOptions.capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to add resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -556,7 +440,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun removeResource(uri: String): Boolean { - if (serverOptions.capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to remove resource '$uri': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -583,7 +467,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun removeResources(uris: List): Int { - if (serverOptions.capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to remove resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -603,159 +487,6 @@ public open class Server( return removedCount } - /** - * Sends a ping request to the client to check connectivity. - * - * @return The result of the ping request. - * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.ping instead.", - ReplaceWith("session.ping"), - DeprecationLevel.WARNING - ) - public suspend fun ping(): EmptyRequestResult { - return request(PingRequest()) - } - - /** - * Creates a message using the server's sampling capability. - * - * @param params The parameters for creating a message. - * @param options Optional request options. - * @return The created message result. - * @throws IllegalStateException If the server does not support sampling or if the request fails. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.createMessage instead.", - ReplaceWith("session.createMessage"), - DeprecationLevel.WARNING - ) - public suspend fun createMessage( - params: CreateMessageRequest, - options: RequestOptions? = null - ): CreateMessageResult { - logger.debug { "Creating message with params: $params" } - return request(params, options) - } - - /** - * Lists the available "roots" from the client's perspective (if supported). - * - * @param params JSON parameters for the request, usually empty. - * @param options Optional request options. - * @return The list of roots. - * @throws IllegalStateException If the server or client does not support roots. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.listRoots instead.", - ReplaceWith("session.listRoots"), - DeprecationLevel.WARNING - ) - public suspend fun listRoots( - params: JsonObject = EmptyJsonObject, - options: RequestOptions? = null - ): ListRootsResult { - logger.debug { "Listing roots with params: $params" } - return request(ListRootsRequest(params), options) - } - - /** - * Creates an elicitation request with the specified message and schema. - * - * @param message The message to be used for the elicitation. - * @param requestedSchema The schema defining the structure of the requested elicitation. - * @param options Optional parameters to customize the elicitation request. - * @return Returns the result of the elicitation creation process. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.createElicitation instead.", - ReplaceWith("session.createElicitation"), - DeprecationLevel.WARNING - ) - public suspend fun createElicitation( - message: String, - requestedSchema: RequestedSchema, - options: RequestOptions? = null - ): CreateElicitationResult { - logger.debug { "Creating elicitation with message: $message" } - return request(CreateElicitationRequest(message, requestedSchema), options) - } - - /** - * Sends a logging message notification to the client. - * - * @param params The logging message notification parameters. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.sendLoggingMessage instead.", - ReplaceWith("session.sendLoggingMessage"), - DeprecationLevel.WARNING - ) - public suspend fun sendLoggingMessage( - params: LoggingMessageNotification - ) { - logger.trace { "Sending logging message: ${params.params.data}" } - notification(params) - } - - /** - * Sends a resource-updated notification to the client, indicating that a specific resource has changed. - * - * @param params Details of the updated resource. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.sendResourceUpdated instead.", - ReplaceWith("session.sendResourceUpdated"), - DeprecationLevel.WARNING - ) - public suspend fun sendResourceUpdated(params: ResourceUpdatedNotification) { - logger.debug { "Sending resource updated notification for: ${params.params.uri}" } - notification(params) - } - - /** - * Sends a notification to the client indicating that the list of resources has changed. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.sendResourceListChanged instead.", - ReplaceWith("session.sendResourceListChanged"), - DeprecationLevel.WARNING - ) - public suspend fun sendResourceListChanged() { - logger.debug { "Sending resource list changed notification" } - notification(ResourceListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of tools has changed. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.sendToolListChanged instead.", - ReplaceWith("session.sendToolListChanged"), - DeprecationLevel.WARNING - ) - public suspend fun sendToolListChanged() { - logger.debug { "Sending tool list changed notification" } - notification(ToolListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of prompts has changed. - */ - /** - * Sends a notification to the client indicating that the list of tools has changed. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.sendPromptListChanged instead.", - ReplaceWith("session.sendPromptListChanged"), - DeprecationLevel.WARNING - ) - public suspend fun sendPromptListChanged() { - logger.debug { "Sending prompt list changed notification" } - notification(PromptListChangedNotification()) - } - // --- Internal Handlers --- private suspend fun handleListTools(): ListToolsResult { val toolList = tools.values.map { it.tool } @@ -807,178 +538,6 @@ public open class Server( // If you have resource templates, return them here. For now, return empty. return ListResourceTemplatesResult(listOf()) } - - - @Deprecated( - "Will be removed with Protocol inheritance. Use session.handleInitialize instead.", - ReplaceWith("session.handleInitialize"), - DeprecationLevel.WARNING - ) - private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { - logger.info { "Handling initialize request from client ${request.clientInfo}" } - clientCapabilities = request.capabilities - clientVersion = request.clientInfo - - val requestedVersion = request.protocolVersion - val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { - requestedVersion - } else { - logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } - LATEST_PROTOCOL_VERSION - } - - return InitializeResult( - protocolVersion = protocolVersion, - capabilities = serverOptions.capabilities, - serverInfo = serverInfo - ) - } - - - /** - * Asserts that the client supports the capability required for the given [method]. - * - * This method is automatically called by the [Protocol] framework before handling requests. - * Throws [IllegalStateException] if the capability is not supported. - * - * @param method The method for which we are asserting capability. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.assertCapabilityForMethod instead.", - ReplaceWith("session.assertCapabilityForMethod"), - DeprecationLevel.WARNING - ) - override fun assertCapabilityForMethod(method: Method) { - logger.trace { "Asserting capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (clientCapabilities?.sampling == null) { - logger.error { "Client capability assertion failed: sampling not supported" } - throw IllegalStateException("Client does not support sampling (required for ${method.value})") - } - } - - "roots/list" -> { - if (clientCapabilities?.roots == null) { - throw IllegalStateException("Client does not support listing roots (required for ${method.value})") - } - } - - "elicitation/create" -> { - if (clientCapabilities?.elicitation == null) { - throw IllegalStateException("Client does not support elicitation (required for ${method.value})") - } - } - - "ping" -> { - // No specific capability required - } - } - } - - /** - * Asserts that the server can handle the specified notification method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. - * - * @param method The notification method. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.assertNotificationCapability instead.", - ReplaceWith("session.assertNotificationCapability"), - DeprecationLevel.WARNING - ) - override fun assertNotificationCapability(method: Method) { - logger.trace { "Asserting notification capability for method: ${method.value}" } - when (method.value) { - "notifications/message" -> { - if (serverOptions.capabilities.logging == null) { - logger.error { "Server capability assertion failed: logging not supported" } - throw IllegalStateException("Server does not support logging (required for ${method.value})") - } - } - - "notifications/resources/updated", - "notifications/resources/list_changed" -> { - if (serverOptions.capabilities.resources == null) { - throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") - } - } - - "notifications/tools/list_changed" -> { - if (serverOptions.capabilities.tools == null) { - throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") - } - } - - "notifications/prompts/list_changed" -> { - if (serverOptions.capabilities.prompts == null) { - throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") - } - } - - "notifications/cancelled", - "notifications/progress" -> { - // Always allowed - } - } - } - - /** - * Asserts that the server can handle the specified request method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. - * - * @param method The request method. - */ - @Deprecated( - "Will be removed with Protocol inheritance. Use session.assertRequestHandlerCapability instead.", - ReplaceWith("session.assertRequestHandlerCapability"), - DeprecationLevel.WARNING - ) - override fun assertRequestHandlerCapability(method: Method) { - logger.trace { "Asserting request handler capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (serverOptions.capabilities.sampling == null) { - logger.error { "Server capability assertion failed: sampling not supported" } - throw IllegalStateException("Server does not support sampling (required for $method)") - } - } - - "logging/setLevel" -> { - if (serverOptions.capabilities.logging == null) { - throw IllegalStateException("Server does not support logging (required for $method)") - } - } - - "prompts/get", - "prompts/list" -> { - if (serverOptions.capabilities.prompts == null) { - throw IllegalStateException("Server does not support prompts (required for $method)") - } - } - - "resources/list", - "resources/templates/list", - "resources/read" -> { - if (serverOptions.capabilities.resources == null) { - throw IllegalStateException("Server does not support resources (required for $method)") - } - } - - "tools/call", - "tools/list" -> { - if (serverOptions.capabilities.tools == null) { - throw IllegalStateException("Server does not support tools (required for $method)") - } - } - - "ping", "initialize" -> { - // No capability required - } - } - } } /** diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index bea21edc..14cc81b5 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -82,7 +82,7 @@ internal suspend fun WebSocketServerSession.mcpWebSocketEndpoint( val server = block() var session: ServerSession? = null try { - session = server.connectSession(transport) + session = server.connect(transport) awaitCancellation() } catch (e: CancellationException) { session?.close() diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index 26330ce1..c2ca13fa 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -30,8 +30,16 @@ import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.ServerSession import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import kotlin.coroutines.cancellation.CancellationException +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertTrue +import kotlin.test.fail import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.cancel @@ -43,13 +51,6 @@ import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import kotlinx.serialization.json.putJsonObject -import kotlin.coroutines.cancellation.CancellationException -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertIs -import kotlin.test.assertTrue -import kotlin.test.fail class ClientTest { @Test @@ -241,25 +242,6 @@ class ClientTest { serverOptions ) - server.setRequestHandler(Method.Defined.Initialize) { _, _ -> - InitializeResult( - protocolVersion = LATEST_PROTOCOL_VERSION, - capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) - ), - serverInfo = Implementation(name = "test", version = "1.0") - ) - } - - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - ListResourcesResult(resources = emptyList(), nextCursor = null) - } - - server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> - ListToolsResult(tools = emptyList(), nextCursor = null) - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( @@ -269,15 +251,36 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) } ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> + InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null) + ), + serverInfo = Implementation(name = "test", version = "1.0") + ) + } + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + ListResourcesResult(resources = emptyList(), nextCursor = null) + } + + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + ListToolsResult(tools = emptyList(), nextCursor = null) + } // Server supports resources and tools, but not prompts val caps = client.serverCapabilities assertEquals(ServerCapabilities.Resources(null, null), caps?.resources) @@ -368,24 +371,27 @@ class ClientTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") } ).joinAll() + val serverSession = serverSessionResult.await() // These should not throw val jsonObject = buildJsonObject { put("name", "John") put("age", 30) put("isStudent", false) } - server.sendLoggingMessage( + serverSession.sendLoggingMessage( LoggingMessageNotification( params = LoggingMessageNotification.Params( level = LoggingLevel.info, @@ -393,11 +399,11 @@ class ClientTest { ) ) ) - server.sendResourceListChanged() + serverSession.sendResourceListChanged() // This should fail because the server doesn't have the tools capability val ex = assertFailsWith { - server.sendToolListChanged() + serverSession.sendToolListChanged() } assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") == true) } @@ -418,19 +424,6 @@ class ClientTest { val def = CompletableDeferred() val defTimeOut = CompletableDeferred() - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - // Simulate delay - def.complete(Unit) - try { - delay(1000) - } catch (e: CancellationException) { - defTimeOut.complete(Unit) - throw e - } - ListResourcesResult(resources = emptyList()) - fail("Shouldn't have been called") - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( @@ -438,17 +431,35 @@ class ClientTest { options = ClientOptions(capabilities = ClientCapabilities()) ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") } ).joinAll() + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate delay + def.complete(Unit) + try { + delay(1000) + } catch (e: CancellationException) { + defTimeOut.complete(Unit) + throw e + } + ListResourcesResult(resources = emptyList()) + fail("Shouldn't have been called") + } + + val defCancel = CompletableDeferred() val job = launch { try { @@ -478,37 +489,40 @@ class ClientTest { ) ) - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - // Simulate a delayed response - // Wait ~100ms unless canceled - try { - withTimeout(100L) { - // Just delay here, if timeout is 0 on the client side, this won't return in time - delay(100) - } - } catch (_: Exception) { - // If aborted, just rethrow or return early - } - ListResourcesResult(resources = emptyList()) - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions(capabilities = ClientCapabilities()) ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") } ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate a delayed response + // Wait ~100ms unless canceled + try { + withTimeout(100L) { + // Just delay here, if timeout is 0 on the client side, this won't return in time + delay(100) + } + } catch (_: Exception) { + // If aborted, just rethrow or return early + } + ListResourcesResult(resources = emptyList()) + } + // Request with 1 msec timeout should fail immediately val ex = assertFailsWith { withTimeout(1) { @@ -559,7 +573,36 @@ class ClientTest { serverOptions ) - server.setRequestHandler(Method.Defined.Initialize) { _, _ -> + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(sampling = EmptyJsonObject), + ) + ) + + var receivedMessage: JSONRPCMessage? = null + clientTransport.onMessage { msg -> + receivedMessage = msg + } + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } + ).joinAll() + + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> InitializeResult( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( @@ -569,6 +612,7 @@ class ClientTest { serverInfo = Implementation(name = "test", version = "1.0") ) } + val serverListToolsResult = ListToolsResult( tools = listOf( Tool( @@ -582,33 +626,10 @@ class ClientTest { ), nextCursor = null ) - server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> serverListToolsResult } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() - - val client = Client( - clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions( - capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ) - ) - - var receivedMessage: JSONRPCMessage? = null - clientTransport.onMessage { msg -> - receivedMessage = msg - } - - listOf( - launch { - client.connect(clientTransport) - }, - launch { - server.connect(serverTransport) - } - ).joinAll() - val serverCapabilities = client.serverCapabilities assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) @@ -651,15 +672,25 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() - val clientCapabilities = server.clientCapabilities + val serverSession = serverSessionResult.await() + + val clientCapabilities = serverSession.clientCapabilities assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots) - val listRootsResult = server.listRoots() + val listRootsResult = serverSession.listRoots() assertEquals(listRootsResult.roots, clientRoots) } @@ -772,16 +803,27 @@ class ClientTest { // Track notifications var rootListChangedNotificationReceived = false - server.setNotificationHandler(Method.Defined.NotificationsRootsListChanged) { - rootListChangedNotificationReceived = true - CompletableDeferred(Unit) - } + + + val serverSessionResult = CompletableDeferred() listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setNotificationHandler(Method.Defined.NotificationsRootsListChanged) { + rootListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + client.sendRootsListChanged() assertTrue( @@ -808,14 +850,24 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() + val serverSession = serverSessionResult.await() + // Verify that creating an elicitation throws an exception val exception = assertFailsWith { - server.createElicitation( + serverSession.createElicitation( message = "Please provide your GitHub username", requestedSchema = CreateElicitationRequest.RequestedSchema( properties = buildJsonObject { @@ -878,12 +930,22 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() - val result = server.createElicitation( + val serverSession = serverSessionResult.await() + + val result = serverSession.createElicitation( message = elicitationMessage, requestedSchema = requestedSchema ) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index 8a49b29b..839c13ab 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -5,8 +5,6 @@ import io.ktor.server.application.install import io.ktor.server.cio.CIO import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer -import io.ktor.server.routing.post -import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities @@ -72,19 +70,6 @@ class SseTransportTest : BaseTransportTest() { install(ServerSSE) routing { mcp { mcpServer } -// sse { -// mcpSseTransport("", transportManager).apply { -// onMessage { -// send(it) -// } -// -// start() -// } -// } -// -// post { -// mcpPostEndpoint(transportManager) -// } } }.startSuspend(wait = false) @@ -113,21 +98,6 @@ class SseTransportTest : BaseTransportTest() { install(ServerSSE) routing { mcp { mcpServer } -// route("/sse") { -// sse { -// mcpSseTransport("", transportManager).apply { -// onMessage { -// send(it) -// } -// -// start() -// } -// } -// -// post { -// mcpPostEndpoint(transportManager) -// } -// } } }.startSuspend(wait = false)