@@ -13,8 +13,12 @@ import io.ktor.server.routing.routing
13
13
import io.ktor.server.sse.SSE
14
14
import io.ktor.server.sse.ServerSSESession
15
15
import io.ktor.server.sse.sse
16
- import io.ktor.util.collections.ConcurrentMap
17
16
import io.ktor.utils.io.KtorDsl
17
+ import kotlinx.atomicfu.AtomicRef
18
+ import kotlinx.atomicfu.atomic
19
+ import kotlinx.atomicfu.update
20
+ import kotlinx.collections.immutable.PersistentMap
21
+ import kotlinx.collections.immutable.persistentMapOf
18
22
19
23
private val logger = KotlinLogging .logger {}
20
24
@@ -30,7 +34,7 @@ public fun Routing.mcp(path: String, block: () -> Server) {
30
34
*/
31
35
@KtorDsl
32
36
public fun Routing.mcp (block : () -> Server ) {
33
- val transports = ConcurrentMap <String , SseServerTransport >()
37
+ val transports = atomic(persistentMapOf <String , SseServerTransport >() )
34
38
35
39
sse {
36
40
mcpSseEndpoint(" " , transports, block)
@@ -49,24 +53,16 @@ public fun Application.MCP(block: () -> Server) {
49
53
50
54
@KtorDsl
51
55
public fun Application.mcp (block : () -> Server ) {
52
- val transports = ConcurrentMap <String , SseServerTransport >()
53
-
54
56
install(SSE )
55
57
56
58
routing {
57
- sse(" /sse" ) {
58
- mcpSseEndpoint(" /message" , transports, block)
59
- }
60
-
61
- post(" /message" ) {
62
- mcpPostEndpoint(transports)
63
- }
59
+ mcp(block)
64
60
}
65
61
}
66
62
67
- private suspend fun ServerSSESession.mcpSseEndpoint (
63
+ internal suspend fun ServerSSESession.mcpSseEndpoint (
68
64
postEndpoint : String ,
69
- transports : ConcurrentMap < String , SseServerTransport >,
65
+ transports : AtomicRef < PersistentMap < String , SseServerTransport > >,
70
66
block : () -> Server ,
71
67
) {
72
68
val transport = mcpSseTransport(postEndpoint, transports)
@@ -75,27 +71,27 @@ private suspend fun ServerSSESession.mcpSseEndpoint(
75
71
76
72
server.onClose {
77
73
logger.info { " Server connection closed for sessionId: ${transport.sessionId} " }
78
- transports.remove(transport.sessionId)
74
+ transports.update { it. remove(transport.sessionId) }
79
75
}
80
76
81
- server.connect(transport)
77
+ server.connectSession(transport)
78
+
82
79
logger.debug { " Server connected to transport for sessionId: ${transport.sessionId} " }
83
80
}
84
81
85
82
internal fun ServerSSESession.mcpSseTransport (
86
83
postEndpoint : String ,
87
- transports : ConcurrentMap < String , SseServerTransport >,
84
+ transports : AtomicRef < PersistentMap < String , SseServerTransport > >,
88
85
): SseServerTransport {
89
86
val transport = SseServerTransport (postEndpoint, this )
90
- transports[transport.sessionId] = transport
91
-
87
+ transports.update { it.put(transport.sessionId, transport) }
92
88
logger.info { " New SSE connection established and stored with sessionId: ${transport.sessionId} " }
93
89
94
90
return transport
95
91
}
96
92
97
93
internal suspend fun RoutingContext.mcpPostEndpoint (
98
- transports : ConcurrentMap < String , SseServerTransport >,
94
+ transports : AtomicRef < PersistentMap < String , SseServerTransport > >,
99
95
) {
100
96
val sessionId: String = call.request.queryParameters[" sessionId" ]
101
97
? : run {
@@ -105,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(
105
101
106
102
logger.debug { " Received message for sessionId: $sessionId " }
107
103
108
- val transport = transports[sessionId]
104
+ val transport = transports.value [sessionId]
109
105
if (transport == null ) {
110
106
logger.warn { " Session not found for sessionId: $sessionId " }
111
107
call.respond(HttpStatusCode .NotFound , " Session not found" )
0 commit comments