From 9f816bfef1118b7a54ea974f9143d4e6ecd8f04f Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 25 Jul 2025 14:35:46 +0200 Subject: [PATCH] add ktlint and cleanup code --- .editorconfig | 37 +++ build.gradle.kts | 14 +- gradle/libs.versions.toml | 4 +- settings.gradle.kts | 1 - .../kotlin/sdk/client/Client.kt | 97 +++---- .../kotlin/sdk/client/KtorClient.kt | 6 +- .../kotlin/sdk/client/SSEClientTransport.kt | 1 + .../kotlin/sdk/client/StdioClientTransport.kt | 5 +- .../client/StreamableHttpClientTransport.kt | 29 +- .../StreamableHttpMcpKtorClientExtensions.kt | 2 +- .../WebSocketMcpKtorClientExtensions.kt | 4 +- .../kotlin/sdk/internal/utils.kt | 2 +- .../kotlin/sdk/server/KtorServer.kt | 6 +- .../kotlin/sdk/server/SSEServerTransport.kt | 13 +- .../kotlin/sdk/server/Server.kt | 78 ++--- .../kotlin/sdk/server/StdioServerTransport.kt | 5 +- .../WebSocketMcpKtorServerExtensions.kt | 32 +-- .../sdk/server/WebSocketMcpServerTransport.kt | 4 +- .../kotlin/sdk/shared/Protocol.kt | 65 +++-- .../kotlin/sdk/shared/ReadBuffer.kt | 11 +- .../sdk/shared/WebSocketMcpTransport.kt | 3 +- .../modelcontextprotocol/kotlin/sdk/types.kt | 200 ++++++------- .../kotlin/sdk/types.util.kt | 122 ++++---- .../kotlin/AudioContentSerializationTest.kt | 5 +- .../kotlin/CallToolResultUtilsTest.kt | 12 +- .../kotlin/sdk/ToolSerializationTest.kt | 271 +++++++++++------- .../kotlin/sdk/client/BaseTransportTest.kt | 5 +- .../sdk/client/InMemoryTransportTest.kt | 2 +- .../kotlin/sdk/client/TypesTest.kt | 2 +- .../sdk/integration/SseIntegrationTest.kt | 16 +- .../kotlin/sdk/shared/ReadBufferTest.kt | 1 - .../kotlin/sdk/internal/utils.ios.kt | 2 +- .../kotlin/sdk/internal/utils.js.kt | 2 +- .../kotlin/sdk/internal/utils.jvm.kt | 2 +- .../kotlin/client/ClientIntegrationTest.kt | 8 +- src/jvmTest/kotlin/client/ClientTest.kt | 267 ++++++++--------- .../kotlin/client/StdioClientTransportTest.kt | 6 +- .../StreamableHttpClientTransportTest.kt | 94 +++--- src/jvmTest/kotlin/server/ServerTest.kt | 111 +++---- .../kotlin/server/StdioServerTransportTest.kt | 16 +- .../kotlin/sdk/internal/utils.wasmJs.kt | 2 +- 41 files changed, 809 insertions(+), 756 deletions(-) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..5c283ed0 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,37 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 +max_line_length = 120 + +[*.json] +indent_size = 2 + +[{*.yaml,*.yml}] +indent_size = 2 + +[*.{kt,kts}] +ij_kotlin_code_style_defaults = KOTLIN_OFFICIAL + +# Disable wildcard imports entirely +ij_kotlin_name_count_to_use_star_import = 2147483647 +ij_kotlin_name_count_to_use_star_import_for_members = 2147483647 +ij_kotlin_packages_to_use_import_on_demand = unset + +ktlint_code_style = intellij_idea +ktlint_experimental = enabled +ktlint_standard_filename = disabled +ktlint_standard_no-empty-first-line-in-class-body = disabled +ktlint_class_signature_rule_force_multiline_when_parameter_count_greater_or_equal_than = 4 +ktlint_function_signature_rule_force_multiline_when_parameter_count_greater_or_equal_than = 4 +ktlint_standard_chain-method-continuation = disabled +ktlint_ignore_back_ticked_identifier = true +ktlint_standard_multiline-expression-wrapping = disabled +ktlint_standard_when-entry-bracing = disabled + +[*/build/**/*] +ktlint = disabled \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index ec12f22f..2525d0b7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,6 +12,7 @@ plugins { alias(libs.plugins.kotlin.atomicfu) alias(libs.plugins.dokka) alias(libs.plugins.jreleaser) + alias(libs.plugins.ktlint) `maven-publish` signing alias(libs.plugins.kotlinx.binary.compatibility.validator) @@ -66,9 +67,9 @@ jreleaser { if (publication is MavenPublication) { val pubName = publication.name - if (!pubName.contains("jvm", ignoreCase = true) - && !pubName.contains("metadata", ignoreCase = true) - && !pubName.contains("kotlinMultiplatform", ignoreCase = true) + if (!pubName.contains("jvm", ignoreCase = true) && + !pubName.contains("metadata", ignoreCase = true) && + !pubName.contains("kotlinMultiplatform", ignoreCase = true) ) { artifactOverride { @@ -188,7 +189,7 @@ abstract class GenerateLibVersionTask @Inject constructor( public const val LIB_VERSION: String = "$libVersion" - """.trimIndent() + """.trimIndent(), ) } } @@ -276,3 +277,8 @@ kotlin { } } } + +ktlint { + verbose = true + outputToConsole = true +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 201a74f8..fcfcef9f 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -3,6 +3,7 @@ kotlin = "2.2.0" dokka = "2.0.0" atomicfu = "0.29.0" +ktlint = "13.0.0" # libraries version serialization = "1.9.0" @@ -42,5 +43,6 @@ kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } kotlin-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" } dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" } -jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"} +ktlint = { id = "org.jlleitschuh.gradle.ktlint", version.ref = "ktlint" } +jreleaser = { id = "org.jreleaser", version.ref = "jreleaser" } kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" } diff --git a/settings.gradle.kts b/settings.gradle.kts index c0ad7518..5284eaa6 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -16,4 +16,3 @@ dependencyResolutionManagement { } rootProject.name = "kotlin-sdk" - diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 59abee5b..b929ca36 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -80,10 +80,8 @@ public class ClientOptions( * @param clientInfo Information about the client implementation (name, version). * @param options Configuration options for this client. */ -public open class Client( - private val clientInfo: Implementation, - options: ClientOptions = ClientOptions(), -) : Protocol(options) { +public open class Client(private val clientInfo: Implementation, options: ClientOptions = ClientOptions()) : + Protocol(options) { /** * Retrieves the server's reported capabilities after the initialization process completes. @@ -144,13 +142,13 @@ public open class Client( val message = InitializeRequest( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = capabilities, - clientInfo = clientInfo + clientInfo = clientInfo, ) val result = request(message) if (!SUPPORTED_PROTOCOL_VERSIONS.contains(result.protocolVersion)) { throw IllegalStateException( - "Server's protocol version is not supported: ${result.protocolVersion}" + "Server's protocol version is not supported: ${result.protocolVersion}", ) } @@ -165,11 +163,9 @@ public open class Client( } throw error - } } - override fun assertCapabilityForMethod(method: Method) { when (method) { Method.Defined.LoggingSetLevel -> { @@ -181,7 +177,7 @@ public open class Client( Method.Defined.PromptsGet, Method.Defined.PromptsList, Method.Defined.CompletionComplete, - -> { + -> { if (serverCapabilities?.prompts == null) { throw IllegalStateException("Server does not support prompts (required for $method)") } @@ -192,20 +188,20 @@ public open class Client( Method.Defined.ResourcesRead, Method.Defined.ResourcesSubscribe, Method.Defined.ResourcesUnsubscribe, - -> { + -> { val resCaps = serverCapabilities?.resources ?: error("Server does not support resources (required for $method)") if (method == Method.Defined.ResourcesSubscribe && resCaps.subscribe != true) { throw IllegalStateException( - "Server does not support resource subscriptions (required for $method)" + "Server does not support resource subscriptions (required for $method)", ) } } Method.Defined.ToolsCall, Method.Defined.ToolsList, - -> { + -> { if (serverCapabilities?.tools == null) { throw IllegalStateException("Server does not support tools (required for $method)") } @@ -213,7 +209,7 @@ public open class Client( Method.Defined.Initialize, Method.Defined.Ping, - -> { + -> { // No specific capability required } @@ -228,7 +224,7 @@ public open class Client( Method.Defined.NotificationsRootsListChanged -> { if (capabilities.roots?.listChanged != true) { throw IllegalStateException( - "Client does not support roots list changed notifications (required for $method)" + "Client does not support roots list changed notifications (required for $method)", ) } } @@ -236,7 +232,7 @@ public open class Client( Method.Defined.NotificationsInitialized, Method.Defined.NotificationsCancelled, Method.Defined.NotificationsProgress, - -> { + -> { // Always allowed } @@ -251,7 +247,7 @@ public open class Client( Method.Defined.SamplingCreateMessage -> { if (capabilities.sampling == null) { throw IllegalStateException( - "Client does not support sampling capability (required for $method)" + "Client does not support sampling capability (required for $method)", ) } } @@ -259,7 +255,7 @@ public open class Client( Method.Defined.RootsList -> { if (capabilities.roots == null) { throw IllegalStateException( - "Client does not support roots capability (required for $method)" + "Client does not support roots capability (required for $method)", ) } } @@ -267,7 +263,7 @@ public open class Client( Method.Defined.ElicitationCreate -> { if (capabilities.elicitation == null) { throw IllegalStateException( - "Client does not support elicitation capability (required for $method)" + "Client does not support elicitation capability (required for $method)", ) } } @@ -280,16 +276,13 @@ public open class Client( } } - /** * Sends a ping request to the server to check connectivity. * * @param options Optional request options. * @throws IllegalStateException If the server does not support the ping method (unlikely). */ - public suspend fun ping(options: RequestOptions? = null): EmptyRequestResult { - return request(PingRequest(), options) - } + public suspend fun ping(options: RequestOptions? = null): EmptyRequestResult = request(PingRequest(), options) /** * Sends a completion request to the server, typically to generate or complete some content. @@ -299,9 +292,8 @@ public open class Client( * @return The completion result returned by the server, or `null` if none. * @throws IllegalStateException If the server does not support prompts or completion. */ - public suspend fun complete(params: CompleteRequest, options: RequestOptions? = null): CompleteResult? { - return request(params, options) - } + public suspend fun complete(params: CompleteRequest, options: RequestOptions? = null): CompleteResult = + request(params, options) /** * Sets the logging level on the server. @@ -310,9 +302,8 @@ public open class Client( * @param options Optional request options. * @throws IllegalStateException If the server does not support logging. */ - public suspend fun setLoggingLevel(level: LoggingLevel, options: RequestOptions? = null): EmptyRequestResult { - return request(SetLevelRequest(level), options) - } + public suspend fun setLoggingLevel(level: LoggingLevel, options: RequestOptions? = null): EmptyRequestResult = + request(SetLevelRequest(level), options) /** * Retrieves a prompt by name from the server. @@ -322,9 +313,8 @@ public open class Client( * @return The requested prompt details, or `null` if not found. * @throws IllegalStateException If the server does not support prompts. */ - public suspend fun getPrompt(request: GetPromptRequest, options: RequestOptions? = null): GetPromptResult? { - return request(request, options) - } + public suspend fun getPrompt(request: GetPromptRequest, options: RequestOptions? = null): GetPromptResult = + request(request, options) /** * Lists all available prompts from the server. @@ -337,9 +327,7 @@ public open class Client( public suspend fun listPrompts( request: ListPromptsRequest = ListPromptsRequest(), options: RequestOptions? = null, - ): ListPromptsResult? { - return request(request, options) - } + ): ListPromptsResult = request(request, options) /** * Lists all available resources from the server. @@ -352,9 +340,7 @@ public open class Client( public suspend fun listResources( request: ListResourcesRequest = ListResourcesRequest(), options: RequestOptions? = null, - ): ListResourcesResult? { - return request(request, options) - } + ): ListResourcesResult = request(request, options) /** * Lists resource templates available on the server. @@ -367,9 +353,7 @@ public open class Client( public suspend fun listResourceTemplates( request: ListResourceTemplatesRequest, options: RequestOptions? = null, - ): ListResourceTemplatesResult? { - return request(request, options) - } + ): ListResourceTemplatesResult = request(request, options) /** * Reads a resource from the server by its URI. @@ -382,9 +366,7 @@ public open class Client( public suspend fun readResource( request: ReadResourceRequest, options: RequestOptions? = null, - ): ReadResourceResult? { - return request(request, options) - } + ): ReadResourceResult = request(request, options) /** * Subscribes to resource changes on the server. @@ -396,9 +378,7 @@ public open class Client( public suspend fun subscribeResource( request: SubscribeRequest, options: RequestOptions? = null, - ): EmptyRequestResult { - return request(request, options) - } + ): EmptyRequestResult = request(request, options) /** * Unsubscribes from resource changes on the server. @@ -410,9 +390,7 @@ public open class Client( public suspend fun unsubscribeResource( request: UnsubscribeRequest, options: RequestOptions? = null, - ): EmptyRequestResult { - return request(request, options) - } + ): EmptyRequestResult = request(request, options) /** * Calls a tool on the server by name, passing the specified arguments. @@ -443,7 +421,7 @@ public open class Client( val request = CallToolRequest( name = name, - arguments = JsonObject(jsonArguments) + arguments = JsonObject(jsonArguments), ) return callTool(request, compatibility, options) } @@ -461,12 +439,10 @@ public open class Client( request: CallToolRequest, compatibility: Boolean = false, options: RequestOptions? = null, - ): CallToolResultBase? { - return if (compatibility) { - request(request, options) - } else { - request(request, options) - } + ): CallToolResultBase? = if (compatibility) { + request(request, options) + } else { + request(request, options) } /** @@ -480,9 +456,7 @@ public open class Client( public suspend fun listTools( request: ListToolsRequest = ListToolsRequest(), options: RequestOptions? = null, - ): ListToolsResult? { - return request(request, options) - } + ): ListToolsResult = request(request, options) /** * Registers a single root. @@ -491,10 +465,7 @@ public open class Client( * @param name A human-readable name for the root. * @throws IllegalStateException If the client does not support roots. */ - public fun addRoot( - uri: String, - name: String, - ) { + public fun addRoot(uri: String, name: String) { if (capabilities.roots == null) { logger.error { "Failed to add root '$name': Client does not support roots capability" } throw IllegalStateException("Client does not support roots capability.") diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt index 2ccc223d..ccc1496e 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt @@ -12,7 +12,7 @@ import kotlin.time.Duration * * @param urlString Optional URL of the MCP server. * @param reconnectionTime Optional duration to wait before attempting to reconnect. - * @param requestBuilder Optional lambda to configure the HTTP request. + * @param requestBuilder Optional lambda to configure the HTTP request. * @return A [SSEClientTransport] configured for MCP communication. */ public fun HttpClient.mcpSseTransport( @@ -38,8 +38,8 @@ public suspend fun HttpClient.mcpSse( val client = Client( Implementation( name = IMPLEMENTATION_NAME, - version = LIB_VERSION - ) + version = LIB_VERSION, + ), ) client.connect(transport) return client diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index d30f5288..950f37fa 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -139,6 +139,7 @@ public class SseClientTransport( } "endpoint" -> handleEndpoint(event.data.orEmpty()) + else -> handleMessage(event.data.orEmpty()) } } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 8ffbb752..583cec63 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -35,10 +35,7 @@ import kotlin.coroutines.CoroutineContext * @param output The output stream where messages are sent. */ @OptIn(ExperimentalAtomicApi::class) -public class StdioClientTransport( - private val input: Source, - private val output: Sink -) : AbstractTransport() { +public class StdioClientTransport(private val input: Source, private val output: Sink) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val ioCoroutineContext: CoroutineContext = IODispatcher private val scope by lazy { diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 6584bc12..de313de2 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -50,10 +50,8 @@ private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" /** * Error class for Streamable HTTP transport errors. */ -public class StreamableHttpError( - public val code: Int? = null, - message: String? = null -) : Exception("Streamable HTTP error: $message") +public class StreamableHttpError(public val code: Int? = null, message: String? = null) : + Exception("Streamable HTTP error: $message") /** * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. @@ -102,15 +100,16 @@ public class StreamableHttpClientTransport( public suspend fun send( message: JSONRPCMessage, resumptionToken: String?, - onResumptionToken: ((String) -> Unit)? = null + onResumptionToken: ((String) -> Unit)? = null, ) { logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } // If we have a resumption token, reconnect the SSE stream with it resumptionToken?.let { token -> startSseSession( - resumptionToken = token, onResumptionToken = onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null + resumptionToken = token, + onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null, ) return } @@ -147,9 +146,11 @@ public class StreamableHttpClientTransport( } ContentType.Text.EventStream -> handleInlineSse( - response, onResumptionToken = onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null + response, + onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null, ) + else -> { val body = response.bodyAsText() if (response.contentType() == null && body.isBlank()) return @@ -167,7 +168,7 @@ public class StreamableHttpClientTransport( logger.debug { "Client transport closing." } try { - // Try to terminate session if we have one + // Try to terminate a session if we have one terminateSession() sseSession?.cancel() @@ -196,7 +197,7 @@ public class StreamableHttpClientTransport( if (!response.status.isSuccess() && response.status != HttpStatusCode.MethodNotAllowed) { val error = StreamableHttpError( response.status.value, - "Failed to terminate session: ${response.status.description}" + "Failed to terminate session: ${response.status.description}", ) logger.error(error) { "Failed to terminate session" } _onError(error) @@ -211,7 +212,7 @@ public class StreamableHttpClientTransport( private suspend fun startSseSession( resumptionToken: String? = null, replayMessageId: RequestId? = null, - onResumptionToken: ((String) -> Unit)? = null + onResumptionToken: ((String) -> Unit)? = null, ) { sseSession?.cancel() sseJob?.cancelAndJoin() @@ -253,7 +254,7 @@ public class StreamableHttpClientTransport( private suspend fun collectSse( session: ClientSSESession, replayMessageId: RequestId?, - onResumptionToken: ((String) -> Unit)? + onResumptionToken: ((String) -> Unit)?, ) { try { session.incoming.collect { event -> @@ -289,7 +290,7 @@ public class StreamableHttpClientTransport( private suspend fun handleInlineSse( response: HttpResponse, replayMessageId: RequestId?, - onResumptionToken: ((String) -> Unit)? + onResumptionToken: ((String) -> Unit)?, ) { logger.trace { "Handling inline SSE from POST response" } val channel = response.bodyAsChannel() diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt index c2454e1f..1a600a3a 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -39,4 +39,4 @@ public suspend fun HttpClient.mcpStreamableHttp( val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) client.connect(transport) return client -} \ No newline at end of file +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt index 9d70d6c0..77062ab1 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt @@ -33,8 +33,8 @@ public suspend fun HttpClient.mcpWebSocket( val client = Client( Implementation( name = IMPLEMENTATION_NAME, - version = LIB_VERSION - ) + version = LIB_VERSION, + ), ) client.connect(transport) return client diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.kt index 49436f93..1a071a71 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.kt @@ -2,4 +2,4 @@ package io.modelcontextprotocol.kotlin.sdk.internal import kotlinx.coroutines.CoroutineDispatcher -internal expect val IODispatcher: CoroutineDispatcher \ No newline at end of file +internal expect val IODispatcher: CoroutineDispatcher diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index f3683497..40324fcb 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -69,7 +69,7 @@ private suspend fun ServerSSESession.mcpSseEndpoint( transports: ConcurrentMap, block: () -> Server, ) { - val transport = mcpSseTransport(postEndpoint, transports) + val transport = mcpSseTransport(postEndpoint, transports) val server = block() @@ -94,9 +94,7 @@ internal fun ServerSSESession.mcpSseTransport( return transport } -internal suspend fun RoutingContext.mcpPostEndpoint( - transports: ConcurrentMap, -) { +internal suspend fun RoutingContext.mcpPostEndpoint(transports: ConcurrentMap) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index c5b59702..8b5c6be2 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -12,7 +12,6 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson import kotlinx.coroutines.job -import kotlinx.serialization.encodeToString import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.uuid.ExperimentalUuidApi @@ -29,10 +28,8 @@ public typealias SSEServerTransport = SseServerTransport * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. */ @OptIn(ExperimentalAtomicApi::class) -public class SseServerTransport( - private val endpoint: String, - private val session: ServerSSESession, -) : AbstractTransport() { +public class SseServerTransport(private val endpoint: String, private val session: ServerSSESession) : + AbstractTransport() { private val initialized: AtomicBoolean = AtomicBoolean(false) @OptIn(ExperimentalUuidApi::class) @@ -45,13 +42,15 @@ public class SseServerTransport( */ override suspend fun start() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { - error("SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.") + error( + "SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.", + ) } // Send the endpoint event session.send( event = "endpoint", - data = "${endpoint.encodeURLPath()}?$SESSION_ID_PARAM=${sessionId}", + data = "${endpoint.encodeURLPath()}?$SESSION_ID_PARAM=$sessionId", ) try { diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 1230b895..169cfc4f 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -64,10 +64,8 @@ private val logger = KotlinLogging.logger {} * @property capabilities The capabilities this server supports. * @property enforceStrictCapabilities Whether to strictly enforce capabilities when interacting with clients. */ -public class ServerOptions( - public val capabilities: ServerCapabilities, - enforceStrictCapabilities: Boolean = true, -) : ProtocolOptions(enforceStrictCapabilities = enforceStrictCapabilities) +public class ServerOptions(public val capabilities: ServerCapabilities, enforceStrictCapabilities: Boolean = true) : + ProtocolOptions(enforceStrictCapabilities = enforceStrictCapabilities) /** * An MCP server on top of a pluggable transport. @@ -79,11 +77,11 @@ public class ServerOptions( * @param serverInfo Information about this server implementation (name, version). * @param options Configuration options for the server. */ -public open class Server( - private val serverInfo: Implementation, - options: ServerOptions, -) : Protocol(options) { +public open class Server(private val serverInfo: Implementation, options: ServerOptions) : Protocol(options) { + @Suppress("ktlint:standard:backing-property-naming") private var _onInitialized: (() -> Unit) = {} + + @Suppress("ktlint:standard:backing-property-naming") private var _onClose: () -> Unit = {} /** @@ -221,7 +219,7 @@ public open class Server( title: String? = null, outputSchema: Tool.Output? = null, toolAnnotations: ToolAnnotations? = null, - handler: suspend (CallToolRequest) -> CallToolResult + handler: suspend (CallToolRequest) -> CallToolResult, ) { val tool = Tool(name, title, description, inputSchema, outputSchema, toolAnnotations) addTool(tool, handler) @@ -325,7 +323,7 @@ public open class Server( name: String, description: String? = null, arguments: List? = null, - promptProvider: suspend (GetPromptRequest) -> GetPromptResult + promptProvider: suspend (GetPromptRequest) -> GetPromptResult, ) { val prompt = Prompt(name = name, description = description, arguments = arguments) addPrompt(prompt, promptProvider) @@ -416,7 +414,7 @@ public open class Server( name: String, description: String, mimeType: String = "text/html", - readHandler: suspend (ReadResourceRequest) -> ReadResourceResult + readHandler: suspend (ReadResourceRequest) -> ReadResourceResult, ) { if (capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } @@ -426,7 +424,7 @@ public open class Server( _resources.update { current -> current.put( uri, - RegisteredResource(Resource(uri, name, description, mimeType), readHandler) + RegisteredResource(Resource(uri, name, description, mimeType), readHandler), ) } } @@ -507,9 +505,7 @@ public open class Server( * @return The result of the ping request. * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. */ - public suspend fun ping(): EmptyRequestResult { - return request(PingRequest()) - } + public suspend fun ping(): EmptyRequestResult = request(PingRequest()) /** * Creates a message using the server's sampling capability. @@ -521,10 +517,10 @@ public open class Server( */ public suspend fun createMessage( params: CreateMessageRequest, - options: RequestOptions? = null + options: RequestOptions? = null, ): CreateMessageResult { logger.debug { "Creating message with params: $params" } - return request(params, options) + return request(params, options) } /** @@ -537,16 +533,16 @@ public open class Server( */ public suspend fun listRoots( params: JsonObject = EmptyJsonObject, - options: RequestOptions? = null + options: RequestOptions? = null, ): ListRootsResult { logger.debug { "Listing roots with params: $params" } - return request(ListRootsRequest(params), options) + return request(ListRootsRequest(params), options) } public suspend fun createElicitation( message: String, requestedSchema: RequestedSchema, - options: RequestOptions? = null + options: RequestOptions? = null, ): CreateElicitationResult { logger.debug { "Creating elicitation with message: $message" } return request(CreateElicitationRequest(message, requestedSchema), options) @@ -607,14 +603,16 @@ public open class Server( val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { requestedVersion } else { - logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } + logger.warn { + "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" + } LATEST_PROTOCOL_VERSION } return InitializeResult( protocolVersion = protocolVersion, capabilities = capabilities, - serverInfo = serverInfo + serverInfo = serverInfo, ) } @@ -723,26 +721,34 @@ public open class Server( } "notifications/resources/updated", - "notifications/resources/list_changed" -> { + "notifications/resources/list_changed", + -> { if (capabilities.resources == null) { - throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") + throw IllegalStateException( + "Server does not support notifying about resources (required for ${method.value})", + ) } } "notifications/tools/list_changed" -> { if (capabilities.tools == null) { - throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") + throw IllegalStateException( + "Server does not support notifying of tool list changes (required for ${method.value})", + ) } } "notifications/prompts/list_changed" -> { if (capabilities.prompts == null) { - throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") + throw IllegalStateException( + "Server does not support notifying of prompt list changes (required for ${method.value})", + ) } } "notifications/cancelled", - "notifications/progress" -> { + "notifications/progress", + -> { // Always allowed } } @@ -772,7 +778,8 @@ public open class Server( } "prompts/get", - "prompts/list" -> { + "prompts/list", + -> { if (capabilities.prompts == null) { throw IllegalStateException("Server does not support prompts (required for $method)") } @@ -780,14 +787,16 @@ public open class Server( "resources/list", "resources/templates/list", - "resources/read" -> { + "resources/read", + -> { if (capabilities.resources == null) { throw IllegalStateException("Server does not support resources (required for $method)") } } "tools/call", - "tools/list" -> { + "tools/list", + -> { if (capabilities.tools == null) { throw IllegalStateException("Server does not support tools (required for $method)") } @@ -806,10 +815,7 @@ public open class Server( * @property tool The tool definition. * @property handler A suspend function to handle the tool call requests. */ -public data class RegisteredTool( - val tool: Tool, - val handler: suspend (CallToolRequest) -> CallToolResult -) +public data class RegisteredTool(val tool: Tool, val handler: suspend (CallToolRequest) -> CallToolResult) /** * A wrapper class representing a registered prompt on the server. @@ -819,7 +825,7 @@ public data class RegisteredTool( */ public data class RegisteredPrompt( val prompt: Prompt, - val messageProvider: suspend (GetPromptRequest) -> GetPromptResult + val messageProvider: suspend (GetPromptRequest) -> GetPromptResult, ) /** @@ -830,5 +836,5 @@ public data class RegisteredPrompt( */ public data class RegisteredResource( val resource: Resource, - val readHandler: suspend (ReadResourceRequest) -> ReadResourceResult + val readHandler: suspend (ReadResourceRequest) -> ReadResourceResult, ) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index c515ddac..0fa18b8f 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -31,10 +31,7 @@ import kotlin.coroutines.CoroutineContext * Reads from System.in and writes to System.out. */ @OptIn(ExperimentalAtomicApi::class) -public class StdioServerTransport( - private val inputStream: Source, - outputStream: Sink -) : AbstractTransport() { +public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val readBuffer = ReadBuffer() diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index 9301749b..a3d2fd34 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -14,10 +14,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ -public fun Route.mcpWebSocket( - options: ServerOptions? = null, - handler: suspend Server.() -> Unit = {}, -) { +public fun Route.mcpWebSocket(options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}) { webSocket { createMcpServer(this, options, handler) } @@ -30,11 +27,7 @@ public fun Route.mcpWebSocket( * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ -public fun Route.mcpWebSocket( - path: String, - options: ServerOptions? = null, - handler: suspend Server.() -> Unit = {}, -) { +public fun Route.mcpWebSocket(path: String, options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}) { webSocket(path) { createMcpServer(this, options, handler) } @@ -45,9 +38,7 @@ public fun Route.mcpWebSocket( * * @param handler A suspend function that defines the behavior of the transport layer. */ -public fun Route.mcpWebSocketTransport( - handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, -) { +public fun Route.mcpWebSocketTransport(handler: suspend WebSocketMcpServerTransport.() -> Unit = {}) { webSocket { val transport = createMcpTransport(this) transport.start() @@ -62,10 +53,7 @@ public fun Route.mcpWebSocketTransport( * @param path The URL path at which to register the WebSocket route. * @param handler A suspend function that defines the behavior of the transport layer. */ -public fun Route.mcpWebSocketTransport( - path: String, - handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, -) { +public fun Route.mcpWebSocketTransport(path: String, handler: suspend WebSocketMcpServerTransport.() -> Unit = {}) { webSocket(path) { val transport = createMcpTransport(this) transport.start() @@ -74,7 +62,6 @@ public fun Route.mcpWebSocketTransport( } } - private suspend fun Route.createMcpServer( session: WebSocketServerSession, options: ServerOptions?, @@ -85,14 +72,14 @@ private suspend fun Route.createMcpServer( val server = Server( serverInfo = Implementation( name = IMPLEMENTATION_NAME, - version = LIB_VERSION + version = LIB_VERSION, ), options = options ?: ServerOptions( capabilities = ServerCapabilities( prompts = ServerCapabilities.Prompts(listChanged = null), resources = ServerCapabilities.Resources(subscribe = null, listChanged = null), tools = ServerCapabilities.Tools(listChanged = null), - ) + ), ), ) @@ -101,8 +88,5 @@ private suspend fun Route.createMcpServer( server.close() } -private fun createMcpTransport( - session: WebSocketServerSession, -): WebSocketMcpServerTransport { - return WebSocketMcpServerTransport(session) -} +private fun createMcpTransport(session: WebSocketServerSession): WebSocketMcpServerTransport = + WebSocketMcpServerTransport(session) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt index 45cb4df9..877fda58 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt @@ -10,9 +10,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport * * @property session The WebSocket server session used for communication. */ -public class WebSocketMcpServerTransport( - override val session: WebSocketServerSession, -) : WebSocketMcpTransport() { +public class WebSocketMcpServerTransport(override val session: WebSocketServerSession) : WebSocketMcpTransport() { override suspend fun initializeSession() { val subprotocol = session.call.request.headers[HttpHeaders.SecWebSocketProtocol] if (subprotocol != MCP_SUBPROTOCOL) { diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 06ce6baa..dfe224ab 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -108,7 +108,7 @@ public data class RequestOptions( /** * Extra data given to request handlers. */ -public class RequestHandlerExtra() +public class RequestHandlerExtra internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) } @@ -116,15 +116,20 @@ internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) } * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ -public abstract class Protocol( - @PublishedApi internal val options: ProtocolOptions?, -) { +public abstract class Protocol(@PublishedApi internal val options: ProtocolOptions?) { public var transport: Transport? = null private set - private val _requestHandlers: AtomicRef RequestResult?>> = + private val _requestHandlers: + AtomicRef RequestResult?>> = atomic(persistentMapOf()) - public val requestHandlers: Map RequestResult?> + public val requestHandlers: Map< + String, + suspend ( + request: JSONRPCRequest, + extra: RequestHandlerExtra, + ) -> RequestResult?, + > get() = _requestHandlers.value private val _notificationHandlers = @@ -132,7 +137,8 @@ public abstract class Protocol( public val notificationHandlers: Map Unit> get() = _notificationHandlers.value - private val _responseHandlers: AtomicRef Unit>> = + private val _responseHandlers: + AtomicRef Unit>> = atomic(persistentMapOf()) public val responseHandlers: Map Unit> get() = _responseHandlers.value @@ -160,7 +166,9 @@ public abstract class Protocol( /** * A handler to invoke for any request types that do not have their own handler installed. */ - public var fallbackRequestHandler: (suspend (request: JSONRPCRequest, extra: RequestHandlerExtra) -> RequestResult?)? = + public var fallbackRequestHandler: ( + suspend (request: JSONRPCRequest, extra: RequestHandlerExtra) -> RequestResult? + )? = null /** @@ -250,8 +258,8 @@ public abstract class Protocol( error = JSONRPCError( ErrorCode.Defined.MethodNotFound, message = "Server does not support ${request.method}", - ) - ) + ), + ), ) } catch (cause: Throwable) { LOGGER.error(cause) { "Error sending method not found response" } @@ -267,10 +275,9 @@ public abstract class Protocol( transport?.send( JSONRPCResponse( id = request.id, - result = result - ) + result = result, + ), ) - } catch (cause: Throwable) { LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } @@ -280,19 +287,23 @@ public abstract class Protocol( id = request.id, error = JSONRPCError( code = ErrorCode.Defined.InternalError, - message = cause.message ?: "Internal error" - ) - ) + message = cause.message ?: "Internal error", + ), + ), ) } catch (sendError: Throwable) { - LOGGER.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } + LOGGER.error(sendError) { + "Failed to send error response for request: ${request.method} (id: ${request.id})" + } // Optionally implement fallback behavior here } } } private fun onProgress(notification: ProgressNotification) { - LOGGER.trace { "Received progress notification: token=${notification.progressToken}, progress=${notification.progress}/${notification.total}" } + LOGGER.trace { + "Received progress notification: token=${notification.progressToken}, progress=${notification.progress}/${notification.total}" + } val progress = notification.progress val total = notification.total val message = notification.message @@ -354,21 +365,21 @@ public abstract class Protocol( /** * A method to check if a capability is supported by the remote side, for the given method to be called. * - * This should be implemented by subclasses. + * Subclasses should implement this. */ protected abstract fun assertCapabilityForMethod(method: Method) /** * A method to check if a notification is supported by the local side, for the given method to be sent. * - * This should be implemented by subclasses. + * Subclasses should implement this. */ protected abstract fun assertNotificationCapability(method: Method) /** - * A method to check if a request handler is supported by the local side, for the given method to be handled. + * A method to check if the local side supports a request handler for the given method to be handled. * - * This should be implemented by subclasses. + * Subclasses should implement this. */ public abstract fun assertRequestHandlerCapability(method: Method) @@ -377,10 +388,7 @@ public abstract class Protocol( * * Do not use this method to emit notifications! Use notification() instead. */ - public suspend fun request( - request: Request, - options: RequestOptions? = null, - ): T { + public suspend fun request(request: Request, options: RequestOptions? = null): T { LOGGER.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() val transport = this@Protocol.transport ?: throw Error("Not connected") @@ -428,12 +436,11 @@ public abstract class Protocol( val serialized = JSONRPCNotification( notification.method.value, - params = McpJson.encodeToJsonElement(notification) + params = McpJson.encodeToJsonElement(notification), ) transport.send(serialized) result.completeExceptionally(reason) - Unit } val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT @@ -449,7 +456,7 @@ public abstract class Protocol( McpError( ErrorCode.Defined.RequestTimeout.code, "Request timed out", - JsonObject(mutableMapOf("timeout" to JsonPrimitive(timeout.inWholeMilliseconds))) + JsonObject(mutableMapOf("timeout" to JsonPrimitive(timeout.inWholeMilliseconds))), ), ) result.cancel(cause) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt index ddffaa99..ee52f7e0 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt @@ -5,7 +5,6 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import kotlinx.io.Buffer import kotlinx.io.indexOf import kotlinx.io.readString -import kotlinx.serialization.encodeToString /** * Buffers a continuous stdio stream into discrete JSON-RPC messages. @@ -22,6 +21,7 @@ public class ReadBuffer { var lfIndex = buffer.indexOf('\n'.code.toByte()) val line = when (lfIndex) { -1L -> return null + 0L -> { buffer.skip(1) return null @@ -46,11 +46,6 @@ public class ReadBuffer { } } -internal fun deserializeMessage(line: String): JSONRPCMessage { - return McpJson.decodeFromString(line) -} - -internal fun serializeMessage(message: JSONRPCMessage): String { - return McpJson.encodeToString(message) + "\n" -} +internal fun deserializeMessage(line: String): JSONRPCMessage = McpJson.decodeFromString(line) +internal fun serializeMessage(message: JSONRPCMessage): String = McpJson.encodeToString(message) + "\n" diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index ff601ed3..d9027752 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -12,7 +12,6 @@ import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.job import kotlinx.coroutines.launch -import kotlinx.serialization.encodeToString import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi @@ -44,7 +43,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error( "WebSocketClientTransport already started! " + - "If using Client class, note that connect() calls start() automatically.", + "If using Client class, note that connect() calls start() automatically.", ) } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index d5b45f6f..b8a9634c 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -170,7 +170,7 @@ internal fun Notification.toJSON(): JSONRPCNotification { val encoded = JsonObject(McpJson.encodeToJsonElement(this).jsonObject.minus("method")) return JSONRPCNotification( method.value, - params = encoded + params = encoded, ) } @@ -196,9 +196,9 @@ public sealed interface RequestResult : WithMeta * @param _meta Additional metadata for the response. Defaults to an empty JSON object. */ @Serializable -public data class EmptyRequestResult( - override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, ClientResult +public data class EmptyRequestResult(override val _meta: JsonObject = EmptyJsonObject) : + ServerResult, + ClientResult /** * A uniquely identifying ID for a request in JSON-RPC. @@ -278,7 +278,6 @@ public sealed interface ErrorCode { MethodNotFound(-32601), InvalidParams(-32602), InternalError(-32603), - ; } @Serializable @@ -289,11 +288,8 @@ public sealed interface ErrorCode { * A response to a request that indicates an error occurred. */ @Serializable -public data class JSONRPCError( - val code: ErrorCode, - val message: String, - val data: JsonObject = EmptyJsonObject, -) : JSONRPCMessage +public data class JSONRPCError(val code: ErrorCode, val message: String, val data: JsonObject = EmptyJsonObject) : + JSONRPCMessage /* Cancellation */ /** @@ -318,7 +314,9 @@ public data class CancelledNotification( */ val reason: String?, override val _meta: JsonObject = EmptyJsonObject, -) : ClientNotification, ServerNotification, WithMeta { +) : ClientNotification, + ServerNotification, + WithMeta { override val method: Method = Method.Defined.NotificationsCancelled } @@ -327,10 +325,7 @@ public data class CancelledNotification( * Describes the name and version of an MCP implementation. */ @Serializable -public data class Implementation( - val name: String, - val version: String, -) +public data class Implementation(val name: String, val version: String) /** * Capabilities a client may support. @@ -368,7 +363,7 @@ public data class ClientCapabilities( /** * Represents a request sent by the client. */ -//@Serializable(with = ClientRequestPolymorphicSerializer::class) +// @Serializable(with = ClientRequestPolymorphicSerializer::class) public interface ClientRequest : Request /** @@ -386,7 +381,7 @@ public sealed interface ClientResult : RequestResult /** * Represents a request sent by the server. */ -//@Serializable(with = ServerRequestPolymorphicSerializer::class) +// @Serializable(with = ServerRequestPolymorphicSerializer::class) public sealed interface ServerRequest : Request /** @@ -407,9 +402,11 @@ public sealed interface ServerResult : RequestResult * @param method The method that is unknown. */ @Serializable -public data class UnknownMethodRequestOrNotification( - override val method: Method, -) : ClientNotification, ClientRequest, ServerNotification, ServerRequest +public data class UnknownMethodRequestOrNotification(override val method: Method) : + ClientNotification, + ClientRequest, + ServerNotification, + ServerRequest /** * This request is sent from the client to the server when it first connects, asking it to begin initialization. @@ -420,7 +417,8 @@ public data class InitializeRequest( val capabilities: ClientCapabilities, val clientInfo: Implementation, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.Initialize } @@ -516,7 +514,9 @@ public class InitializedNotification : ClientNotification { * The receiver must promptly respond, or else it may be disconnected. */ @Serializable -public class PingRequest : ServerRequest, ClientRequest { +public class PingRequest : + ServerRequest, + ClientRequest { override val method: Method = Method.Defined.Ping } @@ -581,7 +581,9 @@ public data class ProgressNotification( @Suppress("PropertyName") val _meta: JsonObject = EmptyJsonObject, override val total: Double?, override val message: String?, -) : ClientNotification, ServerNotification, ProgressBase { +) : ClientNotification, + ServerNotification, + ProgressBase { override val method: Method = Method.Defined.NotificationsProgress } @@ -590,7 +592,9 @@ public data class ProgressNotification( * Represents a request supporting pagination. */ @Serializable -public sealed interface PaginatedRequest : Request, WithMeta { +public sealed interface PaginatedRequest : + Request, + WithMeta { /** * The cursor indicating the pagination position. */ @@ -633,11 +637,8 @@ public sealed interface ResourceContents { * @property text The text of the item. This must only be set if the item can actually be represented as text (not binary data). */ @Serializable -public data class TextResourceContents( - val text: String, - override val uri: String, - override val mimeType: String?, -) : ResourceContents +public data class TextResourceContents(val text: String, override val uri: String, override val mimeType: String?) : + ResourceContents /** * Represents the binary contents of a resource encoded as a base64 string. @@ -645,20 +646,14 @@ public data class TextResourceContents( * @property blob A base64-encoded string representing the binary data of the item. */ @Serializable -public data class BlobResourceContents( - val blob: String, - override val uri: String, - override val mimeType: String?, -) : ResourceContents +public data class BlobResourceContents(val blob: String, override val uri: String, override val mimeType: String?) : + ResourceContents /** * Represents resource contents with unknown or unspecified data. */ @Serializable -public data class UnknownResourceContents( - override val uri: String, - override val mimeType: String?, -) : ResourceContents +public data class UnknownResourceContents(override val uri: String, override val mimeType: String?) : ResourceContents /** * A known resource that the server is capable of reading. @@ -722,8 +717,9 @@ public data class ResourceTemplate( @Serializable public data class ListResourcesRequest( override val cursor: Cursor? = null, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.ResourcesList } @@ -735,7 +731,8 @@ public class ListResourcesResult( public val resources: List, override val nextCursor: Cursor? = null, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * Sent from the client to request a list of resource templates the server has. @@ -743,8 +740,9 @@ public class ListResourcesResult( @Serializable public data class ListResourceTemplatesRequest( override val cursor: Cursor?, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.ResourcesTemplatesList } @@ -756,16 +754,16 @@ public class ListResourceTemplatesResult( public val resourceTemplates: List, override val nextCursor: Cursor? = null, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * Sent from the client to the server to read a specific resource URI. */ @Serializable -public data class ReadResourceRequest( - val uri: String, - override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +public data class ReadResourceRequest(val uri: String, override val _meta: JsonObject = EmptyJsonObject) : + ClientRequest, + WithMeta { override val method: Method = Method.Defined.ResourcesRead } @@ -798,7 +796,8 @@ public data class SubscribeRequest( */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.ResourcesSubscribe } @@ -812,7 +811,8 @@ public data class UnsubscribeRequest( */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.ResourcesUnsubscribe } @@ -826,7 +826,8 @@ public data class ResourceUpdatedNotification( */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, -) : ServerNotification, WithMeta { +) : ServerNotification, + WithMeta { override val method: Method = Method.Defined.NotificationsResourcesUpdated } @@ -875,8 +876,9 @@ public class Prompt( @Serializable public data class ListPromptsRequest( override val cursor: Cursor? = null, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.PromptsList } @@ -888,7 +890,8 @@ public class ListPromptsResult( public val prompts: List, override val nextCursor: Cursor? = null, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * Used by the client to get a prompt provided by the server. @@ -906,7 +909,8 @@ public data class GetPromptRequest( val arguments: Map?, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.PromptsGet } @@ -985,22 +989,17 @@ public data class AudioContent( } } - /** * Unknown content provided to or from an LLM. */ @Serializable -public data class UnknownContent( - override val type: String, -) : PromptMessageContentMultimodal +public data class UnknownContent(override val type: String) : PromptMessageContentMultimodal /** * The contents of a resource, embedded into a prompt or tool call result. */ @Serializable -public data class EmbeddedResource( - val resource: ResourceContents, -) : PromptMessageContent { +public data class EmbeddedResource(val resource: ResourceContents) : PromptMessageContent { override val type: String = TYPE public companion object { @@ -1014,17 +1013,15 @@ public data class EmbeddedResource( @Suppress("EnumEntryName") @Serializable public enum class Role { - user, assistant, + user, + assistant, } /** * Describes a message returned as part of a prompt. */ @Serializable -public data class PromptMessage( - val role: Role, - val content: PromptMessageContent, -) +public data class PromptMessage(val role: Role, val content: PromptMessageContent) /** * The server's response to a prompts/get request from the client. @@ -1082,7 +1079,7 @@ public data class ToolAnnotations( val destructiveHint: Boolean? = true, /** * If true, calling the tool repeatedly with the same arguments - * will have no additional effect on the its environment. + * will have no additional effect on its environment. * * (This property is meaningful only when `readOnlyHint == false`) * @@ -1100,7 +1097,6 @@ public data class ToolAnnotations( val openWorldHint: Boolean? = true, ) - /** * Definition for a tool the client can call. */ @@ -1132,20 +1128,14 @@ public data class Tool( val annotations: ToolAnnotations?, ) { @Serializable - public data class Input( - val properties: JsonObject = EmptyJsonObject, - val required: List? = null, - ) { + public data class Input(val properties: JsonObject = EmptyJsonObject, val required: List? = null) { @OptIn(ExperimentalSerializationApi::class) @EncodeDefault val type: String = "object" } @Serializable - public data class Output( - val properties: JsonObject = EmptyJsonObject, - val required: List? = null, - ) { + public data class Output(val properties: JsonObject = EmptyJsonObject, val required: List? = null) { @OptIn(ExperimentalSerializationApi::class) @EncodeDefault val type: String = "object" @@ -1158,8 +1148,9 @@ public data class Tool( @Serializable public data class ListToolsRequest( override val cursor: Cursor? = null, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.ToolsList } @@ -1171,7 +1162,8 @@ public class ListToolsResult( public val tools: List, override val nextCursor: Cursor?, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * The server's response to a tool call. @@ -1214,7 +1206,8 @@ public data class CallToolRequest( val name: String, val arguments: JsonObject = EmptyJsonObject, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.ToolsCall } @@ -1242,7 +1235,6 @@ public enum class LoggingLevel { critical, alert, emergency, - ; } /** @@ -1267,7 +1259,8 @@ public data class LoggingMessageNotification( */ val data: JsonObject = EmptyJsonObject, override val _meta: JsonObject = EmptyJsonObject, -) : ServerNotification, WithMeta { +) : ServerNotification, + WithMeta { /** * A request from the client to the server to enable or adjust logging. */ @@ -1278,7 +1271,8 @@ public data class LoggingMessageNotification( */ val level: LoggingLevel, override val _meta: JsonObject = EmptyJsonObject, - ) : ClientRequest, WithMeta { + ) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.LoggingSetLevel } @@ -1339,10 +1333,7 @@ public class ModelPreferences( * Describes a message issued to or received from an LLM API. */ @Serializable -public data class SamplingMessage( - val role: Role, - val content: PromptMessageContentMultimodal, -) +public data class SamplingMessage(val role: Role, val content: PromptMessageContentMultimodal) /** * A request from the server to sample an LLM via the client. @@ -1376,7 +1367,8 @@ public data class CreateMessageRequest( */ val modelPreferences: ModelPreferences?, override val _meta: JsonObject = EmptyJsonObject, -) : ServerRequest, WithMeta { +) : ServerRequest, + WithMeta { override val method: Method = Method.Defined.SamplingCreateMessage @Serializable @@ -1471,9 +1463,7 @@ public data class PromptReference( * Identifies a prompt. */ @Serializable -public data class UnknownReference( - override val type: String, -) : Reference +public data class UnknownReference(override val type: String) : Reference /** * A request from the client to the server to ask for completion options. @@ -1486,7 +1476,8 @@ public data class CompleteRequest( */ val argument: Argument, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.CompletionComplete @Serializable @@ -1506,10 +1497,8 @@ public data class CompleteRequest( * The server's response to a completion/complete request */ @Serializable -public data class CompleteResult( - val completion: Completion, - override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult { +public data class CompleteResult(val completion: Completion, override val _meta: JsonObject = EmptyJsonObject) : + ServerResult { @Suppress("CanBeParameter") @Serializable public class Completion( @@ -1561,7 +1550,9 @@ public data class Root( * Sent from the server to request a list of root URIs from the client. */ @Serializable -public class ListRootsRequest(override val _meta: JsonObject = EmptyJsonObject) : ServerRequest, WithMeta { +public class ListRootsRequest(override val _meta: JsonObject = EmptyJsonObject) : + ServerRequest, + WithMeta { override val method: Method = Method.Defined.RootsList } @@ -1569,10 +1560,8 @@ public class ListRootsRequest(override val _meta: JsonObject = EmptyJsonObject) * The client's response to a roots/list request from the server. */ @Serializable -public class ListRootsResult( - public val roots: List, - override val _meta: JsonObject = EmptyJsonObject, -) : ClientResult +public class ListRootsResult(public val roots: List, override val _meta: JsonObject = EmptyJsonObject) : + ClientResult /** * A notification from the client to the server, informing it that the list of roots has changed. @@ -1583,14 +1572,15 @@ public class RootsListChangedNotification : ClientNotification { } /** - * Sent from the server to create an elicitation from the client. + * Sent from the server to create elicitation from the client. */ @Serializable public data class CreateElicitationRequest( public val message: String, public val requestedSchema: RequestedSchema, override val _meta: JsonObject = EmptyJsonObject, -) : ServerRequest, WithMeta { +) : ServerRequest, + WithMeta { override val method: Method = Method.Defined.ElicitationCreate @Serializable @@ -1633,5 +1623,5 @@ public data class CreateElicitationResult( */ public class McpError(public val code: Int, message: String, public val data: JsonObject = EmptyJsonObject) : Exception() { - override val message: String = "MCP error ${code}: $message" + override val message: String = "MCP error $code: $message" } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt index 67b6b9f3..d050413f 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt @@ -61,50 +61,44 @@ internal object StopReasonSerializer : KSerializer { encoder.encodeString(value.value) } - override fun deserialize(decoder: Decoder): StopReason { - val decodedString = decoder.decodeString() - return when (decodedString) { - StopReason.StopSequence.value -> StopReason.StopSequence - StopReason.MaxTokens.value -> StopReason.MaxTokens - StopReason.EndTurn.value -> StopReason.EndTurn - else -> StopReason.Other(decodedString) - } + override fun deserialize(decoder: Decoder): StopReason = when (val decodedString = decoder.decodeString()) { + StopReason.StopSequence.value -> StopReason.StopSequence + StopReason.MaxTokens.value -> StopReason.MaxTokens + StopReason.EndTurn.value -> StopReason.EndTurn + else -> StopReason.Other(decodedString) } } internal object ReferencePolymorphicSerializer : JsonContentPolymorphicSerializer(Reference::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return when (element.jsonObject.getValue("type").jsonPrimitive.content) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.jsonObject.getValue("type").jsonPrimitive.content) { ResourceReference.TYPE -> ResourceReference.serializer() PromptReference.TYPE -> PromptReference.serializer() else -> UnknownReference.serializer() } - } } internal object PromptMessageContentPolymorphicSerializer : JsonContentPolymorphicSerializer(PromptMessageContent::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return when (element.jsonObject.getValue("type").jsonPrimitive.content) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.jsonObject.getValue("type").jsonPrimitive.content) { ImageContent.TYPE -> ImageContent.serializer() TextContent.TYPE -> TextContent.serializer() EmbeddedResource.TYPE -> EmbeddedResource.serializer() AudioContent.TYPE -> AudioContent.serializer() else -> UnknownContent.serializer() } - } } internal object PromptMessageContentMultimodalPolymorphicSerializer : JsonContentPolymorphicSerializer(PromptMessageContentMultimodal::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return when (element.jsonObject.getValue("type").jsonPrimitive.content) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.jsonObject.getValue("type").jsonPrimitive.content) { ImageContent.TYPE -> ImageContent.serializer() TextContent.TYPE -> TextContent.serializer() AudioContent.TYPE -> AudioContent.serializer() else -> UnknownContent.serializer() } - } } internal object ResourceContentsPolymorphicSerializer : @@ -125,54 +119,48 @@ internal fun selectRequestDeserializer(method: String): DeserializationStrategy< return CustomRequest.serializer() } -internal fun selectClientRequestDeserializer(method: String): DeserializationStrategy? { - return when (method) { - Method.Defined.Ping.value -> PingRequest.serializer() - Method.Defined.Initialize.value -> InitializeRequest.serializer() - Method.Defined.CompletionComplete.value -> CompleteRequest.serializer() - Method.Defined.LoggingSetLevel.value -> SetLevelRequest.serializer() - Method.Defined.PromptsGet.value -> GetPromptRequest.serializer() - Method.Defined.PromptsList.value -> ListPromptsRequest.serializer() - Method.Defined.ResourcesList.value -> ListResourcesRequest.serializer() - Method.Defined.ResourcesTemplatesList.value -> ListResourceTemplatesRequest.serializer() - Method.Defined.ResourcesRead.value -> ReadResourceRequest.serializer() - Method.Defined.ResourcesSubscribe.value -> SubscribeRequest.serializer() - Method.Defined.ResourcesUnsubscribe.value -> UnsubscribeRequest.serializer() - Method.Defined.ToolsCall.value -> CallToolRequest.serializer() - Method.Defined.ToolsList.value -> ListToolsRequest.serializer() - else -> null - } +internal fun selectClientRequestDeserializer(method: String): DeserializationStrategy? = when (method) { + Method.Defined.Ping.value -> PingRequest.serializer() + Method.Defined.Initialize.value -> InitializeRequest.serializer() + Method.Defined.CompletionComplete.value -> CompleteRequest.serializer() + Method.Defined.LoggingSetLevel.value -> SetLevelRequest.serializer() + Method.Defined.PromptsGet.value -> GetPromptRequest.serializer() + Method.Defined.PromptsList.value -> ListPromptsRequest.serializer() + Method.Defined.ResourcesList.value -> ListResourcesRequest.serializer() + Method.Defined.ResourcesTemplatesList.value -> ListResourceTemplatesRequest.serializer() + Method.Defined.ResourcesRead.value -> ReadResourceRequest.serializer() + Method.Defined.ResourcesSubscribe.value -> SubscribeRequest.serializer() + Method.Defined.ResourcesUnsubscribe.value -> UnsubscribeRequest.serializer() + Method.Defined.ToolsCall.value -> CallToolRequest.serializer() + Method.Defined.ToolsList.value -> ListToolsRequest.serializer() + else -> null } -private fun selectClientNotificationDeserializer(element: JsonElement): DeserializationStrategy? { - return when (element.jsonObject.getValue("method").jsonPrimitive.content) { +private fun selectClientNotificationDeserializer(element: JsonElement): DeserializationStrategy? = + when (element.jsonObject.getValue("method").jsonPrimitive.content) { Method.Defined.NotificationsCancelled.value -> CancelledNotification.serializer() Method.Defined.NotificationsProgress.value -> ProgressNotification.serializer() Method.Defined.NotificationsInitialized.value -> InitializedNotification.serializer() Method.Defined.NotificationsRootsListChanged.value -> RootsListChangedNotification.serializer() else -> null } -} internal object ClientNotificationPolymorphicSerializer : JsonContentPolymorphicSerializer(ClientNotification::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientNotificationDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientNotificationDeserializer(element) ?: UnknownMethodRequestOrNotification.serializer() - } } -internal fun selectServerRequestDeserializer(method: String): DeserializationStrategy? { - return when (method) { - Method.Defined.Ping.value -> PingRequest.serializer() - Method.Defined.SamplingCreateMessage.value -> CreateMessageRequest.serializer() - Method.Defined.RootsList.value -> ListRootsRequest.serializer() - else -> null - } +internal fun selectServerRequestDeserializer(method: String): DeserializationStrategy? = when (method) { + Method.Defined.Ping.value -> PingRequest.serializer() + Method.Defined.SamplingCreateMessage.value -> CreateMessageRequest.serializer() + Method.Defined.RootsList.value -> ListRootsRequest.serializer() + else -> null } -internal fun selectServerNotificationDeserializer(element: JsonElement): DeserializationStrategy? { - return when (element.jsonObject.getValue("method").jsonPrimitive.content) { +internal fun selectServerNotificationDeserializer(element: JsonElement): DeserializationStrategy? = + when (element.jsonObject.getValue("method").jsonPrimitive.content) { Method.Defined.NotificationsCancelled.value -> CancelledNotification.serializer() Method.Defined.NotificationsProgress.value -> ProgressNotification.serializer() Method.Defined.NotificationsMessage.value -> LoggingMessageNotification.serializer() @@ -182,23 +170,20 @@ internal fun selectServerNotificationDeserializer(element: JsonElement): Deseria Method.Defined.PromptsList.value -> PromptListChangedNotification.serializer() else -> null } -} internal object ServerNotificationPolymorphicSerializer : JsonContentPolymorphicSerializer(ServerNotification::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectServerNotificationDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectServerNotificationDeserializer(element) ?: UnknownMethodRequestOrNotification.serializer() - } } internal object NotificationPolymorphicSerializer : JsonContentPolymorphicSerializer(Notification::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientNotificationDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientNotificationDeserializer(element) ?: selectServerNotificationDeserializer(element) ?: UnknownMethodRequestOrNotification.serializer() - } } internal object RequestPolymorphicSerializer : @@ -243,27 +228,24 @@ private fun selectClientResultDeserializer(element: JsonElement): Deserializatio internal object ServerResultPolymorphicSerializer : JsonContentPolymorphicSerializer(ServerResult::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectServerResultDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectServerResultDeserializer(element) ?: EmptyRequestResult.serializer() - } } internal object ClientResultPolymorphicSerializer : JsonContentPolymorphicSerializer(ClientResult::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientResultDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientResultDeserializer(element) ?: EmptyRequestResult.serializer() - } } internal object RequestResultPolymorphicSerializer : JsonContentPolymorphicSerializer(RequestResult::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientResultDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientResultDeserializer(element) ?: selectServerResultDeserializer(element) ?: EmptyRequestResult.serializer() - } } internal object JSONRPCMessagePolymorphicSerializer : @@ -286,9 +268,7 @@ public class RequestIdSerializer : KSerializer { override fun deserialize(decoder: Decoder): RequestId { val jsonDecoder = decoder as? JsonDecoder ?: error("Can only deserialize JSON") - val element = jsonDecoder.decodeJsonElement() - - return when (element) { + return when (val element = jsonDecoder.decodeJsonElement()) { is JsonPrimitive -> when { element.isString -> RequestId.StringId(element.content) element.longOrNull != null -> RequestId.NumberId(element.long) @@ -315,7 +295,7 @@ public fun CallToolResult.Companion.ok(content: String, meta: JsonObject = Empty CallToolResult( content = listOf(TextContent(content)), isError = false, - _meta = meta + _meta = meta, ) /** @@ -325,5 +305,5 @@ public fun CallToolResult.Companion.error(content: String, meta: JsonObject = Em CallToolResult( content = listOf(TextContent(content)), isError = true, - _meta = meta - ) \ No newline at end of file + _meta = meta, + ) diff --git a/src/commonTest/kotlin/AudioContentSerializationTest.kt b/src/commonTest/kotlin/AudioContentSerializationTest.kt index 6745972c..247388f7 100644 --- a/src/commonTest/kotlin/AudioContentSerializationTest.kt +++ b/src/commonTest/kotlin/AudioContentSerializationTest.kt @@ -2,7 +2,6 @@ package io.modelcontextprotocol.kotlin.sdk import io.kotest.assertions.json.shouldEqualJson import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import kotlinx.serialization.encodeToString import kotlin.test.Test import kotlin.test.assertEquals @@ -18,7 +17,7 @@ class AudioContentSerializationTest { private val audioContent = AudioContent( data = "base64-encoded-audio-data", - mimeType = "audio/wav" + mimeType = "audio/wav", ) @Test @@ -31,4 +30,4 @@ class AudioContentSerializationTest { val content = McpJson.decodeFromString(audioContentJson) assertEquals(expected = audioContent, actual = content) } -} \ No newline at end of file +} diff --git a/src/commonTest/kotlin/CallToolResultUtilsTest.kt b/src/commonTest/kotlin/CallToolResultUtilsTest.kt index 4d5ad2dd..5bc55556 100644 --- a/src/commonTest/kotlin/CallToolResultUtilsTest.kt +++ b/src/commonTest/kotlin/CallToolResultUtilsTest.kt @@ -4,8 +4,7 @@ import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertTrue +import kotlin.test.assertNotEquals class CallToolResultUtilsTest { @@ -16,7 +15,7 @@ class CallToolResultUtilsTest { assertEquals(1, result.content.size) assertEquals(content, (result.content[0] as TextContent).text) - assertFalse(result.isError == true) + assertNotEquals(result.isError, true) assertEquals(EmptyJsonObject, result._meta) } @@ -31,7 +30,7 @@ class CallToolResultUtilsTest { assertEquals(1, result.content.size) assertEquals(content, (result.content[0] as TextContent).text) - assertFalse(result.isError == true) + assertNotEquals(result.isError, true) assertEquals(meta, result._meta) } @@ -42,7 +41,7 @@ class CallToolResultUtilsTest { assertEquals(1, result.content.size) assertEquals(content, (result.content[0] as TextContent).text) - assertTrue(result.isError == true) + assertEquals(result.isError, true) assertEquals(EmptyJsonObject, result._meta) } @@ -57,8 +56,7 @@ class CallToolResultUtilsTest { assertEquals(1, result.content.size) assertEquals(content, (result.content[0] as TextContent).text) - assertTrue(result.isError == true) + assertEquals(result.isError, true) assertEquals(meta, result._meta) } } - diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt index 0e1f704a..0664351f 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt @@ -55,30 +55,42 @@ class ToolSerializationTest { annotations = null, inputSchema = Tool.Input( properties = buildJsonObject { - put("location", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("The city and state, e.g. San Francisco, CA")) - }) + put( + "location", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("The city and state, e.g. San Francisco, CA")) + }, + ) }, - required = listOf("location") + required = listOf("location"), ), outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) //region Serialize @@ -120,23 +132,35 @@ class ToolSerializationTest { name = "get_weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) - val expectedJson = createWeatherToolJson(name = "get_weather", outputSchema = """ + val expectedJson = + createWeatherToolJson( + name = "get_weather", + outputSchema = """ { "type": "object", "properties": { @@ -155,7 +179,8 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val actualJson = McpJson.encodeToString(weatherTool) @@ -169,21 +194,30 @@ class ToolSerializationTest { title = "Get weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) val expectedJson = createWeatherToolJson( name = "get_weather", @@ -207,7 +241,8 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val actualJson = McpJson.encodeToString(weatherTool) @@ -245,7 +280,10 @@ class ToolSerializationTest { @Test fun `should deserialize get_weather tool with outputSchema optional property specified`() { - val toolJson = createWeatherToolJson(name = "get_weather", outputSchema = """ + val toolJson = + createWeatherToolJson( + name = "get_weather", + outputSchema = """ { "type": "object", "properties": { @@ -264,27 +302,37 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val expectedTool = createWeatherTool( name = "get_weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) val actualTool = McpJson.decodeFromString(toolJson) @@ -316,28 +364,38 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val expectedTool = createWeatherTool( name = "get_weather", title = "Get weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) val actualTool = McpJson.decodeFromString(toolJson) @@ -352,9 +410,8 @@ class ToolSerializationTest { private fun createWeatherToolJson( name: String = "get_weather", title: String? = null, - outputSchema: String? = null + outputSchema: String? = null, ): String { - val stringBuilder = StringBuilder() stringBuilder @@ -371,7 +428,8 @@ class ToolSerializationTest { .appendLine(",") .append(" \"description\": \"Get the current weather in a given location\"") .appendLine(",") - .append(""" + .append( + """ "inputSchema": { "type": "object", "properties": { @@ -382,46 +440,49 @@ class ToolSerializationTest { }, "required": ["location"] } - """.trimIndent()) + """.trimIndent(), + ) if (outputSchema != null) { stringBuilder .appendLine(",") - .append(""" + .append( + """ "outputSchema": $outputSchema - """.trimIndent()) + """.trimIndent(), + ) } stringBuilder .appendLine() .appendLine("}") - return stringBuilder.toString().trimIndent() } private fun createWeatherTool( name: String = "get_weather", title: String? = null, - outputSchema: Tool.Output? = null - ): Tool { - return Tool( - name = name, - title = title, - description = "Get the current weather in a given location", - annotations = null, - inputSchema = Tool.Input( - properties = buildJsonObject { - put("location", buildJsonObject { + outputSchema: Tool.Output? = null, + ): Tool = Tool( + name = name, + title = title, + description = "Get the current weather in a given location", + annotations = null, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "location", + buildJsonObject { put("type", JsonPrimitive("string")) put("description", JsonPrimitive("The city and state, e.g. San Francisco, CA")) - }) - }, - required = listOf("location") - ), - outputSchema = outputSchema - ) - } + }, + ) + }, + required = listOf("location"), + ), + outputSchema = outputSchema, + ) //endregion Private Methods } diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/BaseTransportTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/BaseTransportTest.kt index 2c82ff72..8fa1e71c 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/BaseTransportTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/BaseTransportTest.kt @@ -33,8 +33,9 @@ abstract class BaseTransportTest { fail("Unexpected error: $error") } - val messages = listOf( - PingRequest().toJSON(), InitializedNotification().toJSON() + val messages = listOf( + PingRequest().toJSON(), + InitializedNotification().toJSON(), ) val readMessages = mutableListOf() diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/InMemoryTransportTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/InMemoryTransportTest.kt index 6ab3feaf..c1dcc51b 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/InMemoryTransportTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/InMemoryTransportTest.kt @@ -88,7 +88,7 @@ class InMemoryTransportTest { assertFailsWith { clientTransport.send( - InitializedNotification().toJSON() + InitializedNotification().toJSON(), ) } } diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/TypesTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/TypesTest.kt index 1714ded6..f2312690 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/TypesTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/TypesTest.kt @@ -46,4 +46,4 @@ class TypesTest { """.trimIndent() McpJson.decodeFromString(line) } -} \ No newline at end of file +} diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 19d84589..bb722983 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -40,17 +40,13 @@ class SseIntegrationTest { } } - private inline fun assertDoesNotThrow(block: () -> T): T { - return try { - block() - } catch (e: Throwable) { - fail("Expected no exception, but got: $e") - } + private inline fun assertDoesNotThrow(block: () -> T): T = try { + block() + } catch (e: Throwable) { + fail("Expected no exception, but got: $e") } - private suspend fun initClient(): Client { - return HttpClient(ClientCIO) { install(SSE) }.mcpSse("http://$URL:$PORT") - } + private suspend fun initClient(): Client = HttpClient(ClientCIO) { install(SSE) }.mcpSse("http://$URL:$PORT") private suspend fun initServer(): EmbeddedServer { val server = Server( @@ -70,4 +66,4 @@ class SseIntegrationTest { private const val PORT = 3001 private const val URL = "localhost" } -} \ No newline at end of file +} diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt index 6890aef6..8e6f4f65 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt @@ -4,7 +4,6 @@ import io.ktor.utils.io.charsets.Charsets import io.ktor.utils.io.core.toByteArray import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification -import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlin.test.Test import kotlin.test.assertEquals diff --git a/src/iosMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.ios.kt b/src/iosMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.ios.kt index 17f6555d..e87df8d0 100644 --- a/src/iosMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.ios.kt +++ b/src/iosMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.ios.kt @@ -5,4 +5,4 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.IO internal actual val IODispatcher: CoroutineDispatcher - get() = Dispatchers.IO \ No newline at end of file + get() = Dispatchers.IO diff --git a/src/jsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.js.kt b/src/jsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.js.kt index 1ecad771..b84cbeac 100644 --- a/src/jsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.js.kt +++ b/src/jsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.js.kt @@ -4,4 +4,4 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers internal actual val IODispatcher: CoroutineDispatcher - get() = Dispatchers.Default \ No newline at end of file + get() = Dispatchers.Default diff --git a/src/jvmMain/java/io/modelcontextprotocol/kotlin/sdk/internal/utils.jvm.kt b/src/jvmMain/java/io/modelcontextprotocol/kotlin/sdk/internal/utils.jvm.kt index 2c44eec8..ea112244 100644 --- a/src/jvmMain/java/io/modelcontextprotocol/kotlin/sdk/internal/utils.jvm.kt +++ b/src/jvmMain/java/io/modelcontextprotocol/kotlin/sdk/internal/utils.jvm.kt @@ -4,4 +4,4 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers internal actual val IODispatcher: CoroutineDispatcher - get() = Dispatchers.IO \ No newline at end of file + get() = Dispatchers.IO diff --git a/src/jvmTest/kotlin/client/ClientIntegrationTest.kt b/src/jvmTest/kotlin/client/ClientIntegrationTest.kt index 5eee8179..311ed12e 100644 --- a/src/jvmTest/kotlin/client/ClientIntegrationTest.kt +++ b/src/jvmTest/kotlin/client/ClientIntegrationTest.kt @@ -19,7 +19,7 @@ class ClientIntegrationTest { return StdioClientTransport( socket.inputStream.asSource().buffered(), - socket.outputStream.asSink().buffered() + socket.outputStream.asSink().buffered(), ) } @@ -34,12 +34,10 @@ class ClientIntegrationTest { try { client.connect(transport) - val response: ListToolsResult? = client.listTools() - println(response?.tools) - + val response: ListToolsResult = client.listTools() + println(response.tools) } finally { transport.close() } } - } diff --git a/src/jvmTest/kotlin/client/ClientTest.kt b/src/jvmTest/kotlin/client/ClientTest.kt index 91bd12de..fee8a4b4 100644 --- a/src/jvmTest/kotlin/client/ClientTest.kt +++ b/src/jvmTest/kotlin/client/ClientTest.kt @@ -69,13 +69,13 @@ class ClientTest { capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) @@ -88,13 +88,13 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject - ) - ) + sampling = EmptyJsonObject, + ), + ), ) client.connect(clientTransport) @@ -103,7 +103,7 @@ class ClientTest { @Test fun `should initialize with supported older protocol version`() = runTest { - val OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1] + val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1] val clientTransport = object : AbstractTransport() { override suspend fun start() {} @@ -112,17 +112,17 @@ class ClientTest { check(message.method == Method.Defined.Initialize.value) val result = InitializeResult( - protocolVersion = OLD_VERSION, + protocolVersion = oldVersion, capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) } @@ -134,19 +134,19 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject - ) - ) + sampling = EmptyJsonObject, + ), + ), ) client.connect(clientTransport) assertEquals( Implementation("test", "1.0"), - client.serverVersion + client.serverVersion, ) } @@ -165,13 +165,13 @@ class ClientTest { capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) @@ -185,9 +185,9 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), - options = ClientOptions() + options = ClientOptions(), ) assertFailsWith("Server's protocol version is not supported: invalid-version") { @@ -212,13 +212,13 @@ class ClientTest { capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) @@ -233,13 +233,13 @@ class ClientTest { Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), - options = ClientOptions() - ) + options = ClientOptions(), + ), ) - coEvery{ + coEvery { mockClient.request(any()) } throws IllegalStateException("Test error") @@ -257,30 +257,30 @@ class ClientTest { val serverOptions = ServerOptions( capabilities = ServerCapabilities( resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) - server.setRequestHandler(Method.Defined.Initialize) { request, _ -> + server.setRequestHandler(Method.Defined.Initialize) { _, _ -> InitializeResult( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) + tools = ServerCapabilities.Tools(null), ), - serverInfo = Implementation(name = "test", version = "1.0") + serverInfo = Implementation(name = "test", version = "1.0"), ) } - server.setRequestHandler(Method.Defined.ResourcesList) { request, _ -> + server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> ListResourcesResult(resources = emptyList(), nextCursor = null) } - server.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> ListToolsResult(tools = emptyList(), nextCursor = null) } @@ -290,7 +290,7 @@ class ClientTest { clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ) + ), ) listOf( @@ -299,14 +299,14 @@ class ClientTest { }, launch { server.connect(serverTransport) - } + }, ).joinAll() // Server supports resources and tools, but not prompts val caps = client.serverCapabilities assertEquals(ServerCapabilities.Resources(null, null), caps?.resources) assertEquals(ServerCapabilities.Tools(null), caps?.tools) - assertTrue(caps?.prompts == null) // or check that prompts are absent + assertEquals(caps?.prompts, null) // or check that prompts are absent // These should not throw client.listResources() @@ -316,23 +316,23 @@ class ClientTest { val ex = assertFailsWith { client.listPrompts() } - assertTrue(ex.message?.contains("Server does not support prompts") == true) + assertTrue(ex.message?.contains("Server does not support prompts") ?: false) } @Test fun `should respect client notification capabilities`() = runTest { val server = Server( Implementation(name = "test server", version = "1.0"), - ServerOptions(capabilities = ServerCapabilities()) + ServerOptions(capabilities = ServerCapabilities()), ) val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(listChanged = true) - ) - ) + roots = ClientCapabilities.Roots(listChanged = true), + ), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -345,7 +345,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() // This should not throw because the client supports roots.listChanged @@ -357,7 +357,7 @@ class ClientTest { options = ClientOptions( capabilities = ClientCapabilities(), // enforceStrictCapabilities = true // TODO() - ) + ), ) clientWithoutCapability.connect(clientTransport) @@ -368,7 +368,7 @@ class ClientTest { val ex = assertFailsWith { clientWithoutCapability.sendRootsListChanged() } - assertTrue(ex.message?.startsWith("Client does not support") == true) + assertTrue(ex.message?.startsWith("Client does not support") ?: false) } @Test @@ -378,16 +378,16 @@ class ClientTest { ServerOptions( capabilities = ServerCapabilities( logging = EmptyJsonObject, - resources = ServerCapabilities.Resources(listChanged = true, subscribe = null) - ) - ) + resources = ServerCapabilities.Resources(listChanged = true, subscribe = null), + ), + ), ) val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -400,7 +400,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() // These should not throw @@ -412,8 +412,8 @@ class ClientTest { server.sendLoggingMessage( LoggingMessageNotification( level = LoggingLevel.info, - data = jsonObject - ) + data = jsonObject, + ), ) server.sendResourceListChanged() @@ -421,7 +421,7 @@ class ClientTest { val ex = assertFailsWith { server.sendToolListChanged() } - assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") == true) + assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") ?: false) } @Test @@ -429,13 +429,15 @@ class ClientTest { val server = Server( Implementation(name = "test server", version = "1.0"), ServerOptions( - capabilities = ServerCapabilities(resources = ServerCapabilities.Resources(listChanged = null, subscribe = null)) - ) + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(listChanged = null, subscribe = null), + ), + ), ) val def = CompletableDeferred() val defTimeOut = CompletableDeferred() - server.setRequestHandler(Method.Defined.ResourcesList) { _, extra -> + server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> // Simulate delay def.complete(Unit) try { @@ -444,15 +446,15 @@ class ClientTest { defTimeOut.complete(Unit) throw e } - fail("Shouldn't have been called") ListResourcesResult(resources = emptyList()) + fail("Shouldn't have been called") } val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions(capabilities = ClientCapabilities()) + options = ClientOptions(capabilities = ClientCapabilities()), ) listOf( @@ -463,7 +465,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() val defCancel = CompletableDeferred() @@ -486,11 +488,13 @@ class ClientTest { val server = Server( Implementation(name = "test server", version = "1.0"), ServerOptions( - capabilities = ServerCapabilities(resources = ServerCapabilities.Resources(listChanged = null, subscribe = null)) - ) + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(listChanged = null, subscribe = null), + ), + ), ) - server.setRequestHandler(Method.Defined.ResourcesList) { _, extra -> + server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> // Simulate a delayed response // Wait ~100ms unless canceled try { @@ -507,7 +511,7 @@ class ClientTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions(capabilities = ClientCapabilities()) + options = ClientOptions(capabilities = ClientCapabilities()), ) listOf( @@ -518,7 +522,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() // Request with 1 msec timeout should fail immediately @@ -535,22 +539,22 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject - ) - ) + sampling = EmptyJsonObject, + ), + ), ) - client.setRequestHandler(Method.Defined.SamplingCreateMessage) { request, _ -> + client.setRequestHandler(Method.Defined.SamplingCreateMessage) { _, _ -> CreateMessageResult( model = "test-model", role = Role.assistant, content = TextContent( - text = "Test response" - ) + text = "Test response", + ), ) } @@ -563,22 +567,22 @@ class ClientTest { fun `JSONRPCRequest with ToolsList method and default params returns list of tools`() = runTest { val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) - server.setRequestHandler(Method.Defined.Initialize) { request, _ -> + server.setRequestHandler(Method.Defined.Initialize) { _, _ -> InitializeResult( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) + tools = ServerCapabilities.Tools(null), ), - serverInfo = Implementation(name = "test", version = "1.0") + serverInfo = Implementation(name = "test", version = "1.0"), ) } val serverListToolsResult = ListToolsResult( @@ -589,12 +593,13 @@ class ClientTest { description = "testTool description", annotations = null, inputSchema = Tool.Input(), - outputSchema = null - ) - ), nextCursor = null + outputSchema = null, + ), + ), + nextCursor = null, ) - server.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> serverListToolsResult } @@ -604,7 +609,7 @@ class ClientTest { clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ) + ), ) var receivedMessage: JSONRPCMessage? = null @@ -618,14 +623,14 @@ class ClientTest { }, launch { server.connect(serverTransport) - } + }, ).joinAll() val serverCapabilities = client.serverCapabilities assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) val request = JSONRPCRequest( - method = Method.Defined.ToolsList.value + method = Method.Defined.ToolsList.value, ) clientTransport.send(request) @@ -643,13 +648,13 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(null) - ) - ) + roots = ClientCapabilities.Roots(null), + ), + ), ) val clientRoots = listOf( - Root(uri = "file:///test-root", name = "testRoot") + Root(uri = "file:///test-root", name = "testRoot"), ) client.addRoots(clientRoots) @@ -659,13 +664,13 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() val clientCapabilities = server.clientCapabilities @@ -681,8 +686,8 @@ class ClientTest { val client = Client( Implementation(name = "test client", version = "1.0"), ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) // Verify that adding a root throws an exception @@ -697,8 +702,8 @@ class ClientTest { val client = Client( Implementation(name = "test client", version = "1.0"), ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) // Verify that removing a root throws an exception @@ -714,9 +719,9 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(null) - ) - ) + roots = ClientCapabilities.Roots(null), + ), + ), ) // Add some roots @@ -724,7 +729,7 @@ class ClientTest { listOf( Root(uri = "file:///test-root1", name = "testRoot1"), Root(uri = "file:///test-root2", name = "testRoot2"), - ) + ), ) // Remove a root @@ -740,9 +745,9 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(null) - ) - ) + roots = ClientCapabilities.Roots(null), + ), + ), ) // Add some roots @@ -750,12 +755,12 @@ class ClientTest { listOf( Root(uri = "file:///test-root1", name = "testRoot1"), Root(uri = "file:///test-root2", name = "testRoot2"), - ) + ), ) // Remove multiple roots val result = client.removeRoots( - listOf("file:///test-root1", "file:///test-root2") + listOf("file:///test-root1", "file:///test-root2"), ) // Verify the root was removed @@ -768,9 +773,9 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(listChanged = true) - ) - ) + roots = ClientCapabilities.Roots(listChanged = true), + ), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -778,8 +783,8 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) // Track notifications @@ -791,14 +796,14 @@ class ClientTest { listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() client.sendRootsListChanged() assertTrue( rootListChangedNotificationReceived, - "Notification should be sent when sendRootsListChanged is called" + "Notification should be sent when sendRootsListChanged is called", ) } @@ -807,8 +812,8 @@ class ClientTest { val client = Client( Implementation(name = "test client", version = "1.0"), ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -816,13 +821,13 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() // Verify that creating an elicitation throws an exception @@ -835,13 +840,13 @@ class ClientTest { put("type", "string") } }, - required = listOf("name") - ) + required = listOf("name"), + ), ) } assertEquals( "Client does not support elicitation (required for elicitation/create)", - exception.message + exception.message, ) } @@ -852,8 +857,8 @@ class ClientTest { ClientOptions( capabilities = ClientCapabilities( elicitation = EmptyJsonObject, - ) - ) + ), + ), ) val elicitationMessage = "Please provide your GitHub username" @@ -863,7 +868,7 @@ class ClientTest { put("type", "string") } }, - required = listOf("name") + required = listOf("name"), ) val elicitationResultAction = CreateElicitationResult.Action.accept @@ -877,7 +882,7 @@ class ClientTest { CreateElicitationResult( action = elicitationResultAction, - content = elicitationResultContent + content = elicitationResultContent, ) } @@ -886,18 +891,18 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() val result = server.createElicitation( message = elicitationMessage, - requestedSchema = requestedSchema + requestedSchema = requestedSchema, ) assertEquals(elicitationResultAction, result.action) diff --git a/src/jvmTest/kotlin/client/StdioClientTransportTest.kt b/src/jvmTest/kotlin/client/StdioClientTransportTest.kt index 15defaed..686c39c3 100644 --- a/src/jvmTest/kotlin/client/StdioClientTransportTest.kt +++ b/src/jvmTest/kotlin/client/StdioClientTransportTest.kt @@ -1,8 +1,8 @@ package client import io.modelcontextprotocol.kotlin.sdk.client.BaseTransportTest -import kotlinx.coroutines.test.runTest import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport +import kotlinx.coroutines.test.runTest import kotlinx.io.asSink import kotlinx.io.asSource import kotlinx.io.buffered @@ -20,7 +20,7 @@ class StdioClientTransportTest : BaseTransportTest() { val client = StdioClientTransport( input = input, - output = output + output = output, ) testClientOpenClose(client) @@ -38,7 +38,7 @@ class StdioClientTransportTest : BaseTransportTest() { val client = StdioClientTransport( input = input, - output = output + output = output, ) testClientRead(client) diff --git a/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt index ca255f4a..f7cf4174 100644 --- a/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt +++ b/src/jvmTest/kotlin/client/StreamableHttpClientTransportTest.kt @@ -49,7 +49,7 @@ class StreamableHttpClientTransportTest { val message = JSONRPCRequest( id = RequestId.StringId("test-id"), method = "test", - params = buildJsonObject { } + params = buildJsonObject { }, ) val transport = createTransport { request -> @@ -63,7 +63,7 @@ class StreamableHttpClientTransportTest { respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } @@ -78,26 +78,30 @@ class StreamableHttpClientTransportTest { id = RequestId.StringId("test-id"), method = "initialize", params = buildJsonObject { - put("clientInfo", buildJsonObject { - put("name", JsonPrimitive("test-client")) - put("version", JsonPrimitive("1.0")) - }) + put( + "clientInfo", + buildJsonObject { + put("name", JsonPrimitive("test-client")) + put("version", JsonPrimitive("1.0")) + }, + ) put("protocolVersion", JsonPrimitive("2025-06-18")) - } + }, ) val transport = createTransport { request -> when (val msg = McpJson.decodeFromString((request.body as TextContent).text)) { is JSONRPCRequest if msg.method == "initialize" -> respond( - content = "", status = HttpStatusCode.OK, - headers = headersOf("mcp-session-id", "test-session-id") + content = "", + status = HttpStatusCode.OK, + headers = headersOf("mcp-session-id", "test-session-id"), ) is JSONRPCNotification if msg.method == "test" -> { assertEquals("test-session-id", request.headers["mcp-session-id"]) respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } @@ -124,7 +128,7 @@ class StreamableHttpClientTransportTest { assertEquals("test-session-id", request.headers["mcp-session-id"]) respond( content = "", - status = HttpStatusCode.OK + status = HttpStatusCode.OK, ) } @@ -143,7 +147,7 @@ class StreamableHttpClientTransportTest { assertEquals(HttpMethod.Delete, request.method) respond( content = "", - status = HttpStatusCode.MethodNotAllowed + status = HttpStatusCode.MethodNotAllowed, ) } @@ -164,7 +168,7 @@ class StreamableHttpClientTransportTest { assertEquals("2025-06-18", request.headers["mcp-protocol-version"]) respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } transport.protocolVersion = "2025-06-18" @@ -174,7 +178,9 @@ class StreamableHttpClientTransportTest { transport.close() } - @Ignore("Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support") + @Ignore( + "Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support", + ) @Test fun testNotificationSchemaE2E() = runTest { val receivedMessages = mutableListOf() @@ -186,7 +192,7 @@ class StreamableHttpClientTransportTest { respond( content = "", status = HttpStatusCode.Accepted, - headers = headersOf("mcp-session-id", "notification-test-session") + headers = headersOf("mcp-session-id", "notification-test-session"), ) } @@ -197,7 +203,9 @@ class StreamableHttpClientTransportTest { // Server sends various notifications appendLine("event: message") appendLine("id: 1") - appendLine("""data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""", + ) appendLine() appendLine("event: message") @@ -214,8 +222,9 @@ class StreamableHttpClientTransportTest { content = ByteReadChannel(sseContent), status = HttpStatusCode.OK, headers = headersOf( - HttpHeaders.ContentType, ContentType.Text.EventStream.toString() - ) + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), ) } @@ -223,7 +232,7 @@ class StreamableHttpClientTransportTest { HttpMethod.Post -> { respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } @@ -242,11 +251,14 @@ class StreamableHttpClientTransportTest { method = "notifications/initialized", params = buildJsonObject { put("protocolVersion", JsonPrimitive("1.0")) - put("capabilities", buildJsonObject { - put("tools", JsonPrimitive(true)) - put("resources", JsonPrimitive(true)) - }) - } + put( + "capabilities", + buildJsonObject { + put("tools", JsonPrimitive(true)) + put("resources", JsonPrimitive(true)) + }, + ) + }, ) transport.send(initializedNotification) @@ -278,25 +290,28 @@ class StreamableHttpClientTransportTest { params = buildJsonObject { put("progressToken", JsonPrimitive("download-456")) put("progress", JsonPrimitive(75)) - } + }, ), JSONRPCNotification( method = "notifications/cancelled", params = buildJsonObject { put("requestId", JsonPrimitive("req-789")) put("reason", JsonPrimitive("user_cancelled")) - } + }, ), JSONRPCNotification( method = "notifications/message", params = buildJsonObject { put("level", JsonPrimitive("info")) put("message", JsonPrimitive("Operation completed")) - put("data", buildJsonObject { - put("duration", JsonPrimitive(1234)) - }) - } - ) + put( + "data", + buildJsonObject { + put("duration", JsonPrimitive(1234)) + }, + ) + }, + ), ) // Send all client notifications @@ -309,7 +324,9 @@ class StreamableHttpClientTransportTest { transport.close() } - @Ignore("Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support") + @Ignore( + "Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support", + ) @Test fun testNotificationWithResumptionToken() = runTest { var resumptionTokenReceived: String? = null @@ -324,15 +341,18 @@ class StreamableHttpClientTransportTest { val sseContent = buildString { appendLine("event: message") appendLine("id: resume-100") - appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"${lastEventIdSent}"}}""") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"$lastEventIdSent"}}""", + ) appendLine() } respond( content = ByteReadChannel(sseContent), status = HttpStatusCode.OK, headers = headersOf( - HttpHeaders.ContentType, ContentType.Text.EventStream.toString() - ) + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), ) } @@ -348,12 +368,12 @@ class StreamableHttpClientTransportTest { method = "notifications/test", params = buildJsonObject { put("data", JsonPrimitive("test-data")) - } + }, ), resumptionToken = "previous-token-99", onResumptionToken = { token -> resumptionTokenReceived = token - } + }, ) // Wait for response diff --git a/src/jvmTest/kotlin/server/ServerTest.kt b/src/jvmTest/kotlin/server/ServerTest.kt index 35e07741..04839861 100644 --- a/src/jvmTest/kotlin/server/ServerTest.kt +++ b/src/jvmTest/kotlin/server/ServerTest.kt @@ -33,16 +33,16 @@ class ServerTest { // Create server with tools capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add a tool - server.addTool("test-tool", "Test Tool", Tool.Input()) { request -> + server.addTool("test-tool", "Test Tool", Tool.Input()) { _ -> CallToolResult(listOf(TextContent("Test result"))) } @@ -68,12 +68,12 @@ class ServerTest { // Create server with tools capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Setup client @@ -105,11 +105,11 @@ class ServerTest { fun `removeTool should throw when tools capability is not supported`() = runTest { // Create server without tools capability val serverOptions = ServerOptions( - capabilities = ServerCapabilities() + capabilities = ServerCapabilities(), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Verify that removing a tool throws an exception @@ -124,19 +124,19 @@ class ServerTest { // Create server with tools capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add tools - server.addTool("test-tool-1", "Test Tool 1") { request -> + server.addTool("test-tool-1", "Test Tool 1") { _ -> CallToolResult(listOf(TextContent("Test result 1"))) } - server.addTool("test-tool-2", "Test Tool 2") { request -> + server.addTool("test-tool-2", "Test Tool 2") { _ -> CallToolResult(listOf(TextContent("Test result 2"))) } @@ -162,20 +162,20 @@ class ServerTest { // Create server with prompts capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = false) - ) + prompts = ServerCapabilities.Prompts(listChanged = false), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add a prompt val testPrompt = Prompt("test-prompt", "Test Prompt", null) - server.addPrompt(testPrompt) { request -> + server.addPrompt(testPrompt) { _ -> GetPromptResult( description = "Test prompt description", - messages = listOf() + messages = listOf(), ) } @@ -201,27 +201,27 @@ class ServerTest { // Create server with prompts capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = false) - ) + prompts = ServerCapabilities.Prompts(listChanged = false), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add prompts val testPrompt1 = Prompt("test-prompt-1", "Test Prompt 1", null) val testPrompt2 = Prompt("test-prompt-2", "Test Prompt 2", null) - server.addPrompt(testPrompt1) { request -> + server.addPrompt(testPrompt1) { _ -> GetPromptResult( description = "Test prompt description 1", - messages = listOf() + messages = listOf(), ) } - server.addPrompt(testPrompt2) { request -> + server.addPrompt(testPrompt2) { _ -> GetPromptResult( description = "Test prompt description 2", - messages = listOf() + messages = listOf(), ) } @@ -247,12 +247,12 @@ class ServerTest { // Create server with resources capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null) - ) + resources = ServerCapabilities.Resources(null, null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add a resource @@ -261,8 +261,8 @@ class ServerTest { uri = testResourceUri, name = "Test Resource", description = "A test resource", - mimeType = "text/plain" - ) { request -> + mimeType = "text/plain", + ) { _ -> ReadResourceResult( contents = listOf( TextResourceContents( @@ -270,7 +270,7 @@ class ServerTest { uri = testResourceUri, mimeType = "text/plain", ), - ) + ), ) } @@ -296,12 +296,12 @@ class ServerTest { // Create server with resources capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null) - ) + resources = ServerCapabilities.Resources(null, null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add resources @@ -311,8 +311,8 @@ class ServerTest { uri = testResourceUri1, name = "Test Resource 1", description = "A test resource 1", - mimeType = "text/plain" - ) { request -> + mimeType = "text/plain", + ) { _ -> ReadResourceResult( contents = listOf( TextResourceContents( @@ -320,15 +320,15 @@ class ServerTest { uri = testResourceUri1, mimeType = "text/plain", ), - ) + ), ) } server.addResource( uri = testResourceUri2, name = "Test Resource 2", description = "A test resource 2", - mimeType = "text/plain" - ) { request -> + mimeType = "text/plain", + ) { _ -> ReadResourceResult( contents = listOf( TextResourceContents( @@ -336,7 +336,7 @@ class ServerTest { uri = testResourceUri2, mimeType = "text/plain", ), - ) + ), ) } @@ -362,12 +362,12 @@ class ServerTest { // Create server with prompts capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = false) - ) + prompts = ServerCapabilities.Prompts(listChanged = false), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Setup client @@ -399,11 +399,11 @@ class ServerTest { fun `removePrompt should throw when prompts capability is not supported`() = runTest { // Create server without prompts capability val serverOptions = ServerOptions( - capabilities = ServerCapabilities() + capabilities = ServerCapabilities(), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Verify that removing a prompt throws an exception @@ -418,12 +418,12 @@ class ServerTest { // Create server with resources capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null) - ) + resources = ServerCapabilities.Resources(null, null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Setup client @@ -434,7 +434,9 @@ class ServerTest { // Track notifications var resourceListChangedNotificationReceived = false - client.setNotificationHandler(Method.Defined.NotificationsResourcesListChanged) { + client.setNotificationHandler( + Method.Defined.NotificationsResourcesListChanged, + ) { resourceListChangedNotificationReceived = true CompletableDeferred(Unit) } @@ -448,18 +450,21 @@ class ServerTest { // Verify the result assertFalse(result, "Removing non-existent resource should return false") - assertFalse(resourceListChangedNotificationReceived, "No notification should be sent when resource doesn't exist") + assertFalse( + resourceListChangedNotificationReceived, + "No notification should be sent when resource doesn't exist", + ) } @Test fun `removeResource should throw when resources capability is not supported`() = runTest { // Create server without resources capability val serverOptions = ServerOptions( - capabilities = ServerCapabilities() + capabilities = ServerCapabilities(), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Verify that removing a resource throws an exception diff --git a/src/jvmTest/kotlin/server/StdioServerTransportTest.kt b/src/jvmTest/kotlin/server/StdioServerTransportTest.kt index 8c865aa2..5764b203 100644 --- a/src/jvmTest/kotlin/server/StdioServerTransportTest.kt +++ b/src/jvmTest/kotlin/server/StdioServerTransportTest.kt @@ -3,21 +3,25 @@ package server import io.modelcontextprotocol.kotlin.sdk.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.PingRequest -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.runBlocking import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport -import org.junit.jupiter.api.Assertions.* -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking import kotlinx.io.Sink import kotlinx.io.Source import kotlinx.io.asSink import kotlinx.io.asSource import kotlinx.io.buffered -import java.io.* +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.io.ByteArrayOutputStream +import java.io.PipedInputStream +import java.io.PipedOutputStream +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue class StdioServerTransportTest { private lateinit var input: PipedInputStream diff --git a/src/wasmJsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.wasmJs.kt b/src/wasmJsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.wasmJs.kt index 1ecad771..b84cbeac 100644 --- a/src/wasmJsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.wasmJs.kt +++ b/src/wasmJsMain/kotlin/io/modelcontextprotocol/kotlin/sdk/internal/utils.wasmJs.kt @@ -4,4 +4,4 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers internal actual val IODispatcher: CoroutineDispatcher - get() = Dispatchers.Default \ No newline at end of file + get() = Dispatchers.Default