Skip to content

Commit dc908e0

Browse files
committed
Fix websocket ktor server implementation, add test and logs
1 parent 2413805 commit dc908e0

File tree

14 files changed

+572
-287
lines changed

14 files changed

+572
-287
lines changed

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,13 @@ public open class Client(
158158
serverVersion = result.serverInfo
159159

160160
notification(InitializedNotification())
161+
} catch (error: CancellationException) {
162+
throw IllegalStateException("Error connecting to transport: ${error.message}")
161163
} catch (error: Throwable) {
164+
logger.error(error) { "Failed to initialize client" }
162165
close()
163-
if (error !is CancellationException) {
164-
throw IllegalStateException("Error connecting to transport: ${error.message}")
165-
}
166166

167167
throw error
168-
169168
}
170169
}
171170

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3-
import io.ktor.client.HttpClient
4-
import io.ktor.client.plugins.websocket.webSocketSession
5-
import io.ktor.client.request.HttpRequestBuilder
6-
import io.ktor.client.request.header
7-
import io.ktor.http.HttpHeaders
8-
import io.ktor.websocket.WebSocketSession
3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import io.ktor.client.*
5+
import io.ktor.client.plugins.websocket.*
6+
import io.ktor.client.request.*
7+
import io.ktor.http.*
8+
import io.ktor.websocket.*
99
import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL
1010
import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport
1111
import kotlin.properties.Delegates
1212

