diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 75d0b221..6631f5e9 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -159,13 +159,14 @@ public open class Client( notification(InitializedNotification()) } catch (error: Throwable) { + logger.error(error) { "Failed to initialize client: ${error.message}" } close() + if (error !is CancellationException) { throw IllegalStateException("Error connecting to transport: ${error.message}") } throw error - } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index d30f5288..778d0481 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.plugins.sse.ClientSSESession import io.ktor.client.plugins.sse.sseSession @@ -46,6 +47,8 @@ public class SseClientTransport( private val reconnectionTime: Duration? = null, private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val initialized: AtomicBoolean = AtomicBoolean(false) private val endpoint = CompletableDeferred() @@ -111,6 +114,8 @@ public class SseClientTransport( val text = response.bodyAsText() error("Error POSTing to endpoint (HTTP ${response.status}): $text") } + + logger.debug { "Client successfully sent message via SSE $endpoint" } } catch (e: Throwable) { _onError(e) throw e @@ -157,6 +162,7 @@ public class SseClientTransport( val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData val endpointUrl = Url("$baseUrl/$path") endpoint.complete(endpointUrl.toString()) + logger.debug { "Client connected to endpoint: $endpointUrl" } } catch (e: Throwable) { _onError(e) endpoint.completeExceptionally(e) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt index 45719073..ec3f9470 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.plugins.websocket.webSocketSession import io.ktor.client.request.HttpRequestBuilder @@ -10,6 +11,8 @@ import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport import kotlin.properties.Delegates +private val logger = KotlinLogging.logger {} + /** * Client transport for WebSocket: this will connect to a server over the WebSocket protocol. */ @@ -21,6 +24,8 @@ public class WebSocketClientTransport( override var session: WebSocketSession by Delegates.notNull() override suspend fun initializeSession() { + logger.debug { "Websocket session initialization started..." } + session = urlString?.let { client.webSocketSession(it) { requestBuilder() @@ -32,5 +37,7 @@ public class WebSocketClientTransport( header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL) } + + logger.debug { "Websocket session initialization finished" } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt index 9d70d6c0..7b63420a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt @@ -1,11 +1,15 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.request.HttpRequestBuilder import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +private val logger = KotlinLogging.logger {} + + /** * Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient. * @@ -36,6 +40,8 @@ public suspend fun HttpClient.mcpWebSocket( version = LIB_VERSION ) ) + logger.debug { "Client started to connect to server" } client.connect(transport) + logger.debug { "Client finished to connect to server" } return client } diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 569dfb3e..b753a5cd 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -3200,8 +3200,296 @@ public final class io/modelcontextprotocol/kotlin/sdk/WithMeta$Companion { public final fun serializer ()Lkotlinx/serialization/KSerializer; } -public final class io/modelcontextprotocol/kotlin/sdk/internal/Utils_jvmKt { - public static final fun getIODispatcher ()Lkotlinx/coroutines/CoroutineDispatcher; +public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun addRoot (Ljava/lang/String;Ljava/lang/String;)V + public final fun addRoots (Ljava/util/List;)V + protected final fun assertCapability (Ljava/lang/String;Ljava/lang/String;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun callTool (Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun callTool (Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun complete (Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun complete$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getPrompt (Lio/modelcontextprotocol/kotlin/sdk/GetPromptRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun getPrompt$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/GetPromptRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getServerCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities; + public final fun getServerVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun listPrompts (Lio/modelcontextprotocol/kotlin/sdk/ListPromptsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listPrompts$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListPromptsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun listResourceTemplates (Lio/modelcontextprotocol/kotlin/sdk/ListResourceTemplatesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listResourceTemplates$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListResourceTemplatesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun listResources (Lio/modelcontextprotocol/kotlin/sdk/ListResourcesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listResources$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListResourcesRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun listTools (Lio/modelcontextprotocol/kotlin/sdk/ListToolsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listTools$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ListToolsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun ping (Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun readResource (Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun readResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun removeRoot (Ljava/lang/String;)Z + public final fun removeRoots (Ljava/util/List;)I + public final fun sendRootsListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setElicitationHandler (Lkotlin/jvm/functions/Function1;)V + public final fun setLoggingLevel (Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun setLoggingLevel$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun subscribeResource (Lio/modelcontextprotocol/kotlin/sdk/SubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun subscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/SubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun unsubscribeResource (Lio/modelcontextprotocol/kotlin/sdk/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun unsubscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/ClientOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { + public fun ()V + public fun (Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;Z)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/KtorClientKt { + public static final fun mcpSse-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpSse-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpSseTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; + public static synthetic fun mcpSseTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getProtocolVersion ()Ljava/lang/String; + public final fun getSessionId ()Ljava/lang/String; + public final fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun send$default (Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun setProtocolVersion (Ljava/lang/String;)V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun terminateSession (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpError : java/lang/Exception { + public fun ()V + public fun (Ljava/lang/Integer;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/Integer;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCode ()Ljava/lang/Integer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensionsKt { + public static final fun mcpStreamableHttp-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpStreamableHttp-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpStreamableHttpTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; + public static synthetic fun mcpStreamableHttpTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport { + public fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensionsKt { + public static final fun mcpWebSocket (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpWebSocket$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpWebSocketTransport (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport; + public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport; +} + +public final class io/modelcontextprotocol/kotlin/sdk/internal/MainKt { + public static final fun initClient (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun initClient$default (Ljava/lang/String;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun main ()V + public static synthetic fun main ([Ljava/lang/String;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/ClientSession { + public fun (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;Lio/modelcontextprotocol/kotlin/sdk/Implementation;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;Lio/modelcontextprotocol/kotlin/sdk/Implementation;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun getTransport ()Lio/modelcontextprotocol/kotlin/sdk/shared/Transport; + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun setClientCapabilities (Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;)V + public final fun setClientVersion (Lio/modelcontextprotocol/kotlin/sdk/Implementation;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { + public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V + public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function0;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Prompt; + public final fun component2 ()Lkotlin/jvm/functions/Function2; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt;Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt; + public fun equals (Ljava/lang/Object;)Z + public final fun getMessageProvider ()Lkotlin/jvm/functions/Function2; + public final fun getPrompt ()Lio/modelcontextprotocol/kotlin/sdk/Prompt; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredResource { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Resource;Lkotlin/jvm/functions/Function2;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Resource; + public final fun component2 ()Lkotlin/jvm/functions/Function2; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/Resource;Lkotlin/jvm/functions/Function2;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredResource; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredResource;Lio/modelcontextprotocol/kotlin/sdk/Resource;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredResource; + public fun equals (Ljava/lang/Object;)Z + public final fun getReadHandler ()Lkotlin/jvm/functions/Function2; + public final fun getResource ()Lio/modelcontextprotocol/kotlin/sdk/Resource; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredTool { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Tool; + public final fun component2 ()Lkotlin/jvm/functions/Function2; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredTool; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredTool;Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/server/RegisteredTool; + public fun equals (Ljava/lang/Object;)Z + public final fun getHandler ()Lkotlin/jvm/functions/Function2; + public final fun getTool ()Lio/modelcontextprotocol/kotlin/sdk/Tool; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;)V + public final fun addPrompt (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V + public final fun addPrompt (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun addPrompt$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public final fun addPrompts (Ljava/util/List;)V + public final fun addResource (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun addResource$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public final fun addResources (Ljava/util/List;)V + public final fun addTool (Lio/modelcontextprotocol/kotlin/sdk/Tool;Lkotlin/jvm/functions/Function2;)V + public final fun addTool (Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Input;Lio/modelcontextprotocol/kotlin/sdk/Tool$Output;Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun addTool$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Input;Lio/modelcontextprotocol/kotlin/sdk/Tool$Output;Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public final fun addTools (Ljava/util/List;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun closeSessions (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun connectSession (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun createElicitation (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun getPrompts ()Ljava/util/Map; + public final fun getResources ()Ljava/util/Map; + public final fun getTools ()Ljava/util/Map; + public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun onClose ()V + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onConnect (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun removePrompt (Ljava/lang/String;)Z + public final fun removePrompts (Ljava/util/List;)I + public final fun removeResource (Ljava/lang/String;)Z + public final fun removeResources (Ljava/util/List;)I + public final fun removeTool (Ljava/lang/String;)Z + public final fun removeTools (Ljava/util/List;)I + public final fun sendLoggingMessage (Lio/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendPromptListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceUpdated (Lio/modelcontextprotocol/kotlin/sdk/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { + public fun (Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities;Z)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities; +} + +public class io/modelcontextprotocol/kotlin/sdk/server/ServerSession : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun createElicitation (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun onClose ()V + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendLoggingMessage (Lio/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendPromptListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceUpdated (Lio/modelcontextprotocol/kotlin/sdk/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun (Ljava/lang/String;Lio/ktor/server/sse/ServerSSESession;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSessionId ()Ljava/lang/String; + public final fun handleMessage (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handlePostMessage (Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { + public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun mcpWebSocket$default (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public static synthetic fun mcpWebSocket$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public static final fun mcpWebSocketTransport (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocketTransport (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function2;)V + public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V + public static synthetic fun mcpWebSocketTransport$default (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport { + public fun (Lio/ktor/server/websocket/WebSocketServerSession;)V + public synthetic fun getSession ()Lio/ktor/websocket/WebSocketSession; } public abstract class io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport : io/modelcontextprotocol/kotlin/sdk/shared/Transport { diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index b5f15751..2f85dd6f 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -42,7 +42,7 @@ import kotlin.reflect.typeOf import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds -private val LOGGER = KotlinLogging.logger { } +private val logger = KotlinLogging.logger { } public const val IMPLEMENTATION_NAME: String = "mcp-ktor" @@ -204,6 +204,7 @@ public abstract class Protocol( } } + logger.info { "Starting transport" } return transport.start() } @@ -221,29 +222,29 @@ public abstract class Protocol( } private suspend fun onNotification(notification: JSONRPCNotification) { - LOGGER.trace { "Received notification: ${notification.method}" } + logger.trace { "Received notification: ${notification.method}" } val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler if (handler == null) { - LOGGER.trace { "No handler found for notification: ${notification.method}" } + logger.trace { "No handler found for notification: ${notification.method}" } return } try { handler(notification) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling notification: ${notification.method}" } + logger.error(cause) { "Error handling notification: ${notification.method}" } onError(cause) } } private suspend fun onRequest(request: JSONRPCRequest) { - LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" } + logger.trace { "Received request: ${request.method} (id: ${request.id})" } val handler = requestHandlers[request.method] ?: fallbackRequestHandler if (handler === null) { - LOGGER.trace { "No handler found for request: ${request.method}" } + logger.trace { "No handler found for request: ${request.method}" } try { transport?.send( JSONRPCResponse( @@ -255,7 +256,7 @@ public abstract class Protocol( ) ) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error sending method not found response" } + logger.error(cause) { "Error sending method not found response" } onError(cause) } return @@ -263,7 +264,7 @@ public abstract class Protocol( try { val result = handler(request, RequestHandlerExtra()) - LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } + logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } transport?.send( JSONRPCResponse( @@ -273,7 +274,7 @@ public abstract class Protocol( ) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } + logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } try { transport?.send( @@ -286,14 +287,14 @@ public abstract class Protocol( ) ) } catch (sendError: Throwable) { - LOGGER.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } + logger.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } // Optionally implement fallback behavior here } } } private fun onProgress(notification: ProgressNotification) { - LOGGER.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" } + logger.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" } val progress = notification.params.progress val total = notification.params.total val message = notification.params.message @@ -304,7 +305,7 @@ public abstract class Protocol( val error = Error( "Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}", ) - LOGGER.error { error.message } + logger.error { error.message } onError(error) return } @@ -382,9 +383,9 @@ public abstract class Protocol( request: Request, options: RequestOptions? = null, ): T { - LOGGER.trace { "Sending request: ${request.method}" } + logger.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() - val transport = this@Protocol.transport ?: throw Error("Not connected") + val transport = transport ?: throw Error("Not connected") if (this@Protocol.options?.enforceStrictCapabilities == true) { assertCapabilityForMethod(request.method) @@ -394,7 +395,7 @@ public abstract class Protocol( val messageId = message.id if (options?.onProgress != null) { - LOGGER.trace { "Registering progress handler for request id: $messageId" } + logger.trace { "Registering progress handler for request id: $messageId" } _progressHandlers.update { current -> current.put(messageId, options.onProgress) } @@ -427,7 +428,7 @@ public abstract class Protocol( val notification = CancelledNotification( params = CancelledNotification.Params( - requestId = messageId, + requestId = messageId, reason = reason.message ?: "Unknown" ) ) @@ -444,12 +445,12 @@ public abstract class Protocol( val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT try { withTimeout(timeout) { - LOGGER.trace { "Sending request message with id: $messageId" } + logger.trace { "Sending request message with id: $messageId" } this@Protocol.transport?.send(message) } return result.await() } catch (cause: TimeoutCancellationException) { - LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } + logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } cancel( McpError( ErrorCode.Defined.RequestTimeout.code, @@ -466,7 +467,7 @@ public abstract class Protocol( * Emits a notification, which is a one-way message that does not expect a response. */ public suspend fun notification(notification: Notification) { - LOGGER.trace { "Sending notification: ${notification.method}" } + logger.trace { "Sending notification: ${notification.method}" } val transport = this.transport ?: error("Not connected") assertNotificationCapability(notification.method) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index 29e7b866..3eea2d0f 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.websocket.Frame import io.ktor.websocket.WebSocketSession import io.ktor.websocket.close @@ -17,6 +18,9 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi public const val MCP_SUBPROTOCOL: String = "mcp" +private val logger = KotlinLogging.logger {} + + /** * Abstract class representing a WebSocket transport for the Model Context Protocol (MCP). * Handles communication over a WebSocket session. @@ -40,6 +44,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { protected abstract suspend fun initializeSession() override suspend fun start() { + logger.debug { "Starting websocket transport" } + if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error( "WebSocketClientTransport already started! " + @@ -53,7 +59,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { while (true) { val message = try { session.incoming.receive() - } catch (_: ClosedReceiveChannelException) { + } catch (e: ClosedReceiveChannelException) { + logger.debug { "Closed receive channel, exiting" } return@launch } @@ -84,6 +91,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } override suspend fun send(message: JSONRPCMessage) { + logger.debug { "Sending message" } if (!initialized.load()) { error("Not connected") } @@ -96,6 +104,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { error("Not connected") } + logger.debug { "Closing websocket session" } session.close() session.coroutineContext.job.join() } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 056c7854..88bdd94d 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -8,94 +8,96 @@ import io.ktor.server.response.respond import io.ktor.server.routing.Routing import io.ktor.server.routing.RoutingContext import io.ktor.server.routing.post -import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession import io.ktor.server.sse.sse -import io.ktor.util.collections.ConcurrentMap import io.ktor.utils.io.KtorDsl +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentMap +import kotlinx.collections.immutable.toPersistentMap private val logger = KotlinLogging.logger {} -@KtorDsl -public fun Routing.mcp(path: String, block: () -> Server) { - route(path) { - mcp(block) +internal class SseTransportManager(transports: Map = emptyMap()) { + private val transports: AtomicRef> = atomic(transports.toPersistentMap()) + + fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId] + + fun addTransport(transport: SseServerTransport) { + transports.update { it.put(transport.sessionId, transport) } + } + + fun removeTransport(sessionId: String) { + transports.update { it.remove(sessionId) } } } -/** - * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). - */ +/* +* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). +*/ @KtorDsl public fun Routing.mcp(block: () -> Server) { - val transports = ConcurrentMap() + val sseTransportManager = SseTransportManager() sse { - mcpSseEndpoint("", transports, block) + mcpSseEndpoint("", sseTransportManager, block) } post { - mcpPostEndpoint(transports) + mcpPostEndpoint(sseTransportManager) } } @Suppress("FunctionName") -@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.WARNING) +@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.ERROR) public fun Application.MCP(block: () -> Server) { mcp(block) } @KtorDsl public fun Application.mcp(block: () -> Server) { - val transports = ConcurrentMap() - install(SSE) routing { - sse("/sse") { - mcpSseEndpoint("/message", transports, block) - } - - post("/message") { - mcpPostEndpoint(transports) - } + mcp(block) } } -private suspend fun ServerSSESession.mcpSseEndpoint( +internal suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, - transports: ConcurrentMap, + sseTransportManager: SseTransportManager, block: () -> Server, ) { - val transport = mcpSseTransport(postEndpoint, transports) + val transport = mcpSseTransport(postEndpoint, sseTransportManager) val server = block() server.onClose { logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } - transports.remove(transport.sessionId) + sseTransportManager.removeTransport(transport.sessionId) } server.connect(transport) + logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } } internal fun ServerSSESession.mcpSseTransport( postEndpoint: String, - transports: ConcurrentMap, + sseTransportManager: SseTransportManager, ): SseServerTransport { val transport = SseServerTransport(postEndpoint, this) - transports[transport.sessionId] = transport - + sseTransportManager.addTransport(transport) logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } return transport } internal suspend fun RoutingContext.mcpPostEndpoint( - transports: ConcurrentMap, + sseTransportManager: SseTransportManager, ) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { @@ -105,7 +107,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint( logger.debug { "Received message for sessionId: $sessionId" } - val transport = transports[sessionId] + val transport = sseTransportManager.getTransport(sessionId) if (transport == null) { logger.warn { "Session not found for sessionId: $sessionId" } call.respond(HttpStatusCode.NotFound, "Session not found") diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index f0655fd9..03277347 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -3,58 +3,35 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult -import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest -import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult -import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject -import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeRequest -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION import io.modelcontextprotocol.kotlin.sdk.ListPromptsRequest import io.modelcontextprotocol.kotlin.sdk.ListPromptsResult import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult -import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest -import io.modelcontextprotocol.kotlin.sdk.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.ListToolsResult -import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.Method -import io.modelcontextprotocol.kotlin.sdk.PingRequest import io.modelcontextprotocol.kotlin.sdk.Prompt import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult import io.modelcontextprotocol.kotlin.sdk.Resource -import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification -import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations -import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions -import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import io.modelcontextprotocol.kotlin.sdk.shared.Transport import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus +import kotlinx.collections.immutable.persistentListOf import kotlinx.collections.immutable.persistentMapOf import kotlinx.collections.immutable.toPersistentSet -import kotlinx.coroutines.CompletableDeferred -import kotlinx.serialization.json.JsonObject private val logger = KotlinLogging.logger {} @@ -81,25 +58,14 @@ public class ServerOptions( */ public open class Server( private val serverInfo: Implementation, - options: ServerOptions, -) : Protocol(options) { + private val options: ServerOptions, +) { + private val sessions = atomic(persistentListOf()) + private var _onInitialized: (() -> Unit) = {} + private var _onConnect: (() -> Unit) = {} private var _onClose: () -> Unit = {} - /** - * The client's reported capabilities after initialization. - */ - public var clientCapabilities: ClientCapabilities? = null - private set - - /** - * The client's version information after initialization. - */ - public var clientVersion: Implementation? = null - private set - - private val capabilities: ServerCapabilities = options.capabilities - private val _tools = atomic(persistentMapOf()) private val _prompts = atomic(persistentMapOf()) private val _resources = atomic(persistentMapOf()) @@ -110,55 +76,83 @@ public open class Server( public val resources: Map get() = _resources.value - init { - logger.debug { "Initializing MCP server with capabilities: $capabilities" } + public suspend fun close() { + logger.debug { "Closing MCP server" } + sessions.value.forEach { it.close() } + _onClose() + } - // Core protocol handlers - setRequestHandler(Method.Defined.Initialize) { request, _ -> - handleInitialize(request) - } - setNotificationHandler(Method.Defined.NotificationsInitialized) { - _onInitialized() - CompletableDeferred(Unit) - } + /** + * Starts a new server session with the given transport and initializes + * internal request handlers based on the server's capabilities. + * + * @param transport The transport layer to connect the session with. + * @return The initialized and connected server session. + */ + public suspend fun connect(transport: Transport): ServerSession { + val session = ServerSession(serverInfo, options) // Internal handlers for tools - if (capabilities.tools != null) { - setRequestHandler(Method.Defined.ToolsList) { _, _ -> + if (options.capabilities.tools != null) { + session.setRequestHandler(Method.Defined.ToolsList) { _, _ -> handleListTools() } - setRequestHandler(Method.Defined.ToolsCall) { request, _ -> + session.setRequestHandler(Method.Defined.ToolsCall) { request, _ -> handleCallTool(request) } } // Internal handlers for prompts - if (capabilities.prompts != null) { - setRequestHandler(Method.Defined.PromptsList) { _, _ -> + if (options.capabilities.prompts != null) { + session.setRequestHandler(Method.Defined.PromptsList) { _, _ -> handleListPrompts() } - setRequestHandler(Method.Defined.PromptsGet) { request, _ -> + session.setRequestHandler(Method.Defined.PromptsGet) { request, _ -> handleGetPrompt(request) } } // Internal handlers for resources - if (capabilities.resources != null) { - setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + if (options.capabilities.resources != null) { + session.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> handleListResources() } - setRequestHandler(Method.Defined.ResourcesRead) { request, _ -> + session.setRequestHandler(Method.Defined.ResourcesRead) { request, _ -> handleReadResource(request) } - setRequestHandler(Method.Defined.ResourcesTemplatesList) { _, _ -> + session.setRequestHandler(Method.Defined.ResourcesTemplatesList) { _, _ -> handleListResourceTemplates() } } + + logger.debug { "Server session connecting to transport" } + session.connect(transport) + logger.debug { "Server session successfully connected to transport" } + sessions.update { it.add(session) } + + _onConnect() + return session + } + + /** + * Registers a callback to be invoked when the new server session connected. + */ + public fun onConnect(block: () -> Unit) { + val old = _onConnect + _onConnect = { + old() + block() + } } /** * Registers a callback to be invoked when the server has completed initialization. */ + @Deprecated( + "Will be removed with Protocol inheritance. Use onConnect instead.", + ReplaceWith("onConnect"), + DeprecationLevel.WARNING + ) public fun onInitialized(block: () -> Unit) { val old = _onInitialized _onInitialized = { @@ -178,14 +172,6 @@ public open class Server( } } - /** - * Called when the server connection is closing. - */ - override fun onClose() { - logger.info { "Server connection closing" } - _onClose() - } - /** * Registers a single tool. The client can then call this tool. * @@ -194,7 +180,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun addTool(tool: Tool, handler: suspend (CallToolRequest) -> CallToolResult) { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to add tool '${tool.name}': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.") } @@ -234,7 +220,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun addTools(toolsToAdd: List) { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to add tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -250,7 +236,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun removeTool(name: String): Boolean { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to remove tool '$name': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -277,7 +263,7 @@ public open class Server( * @throws IllegalStateException If the server does not support tools. */ public fun removeTools(toolNames: List): Int { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to remove tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -304,7 +290,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompt(prompt: Prompt, promptProvider: suspend (GetPromptRequest) -> GetPromptResult) { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to add prompt '${prompt.name}': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -338,7 +324,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompts(promptsToAdd: List) { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to add prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -354,7 +340,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompt(name: String): Boolean { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to remove prompt '$name': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -381,7 +367,7 @@ public open class Server( * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompts(promptNames: List): Int { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to remove prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -418,7 +404,7 @@ public open class Server( mimeType: String = "text/html", readHandler: suspend (ReadResourceRequest) -> ReadResourceResult ) { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -438,7 +424,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun addResources(resourcesToAdd: List) { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to add resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -454,7 +440,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun removeResource(uri: String): Boolean { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to remove resource '$uri': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -481,7 +467,7 @@ public open class Server( * @throws IllegalStateException If the server does not support resources. */ public fun removeResources(uris: List): Int { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to remove resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -501,123 +487,7 @@ public open class Server( return removedCount } - /** - * Sends a ping request to the client to check connectivity. - * - * @return The result of the ping request. - * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. - */ - public suspend fun ping(): EmptyRequestResult { - return request(PingRequest()) - } - - /** - * Creates a message using the server's sampling capability. - * - * @param params The parameters for creating a message. - * @param options Optional request options. - * @return The created message result. - * @throws IllegalStateException If the server does not support sampling or if the request fails. - */ - public suspend fun createMessage( - params: CreateMessageRequest, - options: RequestOptions? = null - ): CreateMessageResult { - logger.debug { "Creating message with params: $params" } - return request(params, options) - } - - /** - * Lists the available "roots" from the client's perspective (if supported). - * - * @param params JSON parameters for the request, usually empty. - * @param options Optional request options. - * @return The list of roots. - * @throws IllegalStateException If the server or client does not support roots. - */ - public suspend fun listRoots( - params: JsonObject = EmptyJsonObject, - options: RequestOptions? = null - ): ListRootsResult { - logger.debug { "Listing roots with params: $params" } - return request(ListRootsRequest(params), options) - } - - public suspend fun createElicitation( - message: String, - requestedSchema: RequestedSchema, - options: RequestOptions? = null - ): CreateElicitationResult { - logger.debug { "Creating elicitation with message: $message" } - return request(CreateElicitationRequest(message, requestedSchema), options) - } - - /** - * Sends a logging message notification to the client. - * - * @param params The logging message notification parameters. - */ - public suspend fun sendLoggingMessage(params: LoggingMessageNotification) { - logger.trace { "Sending logging message: ${params.params.data}" } - notification(params) - } - - /** - * Sends a resource-updated notification to the client, indicating that a specific resource has changed. - * - * @param params Details of the updated resource. - */ - public suspend fun sendResourceUpdated(params: ResourceUpdatedNotification) { - logger.debug { "Sending resource updated notification for: ${params.params.uri}" } - notification(params) - } - - /** - * Sends a notification to the client indicating that the list of resources has changed. - */ - public suspend fun sendResourceListChanged() { - logger.debug { "Sending resource list changed notification" } - notification(ResourceListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of tools has changed. - */ - public suspend fun sendToolListChanged() { - logger.debug { "Sending tool list changed notification" } - notification(ToolListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of prompts has changed. - */ - public suspend fun sendPromptListChanged() { - logger.debug { "Sending prompt list changed notification" } - notification(PromptListChangedNotification()) - } - // --- Internal Handlers --- - - private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { - logger.info { "Handling initialize request from client ${request.clientInfo}" } - clientCapabilities = request.capabilities - clientVersion = request.clientInfo - - val requestedVersion = request.protocolVersion - val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { - requestedVersion - } else { - logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } - LATEST_PROTOCOL_VERSION - } - - return InitializeResult( - protocolVersion = protocolVersion, - capabilities = capabilities, - serverInfo = serverInfo - ) - } - private suspend fun handleListTools(): ListToolsResult { val toolList = tools.values.map { it.tool } return ListToolsResult(tools = toolList, nextCursor = null) @@ -668,136 +538,6 @@ public open class Server( // If you have resource templates, return them here. For now, return empty. return ListResourceTemplatesResult(listOf()) } - - /** - * Asserts that the client supports the capability required for the given [method]. - * - * This method is automatically called by the [Protocol] framework before handling requests. - * Throws [IllegalStateException] if the capability is not supported. - * - * @param method The method for which we are asserting capability. - */ - override fun assertCapabilityForMethod(method: Method) { - logger.trace { "Asserting capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (clientCapabilities?.sampling == null) { - logger.error { "Client capability assertion failed: sampling not supported" } - throw IllegalStateException("Client does not support sampling (required for ${method.value})") - } - } - - "roots/list" -> { - if (clientCapabilities?.roots == null) { - throw IllegalStateException("Client does not support listing roots (required for ${method.value})") - } - } - - "elicitation/create" -> { - if (clientCapabilities?.elicitation == null) { - throw IllegalStateException("Client does not support elicitation (required for ${method.value})") - } - } - - "ping" -> { - // No specific capability required - } - } - } - - /** - * Asserts that the server can handle the specified notification method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. - * - * @param method The notification method. - */ - override fun assertNotificationCapability(method: Method) { - logger.trace { "Asserting notification capability for method: ${method.value}" } - when (method.value) { - "notifications/message" -> { - if (capabilities.logging == null) { - logger.error { "Server capability assertion failed: logging not supported" } - throw IllegalStateException("Server does not support logging (required for ${method.value})") - } - } - - "notifications/resources/updated", - "notifications/resources/list_changed" -> { - if (capabilities.resources == null) { - throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") - } - } - - "notifications/tools/list_changed" -> { - if (capabilities.tools == null) { - throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") - } - } - - "notifications/prompts/list_changed" -> { - if (capabilities.prompts == null) { - throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") - } - } - - "notifications/cancelled", - "notifications/progress" -> { - // Always allowed - } - } - } - - /** - * Asserts that the server can handle the specified request method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. - * - * @param method The request method. - */ - override fun assertRequestHandlerCapability(method: Method) { - logger.trace { "Asserting request handler capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (capabilities.sampling == null) { - logger.error { "Server capability assertion failed: sampling not supported" } - throw IllegalStateException("Server does not support sampling (required for $method)") - } - } - - "logging/setLevel" -> { - if (capabilities.logging == null) { - throw IllegalStateException("Server does not support logging (required for $method)") - } - } - - "prompts/get", - "prompts/list" -> { - if (capabilities.prompts == null) { - throw IllegalStateException("Server does not support prompts (required for $method)") - } - } - - "resources/list", - "resources/templates/list", - "resources/read" -> { - if (capabilities.resources == null) { - throw IllegalStateException("Server does not support resources (required for $method)") - } - } - - "tools/call", - "tools/list" -> { - if (capabilities.tools == null) { - throw IllegalStateException("Server does not support tools (required for $method)") - } - } - - "ping", "initialize" -> { - // No capability required - } - } - } } /** diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt new file mode 100644 index 00000000..5952bf25 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -0,0 +1,343 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult +import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest +import io.modelcontextprotocol.kotlin.sdk.ListRootsResult +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.shared.Protocol +import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import kotlinx.coroutines.CompletableDeferred +import kotlinx.serialization.json.JsonObject + +private val logger = KotlinLogging.logger {} + +public open class ServerSession( + private val serverInfo: Implementation, + options: ServerOptions, +) : Protocol(options) { + private var _onInitialized: (() -> Unit) = {} + private var _onClose: () -> Unit = {} + + init { + // Core protocol handlers + setRequestHandler(Method.Defined.Initialize) { request, _ -> + handleInitialize(request) + } + setNotificationHandler(Method.Defined.NotificationsInitialized) { + _onInitialized() + CompletableDeferred(Unit) + } + } + + /** + * The capabilities supported by the server, related to the session. + */ + private val serverCapabilities = options.capabilities + + /** + * The client's reported capabilities after initialization. + */ + public var clientCapabilities: ClientCapabilities? = null + private set + + /** + * The client's version information after initialization. + */ + public var clientVersion: Implementation? = null + private set + + /** + * Registers a callback to be invoked when the server has completed initialization. + */ + public fun onInitialized(block: () -> Unit) { + val old = _onInitialized + _onInitialized = { + old() + block() + } + } + + /** + * Registers a callback to be invoked when the server session is closing. + */ + public fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + /** + * Called when the server session is closing. + */ + override fun onClose() { + logger.debug { "Server connection closing" } + _onClose() + } + + /** + * Sends a ping request to the client to check connectivity. + * + * @return The result of the ping request. + * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. + */ + public suspend fun ping(): EmptyRequestResult { + return request(PingRequest()) + } + + /** + * Creates a message using the server's sampling capability. + * + * @param params The parameters for creating a message. + * @param options Optional request options. + * @return The created message result. + * @throws IllegalStateException If the server does not support sampling or if the request fails. + */ + public suspend fun createMessage( + params: CreateMessageRequest, + options: RequestOptions? = null + ): CreateMessageResult { + logger.debug { "Creating message with params: $params" } + return request(params, options) + } + + /** + * Lists the available "roots" from the client's perspective (if supported). + * + * @param params JSON parameters for the request, usually empty. + * @param options Optional request options. + * @return The list of roots. + * @throws IllegalStateException If the server or client does not support roots. + */ + public suspend fun listRoots( + params: JsonObject = EmptyJsonObject, + options: RequestOptions? = null + ): ListRootsResult { + logger.debug { "Listing roots with params: $params" } + return request(ListRootsRequest(params), options) + } + + public suspend fun createElicitation( + message: String, + requestedSchema: RequestedSchema, + options: RequestOptions? = null + ): CreateElicitationResult { + logger.debug { "Creating elicitation with message: $message" } + return request(CreateElicitationRequest(message, requestedSchema), options) + } + + /** + * Sends a logging message notification to the client. + * + * @param notification The logging message notification. + */ + public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) { + logger.trace { "Sending logging message: ${notification.params.data}" } + notification(notification) + } + + /** + * Sends a resource-updated notification to the client, indicating that a specific resource has changed. + * + * @param notification Details of the updated resource. + */ + public suspend fun sendResourceUpdated(notification: ResourceUpdatedNotification) { + logger.debug { "Sending resource updated notification for: ${notification.params.uri}" } + notification(notification) + } + + /** + * Sends a notification to the client indicating that the list of resources has changed. + */ + public suspend fun sendResourceListChanged() { + logger.debug { "Sending resource list changed notification" } + notification(ResourceListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of tools has changed. + */ + public suspend fun sendToolListChanged() { + logger.debug { "Sending tool list changed notification" } + notification(ToolListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of prompts has changed. + */ + public suspend fun sendPromptListChanged() { + logger.debug { "Sending prompt list changed notification" } + notification(PromptListChangedNotification()) + } + + /** + * Asserts that the client supports the capability required for the given [method]. + * + * This method is automatically called by the [Protocol] framework before handling requests. + * Throws [IllegalStateException] if the capability is not supported. + * + * @param method The method for which we are asserting capability. + */ + override fun assertCapabilityForMethod(method: Method) { + logger.trace { "Asserting capability for method: ${method.value}" } + when (method.value) { + "sampling/createMessage" -> { + if (clientCapabilities?.sampling == null) { + logger.error { "Client capability assertion failed: sampling not supported" } + throw IllegalStateException("Client does not support sampling (required for ${method.value})") + } + } + + "roots/list" -> { + if (clientCapabilities?.roots == null) { + throw IllegalStateException("Client does not support listing roots (required for ${method.value})") + } + } + + "elicitation/create" -> { + if (clientCapabilities?.elicitation == null) { + throw IllegalStateException("Client does not support elicitation (required for ${method.value})") + } + } + + "ping" -> { + // No specific capability required + } + } + } + + /** + * Asserts that the server can handle the specified notification method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. + * + * @param method The notification method. + */ + override fun assertNotificationCapability(method: Method) { + logger.trace { "Asserting notification capability for method: ${method.value}" } + when (method.value) { + "notifications/message" -> { + if (serverCapabilities.logging == null) { + logger.error { "Server capability assertion failed: logging not supported" } + throw IllegalStateException("Server does not support logging (required for ${method.value})") + } + } + + "notifications/resources/updated", + "notifications/resources/list_changed" -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") + } + } + + "notifications/tools/list_changed" -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") + } + } + + "notifications/prompts/list_changed" -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") + } + } + + "notifications/cancelled", + "notifications/progress" -> { + // Always allowed + } + } + } + + /** + * Asserts that the server can handle the specified request method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. + * + * @param method The request method. + */ + override fun assertRequestHandlerCapability(method: Method) { + logger.trace { "Asserting request handler capability for method: ${method.value}" } + when (method.value) { + "sampling/createMessage" -> { + if (serverCapabilities.sampling == null) { + logger.error { "Server capability assertion failed: sampling not supported" } + throw IllegalStateException("Server does not support sampling (required for $method)") + } + } + + "logging/setLevel" -> { + if (serverCapabilities.logging == null) { + throw IllegalStateException("Server does not support logging (required for $method)") + } + } + + "prompts/get", + "prompts/list" -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support prompts (required for $method)") + } + } + + "resources/list", + "resources/templates/list", + "resources/read" -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support resources (required for $method)") + } + } + + "tools/call", + "tools/list" -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support tools (required for $method)") + } + } + + "ping", "initialize" -> { + // No capability required + } + } + } + + private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { + logger.debug { "Handling initialization request from client" } + clientCapabilities = request.capabilities + clientVersion = request.clientInfo + + val requestedVersion = request.protocolVersion + val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { + requestedVersion + } else { + logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } + LATEST_PROTOCOL_VERSION + } + + return InitializeResult( + protocolVersion = protocolVersion, + capabilities = serverCapabilities, + serverInfo = serverInfo + ) + } +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index 9301749b..14cc81b5 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -1,12 +1,99 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.application.Application +import io.ktor.server.application.install import io.ktor.server.routing.Route +import io.ktor.server.routing.Routing +import io.ktor.server.routing.routing import io.ktor.server.websocket.WebSocketServerSession +import io.ktor.server.websocket.WebSockets import io.ktor.server.websocket.webSocket +import io.ktor.utils.io.CancellationException +import io.ktor.utils.io.KtorDsl import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +import kotlinx.coroutines.awaitCancellation + +private val logger = KotlinLogging.logger {} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Routing.mcpWebSocket( + block: () -> Server +) { + webSocket { + mcpWebSocketEndpoint(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Routing.mcpWebSocket( + path: String, + block: () -> Server +) { + + webSocket(path) { + mcpWebSocketEndpoint(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Application.mcpWebSocket( + block: () -> Server +) { + install(WebSockets) + + routing { + mcpWebSocket(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket at the specified path. + */ +@KtorDsl +public fun Application.mcpWebSocket( + path: String, + block: () -> Server +) { + install(WebSockets) + + routing { + mcpWebSocket(path, block) + } +} + +internal suspend fun WebSocketServerSession.mcpWebSocketEndpoint( + block: () -> Server +) { + logger.info { "Ktor Server establishing new connection" } + val transport = createMcpTransport(this) + val server = block() + var session: ServerSession? = null + try { + session = server.connect(transport) + awaitCancellation() + } catch (e: CancellationException) { + session?.close() + } +} + +private fun createMcpTransport( + webSocketSession: WebSocketServerSession, +): WebSocketMcpServerTransport { + return WebSocketMcpServerTransport(webSocketSession) +} /** * Registers a WebSocket route that establishes an MCP (Model Context Protocol) server session. @@ -14,6 +101,11 @@ import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), options ?: ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocket( options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}, @@ -23,6 +115,19 @@ public fun Route.mcpWebSocket( } } +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket(block)"), + DeprecationLevel.WARNING +) +public fun Route.mcpWebSocket( + block: () -> Server +) { + webSocket { + block().connect(createMcpTransport(this)) + } +} + /** * Registers a WebSocket route at the specified [path] that establishes an MCP server session. * @@ -30,6 +135,11 @@ public fun Route.mcpWebSocket( * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ +@Deprecated( + "Use mcpWebSocket with a path and a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket(path) { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), options ?: ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocket( path: String, options: ServerOptions? = null, @@ -45,6 +155,11 @@ public fun Route.mcpWebSocket( * * @param handler A suspend function that defines the behavior of the transport layer. */ +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocketTransport( handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, ) { @@ -62,6 +177,11 @@ public fun Route.mcpWebSocketTransport( * @param path The URL path at which to register the WebSocket route. * @param handler A suspend function that defines the behavior of the transport layer. */ +@Deprecated( + "Use mcpWebSocket with a path and a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket(path) { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) public fun Route.mcpWebSocketTransport( path: String, handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, @@ -74,7 +194,11 @@ public fun Route.mcpWebSocketTransport( } } - +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket { Server(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION), options ?: ServerOptions(capabilities = ServerCapabilities())) }"), + DeprecationLevel.WARNING +) private suspend fun Route.createMcpServer( session: WebSocketServerSession, options: ServerOptions?, @@ -100,9 +224,3 @@ private suspend fun Route.createMcpServer( handler(server) server.close() } - -private fun createMcpTransport( - session: WebSocketServerSession, -): WebSocketMcpServerTransport { - return WebSocketMcpServerTransport(session) -} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt index 45cb4df9..35885cb5 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt @@ -1,10 +1,14 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.http.HttpHeaders import io.ktor.server.websocket.WebSocketServerSession import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport + +private val logger = KotlinLogging.logger {} + /** * Server-side implementation of the MCP (Model Context Protocol) transport over WebSocket. * @@ -14,6 +18,7 @@ public class WebSocketMcpServerTransport( override val session: WebSocketServerSession, ) : WebSocketMcpTransport() { override suspend fun initializeSession() { + logger.debug { "Checking session headers" } val subprotocol = session.call.request.headers[HttpHeaders.SecWebSocketProtocol] if (subprotocol != MCP_SUBPROTOCOL) { error("Invalid subprotocol: $subprotocol, expected $MCP_SUBPROTOCOL") diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index 26330ce1..c2ca13fa 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -30,8 +30,16 @@ import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.ServerSession import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import kotlin.coroutines.cancellation.CancellationException +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertTrue +import kotlin.test.fail import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.cancel @@ -43,13 +51,6 @@ import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import kotlinx.serialization.json.putJsonObject -import kotlin.coroutines.cancellation.CancellationException -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertIs -import kotlin.test.assertTrue -import kotlin.test.fail class ClientTest { @Test @@ -241,25 +242,6 @@ class ClientTest { serverOptions ) - server.setRequestHandler(Method.Defined.Initialize) { _, _ -> - InitializeResult( - protocolVersion = LATEST_PROTOCOL_VERSION, - capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) - ), - serverInfo = Implementation(name = "test", version = "1.0") - ) - } - - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - ListResourcesResult(resources = emptyList(), nextCursor = null) - } - - server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> - ListToolsResult(tools = emptyList(), nextCursor = null) - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( @@ -269,15 +251,36 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) } ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> + InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null) + ), + serverInfo = Implementation(name = "test", version = "1.0") + ) + } + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + ListResourcesResult(resources = emptyList(), nextCursor = null) + } + + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + ListToolsResult(tools = emptyList(), nextCursor = null) + } // Server supports resources and tools, but not prompts val caps = client.serverCapabilities assertEquals(ServerCapabilities.Resources(null, null), caps?.resources) @@ -368,24 +371,27 @@ class ClientTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") } ).joinAll() + val serverSession = serverSessionResult.await() // These should not throw val jsonObject = buildJsonObject { put("name", "John") put("age", 30) put("isStudent", false) } - server.sendLoggingMessage( + serverSession.sendLoggingMessage( LoggingMessageNotification( params = LoggingMessageNotification.Params( level = LoggingLevel.info, @@ -393,11 +399,11 @@ class ClientTest { ) ) ) - server.sendResourceListChanged() + serverSession.sendResourceListChanged() // This should fail because the server doesn't have the tools capability val ex = assertFailsWith { - server.sendToolListChanged() + serverSession.sendToolListChanged() } assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") == true) } @@ -418,19 +424,6 @@ class ClientTest { val def = CompletableDeferred() val defTimeOut = CompletableDeferred() - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - // Simulate delay - def.complete(Unit) - try { - delay(1000) - } catch (e: CancellationException) { - defTimeOut.complete(Unit) - throw e - } - ListResourcesResult(resources = emptyList()) - fail("Shouldn't have been called") - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( @@ -438,17 +431,35 @@ class ClientTest { options = ClientOptions(capabilities = ClientCapabilities()) ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") } ).joinAll() + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate delay + def.complete(Unit) + try { + delay(1000) + } catch (e: CancellationException) { + defTimeOut.complete(Unit) + throw e + } + ListResourcesResult(resources = emptyList()) + fail("Shouldn't have been called") + } + + val defCancel = CompletableDeferred() val job = launch { try { @@ -478,37 +489,40 @@ class ClientTest { ) ) - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - // Simulate a delayed response - // Wait ~100ms unless canceled - try { - withTimeout(100L) { - // Just delay here, if timeout is 0 on the client side, this won't return in time - delay(100) - } - } catch (_: Exception) { - // If aborted, just rethrow or return early - } - ListResourcesResult(resources = emptyList()) - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions(capabilities = ClientCapabilities()) ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") } ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate a delayed response + // Wait ~100ms unless canceled + try { + withTimeout(100L) { + // Just delay here, if timeout is 0 on the client side, this won't return in time + delay(100) + } + } catch (_: Exception) { + // If aborted, just rethrow or return early + } + ListResourcesResult(resources = emptyList()) + } + // Request with 1 msec timeout should fail immediately val ex = assertFailsWith { withTimeout(1) { @@ -559,7 +573,36 @@ class ClientTest { serverOptions ) - server.setRequestHandler(Method.Defined.Initialize) { _, _ -> + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(sampling = EmptyJsonObject), + ) + ) + + var receivedMessage: JSONRPCMessage? = null + clientTransport.onMessage { msg -> + receivedMessage = msg + } + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } + ).joinAll() + + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> InitializeResult( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( @@ -569,6 +612,7 @@ class ClientTest { serverInfo = Implementation(name = "test", version = "1.0") ) } + val serverListToolsResult = ListToolsResult( tools = listOf( Tool( @@ -582,33 +626,10 @@ class ClientTest { ), nextCursor = null ) - server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> serverListToolsResult } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() - - val client = Client( - clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions( - capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ) - ) - - var receivedMessage: JSONRPCMessage? = null - clientTransport.onMessage { msg -> - receivedMessage = msg - } - - listOf( - launch { - client.connect(clientTransport) - }, - launch { - server.connect(serverTransport) - } - ).joinAll() - val serverCapabilities = client.serverCapabilities assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) @@ -651,15 +672,25 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() - val clientCapabilities = server.clientCapabilities + val serverSession = serverSessionResult.await() + + val clientCapabilities = serverSession.clientCapabilities assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots) - val listRootsResult = server.listRoots() + val listRootsResult = serverSession.listRoots() assertEquals(listRootsResult.roots, clientRoots) } @@ -772,16 +803,27 @@ class ClientTest { // Track notifications var rootListChangedNotificationReceived = false - server.setNotificationHandler(Method.Defined.NotificationsRootsListChanged) { - rootListChangedNotificationReceived = true - CompletableDeferred(Unit) - } + + + val serverSessionResult = CompletableDeferred() listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setNotificationHandler(Method.Defined.NotificationsRootsListChanged) { + rootListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + client.sendRootsListChanged() assertTrue( @@ -808,14 +850,24 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() + val serverSession = serverSessionResult.await() + // Verify that creating an elicitation throws an exception val exception = assertFailsWith { - server.createElicitation( + serverSession.createElicitation( message = "Please provide your GitHub username", requestedSchema = CreateElicitationRequest.RequestedSchema( properties = buildJsonObject { @@ -878,12 +930,22 @@ class ClientTest { ) ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + } ).joinAll() - val result = server.createElicitation( + val serverSession = serverSessionResult.await() + + val result = serverSession.createElicitation( message = elicitationMessage, requestedSchema = requestedSchema ) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index 23ddadf1..839c13ab 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -70,19 +70,6 @@ class SseTransportTest : BaseTransportTest() { install(ServerSSE) routing { mcp { mcpServer } -// sse { -// mcpSseTransport("", transports).apply { -// onMessage { -// send(it) -// } -// -// start() -// } -// } -// -// post { -// mcpPostEndpoint(transports) -// } } }.startSuspend(wait = false) @@ -110,22 +97,7 @@ class SseTransportTest : BaseTransportTest() { val server = embeddedServer(CIO, port = 0) { install(ServerSSE) routing { - mcp("/sse") { mcpServer } -// route("/sse") { -// sse { -// mcpSseTransport("", transports).apply { -// onMessage { -// send(it) -// } -// -// start() -// } -// } -// -// post { -// mcpPostEndpoint(transports) -// } -// } + mcp { mcpServer } } }.startSuspend(wait = false) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseBugReproductionTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseBugReproductionTest.kt new file mode 100644 index 00000000..e69de29b diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 19d84589..cf5f0fe1 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -2,68 +2,198 @@ package io.modelcontextprotocol.kotlin.sdk.integration import io.ktor.client.HttpClient import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.ApplicationStopped import io.ktor.server.application.install import io.ktor.server.cio.CIOApplicationEngine import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpSse +import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.mcp import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout import kotlin.test.Test -import kotlin.test.fail +import kotlin.test.assertTrue import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO class SseIntegrationTest { @Test fun `client should be able to connect to sse server`() = runTest { - val serverEngine = initServer() + var server: EmbeddedServer? = null var client: Client? = null + try { withContext(Dispatchers.Default) { - assertDoesNotThrow { client = initClient() } + withTimeout(1000) { + server = initServer() + client = initClient() + } } - } catch (e: Exception) { - fail("Failed to connect client: $e") } finally { client?.close() - // Make sure to stop the server - serverEngine.stopSuspend(1000, 2000) + server?.stop(1000, 2000) } } - private inline fun assertDoesNotThrow(block: () -> T): T { - return try { - block() - } catch (e: Throwable) { - fail("Expected no exception, but got: $e") + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open SSE from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single sse connection`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + client = initClient("Client A") + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } + } finally { + client?.close() + server?.stop(1000, 2000) } } - private suspend fun initClient(): Client { - return HttpClient(ClientCIO) { install(SSE) }.mcpSse("http://$URL:$PORT") + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open SSE connection #1 from Client A and note the sessionId= value. + * 2. Open SSE connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple sse connections`() = runTest { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + clientA = initClient("Client A") + clientB = initClient("Client B") + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } + } finally { + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) + } + } + + private suspend fun initClient(name: String = ""): Client { + val client = Client( + Implementation(name = name, version = "1.0.0") + ) + + val httpClient = HttpClient(ClientCIO) { + install(SSE) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpSseTransport { + url { + host = URL + port = PORT + } + } + + client.connect(transport) + + return client } private suspend fun initServer(): EmbeddedServer { val server = Server( - Implementation(name = "sse-e2e-test", version = "1.0.0"), - ServerOptions(capabilities = ServerCapabilities()), + Implementation(name = "sse-server", version = "1.0.0"), + ServerOptions( + capabilities = + ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)) + ), ) - return embeddedServer(ServerCIO, host = URL, port = PORT) { + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true + ) + ) + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}") + ) + ) + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { install(io.ktor.server.sse.SSE) routing { mcp { server } } - }.startSuspend(wait = false) + } + + ktorServer.monitor.subscribe(ApplicationStopped) { + println("SD -- [T] ktor server has been stopped") + } + + return ktorServer.startSuspend(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName) + ) + ) + + return (response?.messages?.first()?.content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") } companion object { diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt new file mode 100644 index 00000000..4e55f695 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt @@ -0,0 +1,207 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.client.HttpClient +import io.ktor.server.application.ApplicationStopped +import io.ktor.server.application.install +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpWebSocketTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcpWebSocket +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.test.assertTrue +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.websocket.WebSockets as ServerWebSockets +import io.ktor.client.plugins.websocket.WebSockets as ClientWebSocket + +class WebSocketIntegrationTest { + + @Test + fun `client should be able to connect to websocket server 2`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + client = initClient() + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open WebSocket from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single websocket connection`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + client = initClient() + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open WebSocket connection #1 from Client A and note the sessionId= value. + * 2. Open WebSocket connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple websocket connections`() = runTest { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + clientA = initClient() + clientB = initClient() + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } + } finally { + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) + } + } + + + private suspend fun initClient(name: String = ""): Client { + val client = Client( + Implementation(name = name, version = "1.0.0") + ) + + val httpClient = HttpClient(ClientCIO) { + install(ClientWebSocket) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpWebSocketTransport { + url { + host = URL + port = PORT + } + } + + client.connect(transport) + + return client + } + + + private suspend fun initServer(): EmbeddedServer { + val server = Server( + Implementation(name = "websocket-server", version = "1.0.0"), + ServerOptions( + capabilities = + ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)) + ), + ) + + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true + ) + ) + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}") + ) + ) + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerWebSockets) + routing { + mcpWebSocket(block = { server }) + } + } + + ktorServer.monitor.subscribe(ApplicationStopped) { + println("SD -- [T] ktor server has been stopped") + } + + return ktorServer.startSuspend(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName) + ) + ) + + return (response?.messages?.first()?.content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } + + companion object { + private const val PORT = 3002 + private const val URL = "localhost" + } +} \ No newline at end of file