Skip to content

Commit d50c661

Browse files
committed
Fix websocket ktor server implementation, add test and logs
1 parent deb2b9d commit d50c661

File tree

13 files changed

+908
-88
lines changed

13 files changed

+908
-88
lines changed

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,13 @@ public open class Client(
158158
serverVersion = result.serverInfo
159159

160160
notification(InitializedNotification())
161+
} catch (error: CancellationException) {
162+
throw IllegalStateException("Error connecting to transport: ${error.message}")
161163
} catch (error: Throwable) {
164+
logger.error(error) { "Failed to initialize client" }
162165
close()
163-
if (error !is CancellationException) {
164-
throw IllegalStateException("Error connecting to transport: ${error.message}")
165-
}
166166

167167
throw error
168-
169168
}
170169
}
171170

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3-
import io.ktor.client.HttpClient
4-
import io.ktor.client.plugins.websocket.webSocketSession
5-
import io.ktor.client.request.HttpRequestBuilder
6-
import io.ktor.client.request.header
7-
import io.ktor.http.HttpHeaders
8-
import io.ktor.websocket.WebSocketSession
3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import io.ktor.client.*
5+
import io.ktor.client.plugins.websocket.*
6+
import io.ktor.client.request.*
7+
import io.ktor.http.*
8+
import io.ktor.websocket.*
99
import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL
1010
import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport
1111
import kotlin.properties.Delegates
1212