13+
private val logger = KotlinLogging.logger {}
14+
1315
/**
1416
* Client transport for WebSocket: this will connect to a server over the WebSocket protocol.
1517
*/
@@ -21,6 +23,8 @@ public class WebSocketClientTransport(
2123
override var session: WebSocketSession by Delegates.notNull()
2224

2325
override suspend fun initializeSession() {
26+
logger.debug { "Websocket session initialization started..." }
27+
2428
session = urlString?.let {
2529
client.webSocketSession(it) {
2630
requestBuilder()
@@ -32,5 +36,7 @@ public class WebSocketClientTransport(
3236

3337
header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL)
3438
}
39+
40+
logger.debug { "Websocket session initialization finished" }
3541
}
3642
}

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.ktor.client.HttpClient
45
import io.ktor.client.request.HttpRequestBuilder
56
import io.modelcontextprotocol.kotlin.sdk.Implementation
67
import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION
78
import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME
89

10+
private val logger = KotlinLogging.logger {}
11+
12+
913
/**
1014
* Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient.
1115
*
@@ -36,6 +40,8 @@ public suspend fun HttpClient.mcpWebSocket(
3640
version = LIB_VERSION
3741
)
3842
)
43+
logger.debug { "Client started to connect to server" }
3944
client.connect(transport)
45+
logger.debug { "Client finished to connect to server" }
4046
return client
4147
}

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import io.ktor.server.response.respond
88
import io.ktor.server.routing.Routing
99
import io.ktor.server.routing.RoutingContext
1010
import io.ktor.server.routing.post
11-
import io.ktor.server.routing.route
1211
import io.ktor.server.routing.routing
1312
import io.ktor.server.sse.SSE
1413
import io.ktor.server.sse.ServerSSESession
@@ -18,39 +17,40 @@ import kotlinx.atomicfu.AtomicRef
1817
import kotlinx.atomicfu.atomic
1918
import kotlinx.atomicfu.update
2019
import kotlinx.collections.immutable.PersistentMap
21-
import kotlinx.collections.immutable.persistentMapOf
20+
import kotlinx.collections.immutable.toPersistentMap
2221

2322
private val logger = KotlinLogging.logger {}
2423

25-
@KtorDsl
26-
public fun Routing.mcp(path: String, block: () -> Server) {
27-
route(path) {
28-
mcp(block)
24+
internal class SseTransportManager(transports: Map<String, SseServerTransport> = emptyMap()) {
25+
private val transports: AtomicRef<PersistentMap<String, SseServerTransport>> = atomic(transports.toPersistentMap())
26+
27+
fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId]
28+
29+
fun addTransport(transport: SseServerTransport) {
30+
transports.update { it.put(transport.sessionId, transport) }
31+
}
32+
33+
fun removeTransport(sessionId: String) {
34+
transports.update { it.remove(sessionId) }
2935
}
3036
}
3137

32-
/**
33-
* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
34-
*/
38+
/*
39+
* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
40+
*/
3541
@KtorDsl
3642
public fun Routing.mcp(block: () -> Server) {
37-
val transports = atomic(persistentMapOf<String, SseServerTransport>())
43+
val sseTransportManager = SseTransportManager()
3844

3945
sse {
40-
mcpSseEndpoint("", transports, block)
46+
mcpSseEndpoint("", sseTransportManager, block)
4147
}
4248

4349
post {
44-
mcpPostEndpoint(transports)
50+
mcpPostEndpoint(sseTransportManager)
4551
}
4652
}
4753

48-
@Suppress("FunctionName")
49-
@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.WARNING)
50-
public fun Application.MCP(block: () -> Server) {
51-
mcp(block)
52-
}
53-
5454
@KtorDsl
5555
public fun Application.mcp(block: () -> Server) {
5656
install(SSE)
@@ -62,16 +62,16 @@ public fun Application.mcp(block: () -> Server) {
6262

6363
internal suspend fun ServerSSESession.mcpSseEndpoint(
6464
postEndpoint: String,
65-
transports: AtomicRef<PersistentMap<String, SseServerTransport>>,
65+
sseTransportManager: SseTransportManager,
6666
block: () -> Server,
6767
) {
68-
val transport = mcpSseTransport(postEndpoint, transports)
68+
val transport = mcpSseTransport(postEndpoint, sseTransportManager)
6969

7070
val server = block()
7171

7272
server.onClose {
7373
logger.info { "Server connection closed for sessionId: ${transport.sessionId}" }
74-
transports.update { it.remove(transport.sessionId) }
74+
sseTransportManager.removeTransport(transport.sessionId)
7575
}
7676

7777
server.connectSession(transport)
@@ -81,17 +81,17 @@ internal suspend fun ServerSSESession.mcpSseEndpoint(
8181

8282
internal fun ServerSSESession.mcpSseTransport(
8383
postEndpoint: String,
84-
transports: AtomicRef<PersistentMap<String, SseServerTransport>>,
84+
sseTransportManager: SseTransportManager,
8585
): SseServerTransport {
8686
val transport = SseServerTransport(postEndpoint, this)
87-
transports.update { it.put(transport.sessionId, transport) }
87+
sseTransportManager.addTransport(transport)
8888
logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
8989

9090
return transport
9191
}
9292

9393
internal suspend fun RoutingContext.mcpPostEndpoint(
94-
transports: AtomicRef<PersistentMap<String, SseServerTransport>>,
94+
sseTransportManager: SseTransportManager,
9595
) {
9696
val sessionId: String = call.request.queryParameters["sessionId"]
9797
?: run {
@@ -101,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(
101101

102102
logger.debug { "Received message for sessionId: $sessionId" }
103103

104-
val transport = transports.value[sessionId]
104+
val transport = sseTransportManager.getTransport(sessionId)
105105
if (transport == null) {
106106
logger.warn { "Session not found for sessionId: $sessionId" }
107107
call.respond(HttpStatusCode.NotFound, "Session not found")

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ public open class Server(
191191
}
192192
}
193193

194+
logger.debug { "Server session connecting to transport" }
194195
session.connect(transport)
196+
logger.debug { "Server session successfully connected to transport" }
195197
sessions.update { it.add(session) }
196198

197199
_onConnect()

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public open class ServerSession(
9292
* Called when the server session is closing.
9393
*/
9494
override fun onClose() {
95-
logger.info { "Server connection closing" }
95+
logger.debug { "Server connection closing" }
9696
_onClose()
9797
}
9898

@@ -322,7 +322,7 @@ public open class ServerSession(
322322
}
323323

324324
private suspend fun handleInitialize(request: InitializeRequest): InitializeResult {
325-
logger.info { "Handling initialize request from client ${request.clientInfo}" }
325+
logger.debug { "Handling initialization request from client" }
326326
clientCapabilities = request.capabilities
327327
clientVersion = request.clientInfo
328328

0 commit comments

Comments
 (0)