Skip to content

Commit deb2b9d

Browse files
committed
Introduce server session
1 parent 8698b96 commit deb2b9d

File tree

5 files changed

+260
-119
lines changed

5 files changed

+260
-119
lines changed

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.ktor.client.HttpClient
45
import io.ktor.client.plugins.sse.ClientSSESession
56
import io.ktor.client.plugins.sse.sseSession
@@ -46,6 +47,8 @@ public class SseClientTransport(
4647
private val reconnectionTime: Duration? = null,
4748
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
4849
) : AbstractTransport() {
50+
private val logger = KotlinLogging.logger {}
51+
4952
private val initialized: AtomicBoolean = AtomicBoolean(false)
5053
private val endpoint = CompletableDeferred<String>()
5154

@@ -68,6 +71,7 @@ public class SseClientTransport(
6871
check(initialized.compareAndSet(expectedValue = false, newValue = true)) {
6972
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically."
7073
}
74+
logger.info { "Starting SseClientTransport..." }
7175

7276
try {
7377
session = urlString?.let {
@@ -111,6 +115,8 @@ public class SseClientTransport(
111115
val text = response.bodyAsText()
112116
error("Error POSTing to endpoint (HTTP ${response.status}): $text")
113117
}
118+
119+
logger.debug { "Client successfully sent message via SSE $endpoint" }
114120
} catch (e: Throwable) {
115121
_onError(e)
116122
throw e
@@ -157,6 +163,7 @@ public class SseClientTransport(
157163
val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData
158164
val endpointUrl = Url("$baseUrl/$path")
159165
endpoint.complete(endpointUrl.toString())
166+
logger.debug { "Client connected to endpoint: $endpointUrl" }
160167
} catch (e: Throwable) {
161168
_onError(e)
162169
endpoint.completeExceptionally(e)

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@ import io.ktor.server.routing.routing
1313
import io.ktor.server.sse.SSE
1414
import io.ktor.server.sse.ServerSSESession
1515
import io.ktor.server.sse.sse
16-
import io.ktor.util.collections.ConcurrentMap
1716
import io.ktor.utils.io.KtorDsl
17+
import kotlinx.atomicfu.AtomicRef
18+
import kotlinx.atomicfu.atomic
19+
import kotlinx.atomicfu.update
20+
import kotlinx.collections.immutable.PersistentMap
21+
import kotlinx.collections.immutable.persistentMapOf
1822

1923
private val logger = KotlinLogging.logger {}
2024

@@ -30,7 +34,7 @@ public fun Routing.mcp(path: String, block: () -> Server) {
3034
*/
3135
@KtorDsl
3236
public fun Routing.mcp(block: () -> Server) {
33-
val transports = ConcurrentMap<String, SseServerTransport>()
37+
val transports = atomic(persistentMapOf<String, SseServerTransport>())
3438

3539
sse {
3640
mcpSseEndpoint("", transports, block)
@@ -49,24 +53,16 @@ public fun Application.MCP(block: () -> Server) {
4953

5054
@KtorDsl
5155
public fun Application.mcp(block: () -> Server) {
52-
val transports = ConcurrentMap<String, SseServerTransport>()
53-
5456
install(SSE)
5557

5658
routing {
57-
sse("/sse") {
58-
mcpSseEndpoint("/message", transports, block)
59-
}
60-
61-
post("/message") {
62-
mcpPostEndpoint(transports)
63-
}
59+
mcp(block)
6460
}
6561
}
6662

67-
private suspend fun ServerSSESession.mcpSseEndpoint(
63+
internal suspend fun ServerSSESession.mcpSseEndpoint(
6864
postEndpoint: String,
69-
transports: ConcurrentMap<String, SseServerTransport>,
65+
transports: AtomicRef<PersistentMap<String, SseServerTransport>>,
7066
block: () -> Server,
7167
) {
7268
val transport = mcpSseTransport(postEndpoint, transports)
@@ -75,27 +71,27 @@ private suspend fun ServerSSESession.mcpSseEndpoint(
7571

7672
server.onClose {
7773
logger.info { "Server connection closed for sessionId: ${transport.sessionId}" }
78-
transports.remove(transport.sessionId)
74+
transports.update { it.remove(transport.sessionId) }
7975
}
8076

81-
server.connect(transport)
77+
server.connectSession(transport)
78+
8279
logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" }
8380
}
8481

8582
internal fun ServerSSESession.mcpSseTransport(
8683
postEndpoint: String,
87-
transports: ConcurrentMap<String, SseServerTransport>,
84+
transports: AtomicRef<PersistentMap<String, SseServerTransport>>,
8885
): SseServerTransport {
8986
val transport = SseServerTransport(postEndpoint, this)
90-
transports[transport.sessionId] = transport
91-
87+
transports.update { it.put(transport.sessionId, transport) }
9288
logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
9389

9490
return transport
9591
}
9692

9793
internal suspend fun RoutingContext.mcpPostEndpoint(
98-
transports: ConcurrentMap<String, SseServerTransport>,
94+
transports: AtomicRef<PersistentMap<String, SseServerTransport>>,
9995
) {
10096
val sessionId: String = call.request.queryParameters["sessionId"]
10197
?: run {
@@ -105,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(
105101

106102
logger.debug { "Received message for sessionId: $sessionId" }
107103

108-
val transport = transports[sessionId]
104+
val transport = transports.value[sessionId]
109105
if (transport == null) {
110106
logger.warn { "Session not found for sessionId: $sessionId" }
111107
call.respond(HttpStatusCode.NotFound, "Session not found")

0 commit comments

Comments
 (0)