13+
private val logger = KotlinLogging.logger {}
14+
1315
/**
1416
* Client transport for WebSocket: this will connect to a server over the WebSocket protocol.
1517
*/
@@ -21,6 +23,8 @@ public class WebSocketClientTransport(
2123
override var session: WebSocketSession by Delegates.notNull()
2224

2325
override suspend fun initializeSession() {
26+
logger.debug { "Websocket session initialization started..." }
27+
2428
session = urlString?.let {
2529
client.webSocketSession(it) {
2630
requestBuilder()
@@ -32,5 +36,7 @@ public class WebSocketClientTransport(
3236

3337
header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL)
3438
}
39+
40+
logger.debug { "Websocket session initialization finished" }
3541
}
3642
}

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.ktor.client.HttpClient
45
import io.ktor.client.request.HttpRequestBuilder
56
import io.modelcontextprotocol.kotlin.sdk.Implementation
67
import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION
78
import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME
89

10+
private val logger = KotlinLogging.logger {}
11+
12+
913
/**
1014
* Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient.
1115
*
@@ -36,6 +40,8 @@ public suspend fun HttpClient.mcpWebSocket(
3640
version = LIB_VERSION
3741
)
3842
)
43+
logger.debug { "Client started to connect to server" }
3944
client.connect(transport)
45+
logger.debug { "Client finished to connect to server" }
4046
return client
4147
}

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import kotlin.reflect.typeOf
4242
import kotlin.time.Duration
4343
import kotlin.time.Duration.Companion.milliseconds
4444

45-
private val LOGGER = KotlinLogging.logger { }
45+
private val logger = KotlinLogging.logger { }
4646

4747
public const val IMPLEMENTATION_NAME: String = "mcp-ktor"
4848

@@ -204,6 +204,7 @@ public abstract class Protocol(
204204
}
205205
}
206206

207+
logger.info { "Starting transport" }
207208
return transport.start()
208209
}
209210

@@ -221,29 +222,29 @@ public abstract class Protocol(
221222
}
222223

223224
private suspend fun onNotification(notification: JSONRPCNotification) {
224-
LOGGER.trace { "Received notification: ${notification.method}" }
225+
logger.trace { "Received notification: ${notification.method}" }
225226

226227
val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler
227228

228229
if (handler == null) {
229-
LOGGER.trace { "No handler found for notification: ${notification.method}" }
230+
logger.trace { "No handler found for notification: ${notification.method}" }
230231
return
231232
}
232233
try {
233234
handler(notification)
234235
} catch (cause: Throwable) {
235-
LOGGER.error(cause) { "Error handling notification: ${notification.method}" }
236+
logger.error(cause) { "Error handling notification: ${notification.method}" }
236237
onError(cause)
237238
}
238239
}
239240

240241
private suspend fun onRequest(request: JSONRPCRequest) {
241-
LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" }
242+
logger.trace { "Received request: ${request.method} (id: ${request.id})" }
242243

243244
val handler = requestHandlers[request.method] ?: fallbackRequestHandler
244245

245246
if (handler === null) {
246-
LOGGER.trace { "No handler found for request: ${request.method}" }
247+
logger.trace { "No handler found for request: ${request.method}" }
247248
try {
248249
transport?.send(
249250
JSONRPCResponse(
@@ -255,15 +256,15 @@ public abstract class Protocol(
255256
)
256257
)
257258
} catch (cause: Throwable) {
258-
LOGGER.error(cause) { "Error sending method not found response" }
259+
logger.error(cause) { "Error sending method not found response" }
259260
onError(cause)
260261
}
261262
return
262263
}
263264

264265
try {
265266
val result = handler(request, RequestHandlerExtra())
266-
LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
267+
logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
267268

268269
transport?.send(
269270
JSONRPCResponse(
@@ -273,7 +274,7 @@ public abstract class Protocol(
273274
)
274275

275276
} catch (cause: Throwable) {
276-
LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }
277+
logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }
277278

278279
try {
279280
transport?.send(
@@ -286,14 +287,14 @@ public abstract class Protocol(
286287
)
287288
)
288289
} catch (sendError: Throwable) {
289-
LOGGER.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" }
290+
logger.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" }
290291
// Optionally implement fallback behavior here
291292
}
292293
}
293294
}
294295

295296
private fun onProgress(notification: ProgressNotification) {
296-
LOGGER.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" }
297+
logger.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" }
297298
val progress = notification.params.progress
298299
val total = notification.params.total
299300
val message = notification.params.message
@@ -304,7 +305,7 @@ public abstract class Protocol(
304305
val error = Error(
305306
"Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}",
306307
)
307-
LOGGER.error { error.message }
308+
logger.error { error.message }
308309
onError(error)
309310
return
310311
}
@@ -382,9 +383,9 @@ public abstract class Protocol(
382383
request: Request,
383384
options: RequestOptions? = null,
384385
): T {
385-
LOGGER.trace { "Sending request: ${request.method}" }
386+
logger.trace { "Sending request: ${request.method}" }
386387
val result = CompletableDeferred<T>()
387-
val transport = this@Protocol.transport ?: throw Error("Not connected")
388+
val transport = transport ?: throw Error("Not connected")
388389

389390
if (this@Protocol.options?.enforceStrictCapabilities == true) {
390391
assertCapabilityForMethod(request.method)
@@ -394,7 +395,7 @@ public abstract class Protocol(
394395
val messageId = message.id
395396

396397
if (options?.onProgress != null) {
397-
LOGGER.trace { "Registering progress handler for request id: $messageId" }
398+
logger.trace { "Registering progress handler for request id: $messageId" }
398399
_progressHandlers.update { current ->
399400
current.put(messageId, options.onProgress)
400401
}
@@ -427,7 +428,7 @@ public abstract class Protocol(
427428

428429
val notification = CancelledNotification(
429430
params = CancelledNotification.Params(
430-
requestId = messageId,
431+
requestId = messageId,
431432
reason = reason.message ?: "Unknown"
432433
)
433434
)
@@ -444,12 +445,12 @@ public abstract class Protocol(
444445
val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT
445446
try {
446447
withTimeout(timeout) {
447-
LOGGER.trace { "Sending request message with id: $messageId" }
448+
logger.trace { "Sending request message with id: $messageId" }
448449
this@Protocol.transport?.send(message)
449450
}
450451
return result.await()
451452
} catch (cause: TimeoutCancellationException) {
452-
LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
453+
logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
453454
cancel(
454455
McpError(
455456
ErrorCode.Defined.RequestTimeout.code,
@@ -466,7 +467,7 @@ public abstract class Protocol(
466467
* Emits a notification, which is a one-way message that does not expect a response.
467468
*/
468469
public suspend fun notification(notification: Notification) {
469-
LOGGER.trace { "Sending notification: ${notification.method}" }
470+
logger.trace { "Sending notification: ${notification.method}" }
470471
val transport = this.transport ?: error("Not connected")
471472
assertNotificationCapability(notification.method)
472473

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.ktor.websocket.Frame
45
import io.ktor.websocket.WebSocketSession
56
import io.ktor.websocket.close
@@ -17,6 +18,9 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi
1718

1819
public const val MCP_SUBPROTOCOL: String = "mcp"
1920

21+
private val logger = KotlinLogging.logger {}
22+
23+
2024
/**
2125
* Abstract class representing a WebSocket transport for the Model Context Protocol (MCP).
2226
* Handles communication over a WebSocket session.
@@ -40,6 +44,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
4044
protected abstract suspend fun initializeSession()
4145

4246
override suspend fun start() {
47+
logger.debug { "Starting websocket transport" }
48+
4349
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
4450
error(
4551
"WebSocketClientTransport already started! " +
@@ -53,7 +59,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
5359
while (true) {
5460
val message = try {
5561
session.incoming.receive()
56-
} catch (_: ClosedReceiveChannelException) {
62+
} catch (e: ClosedReceiveChannelException) {
63+
logger.debug { "Closed receive channel, exiting" }
5764
return@launch
5865
}
5966

@@ -84,6 +91,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
8491
}
8592

8693
override suspend fun send(message: JSONRPCMessage) {
94+
logger.debug { "Sending message" }
8795
if (!initialized.load()) {
8896
error("Not connected")
8997
}
@@ -96,6 +104,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
96104
error("Not connected")
97105
}
98106

107+
logger.debug { "Closing websocket session" }
99108
session.close()
100109
session.coroutineContext.job.join()
101110
}

0 commit comments

Comments
 (0)