Skip to content

Commit bd3cc4e

Browse files
devcrocode5lCopilot
authored
refactor SseClientTransport (#142)
* refactor `SseClientTransport` * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * replace `Exception` with `Throwable` in `SseClientTransport` error handling and try/catch `cancel` in closeResources --------- Co-authored-by: Leonid Stashevsky <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent e9eb109 commit bd3cc4e

File tree

3 files changed

+141
-97
lines changed

3 files changed

+141
-97
lines changed

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt

Lines changed: 101 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import io.ktor.http.protocolWithAuthority
1616
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
1717
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
1818
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
19+
import kotlinx.coroutines.CancellationException
1920
import kotlinx.coroutines.CompletableDeferred
2021
import kotlinx.coroutines.CoroutineName
2122
import kotlinx.coroutines.CoroutineScope
@@ -24,10 +25,11 @@ import kotlinx.coroutines.Job
2425
import kotlinx.coroutines.SupervisorJob
2526
import kotlinx.coroutines.cancel
2627
import kotlinx.coroutines.cancelAndJoin
28+
import kotlinx.coroutines.ensureActive
2729
import kotlinx.coroutines.launch
30+
import kotlinx.serialization.SerializationException
2831
import kotlin.concurrent.atomics.AtomicBoolean
2932
import kotlin.concurrent.atomics.ExperimentalAtomicApi
30-
import kotlin.properties.Delegates
3133
import kotlin.time.Duration
3234

3335
@Deprecated("Use SseClientTransport instead", ReplaceWith("SseClientTransport"), DeprecationLevel.WARNING)
@@ -44,97 +46,59 @@ public class SseClientTransport(
4446
private val reconnectionTime: Duration? = null,
4547
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
4648
) : AbstractTransport() {
47-
private val scope by lazy {
48-
CoroutineScope(session.coroutineContext + SupervisorJob())
49-
}
50-
5149
private val initialized: AtomicBoolean = AtomicBoolean(false)
52-
private var session: ClientSSESession by Delegates.notNull()
5350
private val endpoint = CompletableDeferred<String>()
5451

52+
private lateinit var session: ClientSSESession
53+
private lateinit var scope: CoroutineScope
5554
private var job: Job? = null
5655

57-
private val baseUrl by lazy {
58-
val requestUrl = session.call.request.url.toString()
59-
val url = Url(requestUrl)
60-
var path = url.encodedPath
61-
if (path.isEmpty()) {
62-
url.protocolWithAuthority
63-
} else if (path.endsWith("/")) {
64-
url.protocolWithAuthority + path.removeSuffix("/")
65-
} else {
66-
// the last item is not a directory, so will not be taken into account
67-
path = path.substring(0, path.lastIndexOf("/"))
68-
url.protocolWithAuthority + path
56+
private val baseUrl: String by lazy {
57+
session.call.request.url.let { url ->
58+
val path = url.encodedPath
59+
when {
60+
path.isEmpty() -> url.protocolWithAuthority
61+
path.endsWith("/") -> url.protocolWithAuthority + path.removeSuffix("/")
62+
else -> url.protocolWithAuthority + path.take(path.lastIndexOf("/"))
63+
}
6964
}
7065
}
7166

7267
override suspend fun start() {
73-
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
74-
error(
75-
"SSEClientTransport already started! " +
76-
"If using Client class, note that connect() calls start() automatically.",
77-
)
68+
check(initialized.compareAndSet(expectedValue = false, newValue = true)) {
69+
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically."
7870
}
7971

80-
session = urlString?.let {
81-
client.sseSession(
82-
urlString = it,
72+
try {
73+
session = urlString?.let {
74+
client.sseSession(
75+
urlString = it,
76+
reconnectionTime = reconnectionTime,
77+
block = requestBuilder,
78+
)
79+
} ?: client.sseSession(
8380
reconnectionTime = reconnectionTime,
8481
block = requestBuilder,
8582
)
86-
} ?: client.sseSession(
87-
reconnectionTime = reconnectionTime,
88-
block = requestBuilder,
89-
)
90-
91-
job = scope.launch(CoroutineName("SseMcpClientTransport.collect#${hashCode()}")) {
92-
session.incoming.collect { event ->
93-
when (event.event) {
94-
"error" -> {
95-
val e = IllegalStateException("SSE error: ${event.data}")
96-
_onError(e)
97-
throw e
98-
}
99-
100-
"open" -> {
101-
// The connection is open, but we need to wait for the endpoint to be received.
102-
}
103-
104-
"endpoint" -> {
105-
try {
106-
val eventData = event.data ?: ""
107-
108-
// check url correctness
109-
val maybeEndpoint = Url("$baseUrl/${if (eventData.startsWith("/")) eventData.substring(1) else eventData}")
110-
endpoint.complete(maybeEndpoint.toString())
111-
} catch (e: Exception) {
112-
_onError(e)
113-
close()
114-
error(e)
115-
}
116-
}
83+
scope = CoroutineScope(session.coroutineContext + SupervisorJob())
11784

118-
else -> {
119-
try {
120-
val message = McpJson.decodeFromString<JSONRPCMessage>(event.data ?: "")
121-
_onMessage(message)
122-
} catch (e: Exception) {
123-
_onError(e)
124-
}
125-
}
126-
}
85+
job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) {
86+
collectMessages()
12787
}
128-
}
12988

130-
endpoint.await()
89+
endpoint.await()
90+
} catch (e: Exception) {
91+
closeResources()
92+
initialized.store(false)
93+
throw e
94+
}
13195
}
13296

13397
@OptIn(ExperimentalCoroutinesApi::class)
13498
override suspend fun send(message: JSONRPCMessage) {
135-
if (!endpoint.isCompleted) {
136-
error("Not connected")
137-
}
99+
check(initialized.load()) { "SseClientTransport is not initialized!" }
100+
check(job?.isActive == true) { "SseClientTransport is closed!" }
101+
check(endpoint.isCompleted) { "Not connected!" }
138102

139103
try {
140104
val response = client.post(endpoint.getCompleted()) {
@@ -147,19 +111,80 @@ public class SseClientTransport(
147111
val text = response.bodyAsText()
148112
error("Error POSTing to endpoint (HTTP ${response.status}): $text")
149113
}
150-
} catch (e: Exception) {
114+
} catch (e: Throwable) {
151115
_onError(e)
152116
throw e
153117
}
154118
}
155119

156120
override suspend fun close() {
157-
if (!initialized.load()) {
158-
error("SSEClientTransport is not initialized!")
121+
check(initialized.load()) { "SseClientTransport is not initialized!" }
122+
closeResources()
123+
}
124+
125+
private suspend fun CoroutineScope.collectMessages() {
126+
try {
127+
session.incoming.collect { event ->
128+
ensureActive()
129+
130+
when (event.event) {
131+
"error" -> {
132+
val error = IllegalStateException("SSE error: ${event.data}")
133+
_onError(error)
134+
throw error
135+
}
136+
137+
"open" -> {
138+
// The connection is open, but we need to wait for the endpoint to be received.
139+
}
140+
141+
"endpoint" -> handleEndpoint(event.data.orEmpty())
142+
else -> handleMessage(event.data.orEmpty())
143+
}
144+
}
145+
} catch (e: CancellationException) {
146+
throw e
147+
} catch (e: Throwable) {
148+
_onError(e)
149+
throw e
150+
} finally {
151+
closeResources()
159152
}
153+
}
154+
155+
private fun handleEndpoint(eventData: String) {
156+
try {
157+
val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData
158+
val endpointUrl = Url("$baseUrl/$path")
159+
endpoint.complete(endpointUrl.toString())
160+
} catch (e: Throwable) {
161+
_onError(e)
162+
endpoint.completeExceptionally(e)
163+
throw e
164+
}
165+
}
166+
167+
private suspend fun handleMessage(data: String) {
168+
try {
169+
val message = McpJson.decodeFromString<JSONRPCMessage>(data)
170+
_onMessage(message)
171+
} catch (e: SerializationException) {
172+
_onError(e)
173+
}
174+
}
175+
176+
private suspend fun closeResources() {
177+
if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return
160178

161-
session.cancel()
162-
_onClose()
163179
job?.cancelAndJoin()
180+
try {
181+
if (::session.isInitialized) session.cancel()
182+
if (::scope.isInitialized) scope.cancel()
183+
endpoint.cancel()
184+
} catch (e: Throwable) {
185+
_onError(e)
186+
}
187+
188+
_onClose()
164189
}
165190
}

src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import io.ktor.client.HttpClient
44
import io.ktor.client.plugins.sse.SSE
55
import io.ktor.server.application.install
66
import io.ktor.server.cio.CIO
7+
import io.ktor.server.engine.EmbeddedServer
78
import io.ktor.server.engine.embeddedServer
89
import io.ktor.server.routing.post
910
import io.ktor.server.routing.route
@@ -17,10 +18,12 @@ import kotlinx.coroutines.test.runTest
1718
import kotlin.test.Test
1819

1920
class SseTransportTest : BaseTransportTest() {
21+
22+
private suspend fun EmbeddedServer<*, *>.actualPort() = engine.resolvedConnectors().first().port
23+
2024
@Test
2125
fun `should start then close cleanly`() = runTest {
22-
val port = 8080
23-
val server = embeddedServer(CIO, port = port) {
26+
val server = embeddedServer(CIO, port = 0) {
2427
install(io.ktor.server.sse.SSE)
2528
val transports = ConcurrentMap<String, SseServerTransport>()
2629
routing {
@@ -34,24 +37,27 @@ class SseTransportTest : BaseTransportTest() {
3437
}
3538
}.startSuspend(wait = false)
3639

40+
val actualPort = server.actualPort()
41+
3742
val client = HttpClient {
3843
install(SSE)
3944
}.mcpSseTransport {
4045
url {
4146
host = "localhost"
42-
this.port = port
47+
this.port = actualPort
4348
}
4449
}
4550

46-
testClientOpenClose(client)
47-
48-
server.stopSuspend()
51+
try {
52+
testClientOpenClose(client)
53+
} finally {
54+
server.stopSuspend()
55+
}
4956
}
5057

5158
@Test
5259
fun `should read messages`() = runTest {
53-
val port = 3003
54-
val server = embeddedServer(CIO, port = port) {
60+
val server = embeddedServer(CIO, port = 0) {
5561
install(io.ktor.server.sse.SSE)
5662
val transports = ConcurrentMap<String, SseServerTransport>()
5763
routing {
@@ -71,23 +77,27 @@ class SseTransportTest : BaseTransportTest() {
7177
}
7278
}.startSuspend(wait = false)
7379

80+
val actualPort = server.actualPort()
81+
7482
val client = HttpClient {
7583
install(SSE)
7684
}.mcpSseTransport {
7785
url {
7886
host = "localhost"
79-
this.port = port
87+
this.port = actualPort
8088
}
8189
}
8290

83-
testClientRead(client)
84-
server.stopSuspend()
91+
try {
92+
testClientRead(client)
93+
} finally {
94+
server.stopSuspend()
95+
}
8596
}
8697

8798
@Test
8899
fun `test sse path not root path`() = runTest {
89-
val port = 3007
90-
val server = embeddedServer(CIO, port = port) {
100+
val server = embeddedServer(CIO, port = 0) {
91101
install(io.ktor.server.sse.SSE)
92102
val transports = ConcurrentMap<String, SseServerTransport>()
93103
routing {
@@ -109,17 +119,22 @@ class SseTransportTest : BaseTransportTest() {
109119
}
110120
}.startSuspend(wait = false)
111121

122+
val actualPort = server.actualPort()
123+
112124
val client = HttpClient {
113125
install(SSE)
114126
}.mcpSseTransport {
115127
url {
116128
host = "localhost"
117-
this.port = port
129+
this.port = actualPort
118130
pathSegments = listOf("sse")
119131
}
120132
}
121133

122-
testClientRead(client)
123-
server.stopSuspend()
134+
try {
135+
testClientRead(client)
136+
} finally {
137+
server.stopSuspend()
138+
}
124139
}
125140
}

src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@ class SseIntegrationTest {
2626
@Test
2727
fun `client should be able to connect to sse server`() = runTest {
2828
val serverEngine = initServer()
29+
var client: Client? = null
2930
try {
3031
withContext(Dispatchers.Default) {
31-
assertDoesNotThrow { initClient() }
32+
assertDoesNotThrow { client = initClient() }
3233
}
34+
} catch (e: Exception) {
35+
fail("Failed to connect client: $e")
3336
} finally {
37+
client?.close()
3438
// Make sure to stop the server
3539
serverEngine.stopSuspend(1000, 2000)
3640
}
@@ -54,11 +58,11 @@ class SseIntegrationTest {
5458
ServerOptions(capabilities = ServerCapabilities()),
5559
)
5660

57-
return embeddedServer(ServerCIO, host = URL, port = PORT) {
61+
return embeddedServer(ServerCIO, host = URL, port = PORT) {
5862
install(io.ktor.server.sse.SSE)
59-
routing {
60-
mcp { server }
61-
}
63+
routing {
64+
mcp { server }
65+
}
6266
}.startSuspend(wait = false)
6367
}
6468

0 commit comments

Comments
 (0)