@@ -16,6 +16,7 @@ import io.ktor.http.protocolWithAuthority
16
16
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
17
17
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
18
18
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
19
+ import kotlinx.coroutines.CancellationException
19
20
import kotlinx.coroutines.CompletableDeferred
20
21
import kotlinx.coroutines.CoroutineName
21
22
import kotlinx.coroutines.CoroutineScope
@@ -24,10 +25,11 @@ import kotlinx.coroutines.Job
24
25
import kotlinx.coroutines.SupervisorJob
25
26
import kotlinx.coroutines.cancel
26
27
import kotlinx.coroutines.cancelAndJoin
28
+ import kotlinx.coroutines.ensureActive
27
29
import kotlinx.coroutines.launch
30
+ import kotlinx.serialization.SerializationException
28
31
import kotlin.concurrent.atomics.AtomicBoolean
29
32
import kotlin.concurrent.atomics.ExperimentalAtomicApi
30
- import kotlin.properties.Delegates
31
33
import kotlin.time.Duration
32
34
33
35
@Deprecated(" Use SseClientTransport instead" , ReplaceWith (" SseClientTransport" ), DeprecationLevel .WARNING )
@@ -44,97 +46,59 @@ public class SseClientTransport(
44
46
private val reconnectionTime : Duration ? = null ,
45
47
private val requestBuilder : HttpRequestBuilder .() -> Unit = {},
46
48
) : AbstractTransport() {
47
- private val scope by lazy {
48
- CoroutineScope (session.coroutineContext + SupervisorJob ())
49
- }
50
-
51
49
private val initialized: AtomicBoolean = AtomicBoolean (false )
52
- private var session: ClientSSESession by Delegates .notNull()
53
50
private val endpoint = CompletableDeferred <String >()
54
51
52
+ private lateinit var session: ClientSSESession
53
+ private lateinit var scope: CoroutineScope
55
54
private var job: Job ? = null
56
55
57
- private val baseUrl by lazy {
58
- val requestUrl = session.call.request.url.toString()
59
- val url = Url (requestUrl)
60
- var path = url.encodedPath
61
- if (path.isEmpty()) {
62
- url.protocolWithAuthority
63
- } else if (path.endsWith(" /" )) {
64
- url.protocolWithAuthority + path.removeSuffix(" /" )
65
- } else {
66
- // the last item is not a directory, so will not be taken into account
67
- path = path.substring(0 , path.lastIndexOf(" /" ))
68
- url.protocolWithAuthority + path
56
+ private val baseUrl: String by lazy {
57
+ session.call.request.url.let { url ->
58
+ val path = url.encodedPath
59
+ when {
60
+ path.isEmpty() -> url.protocolWithAuthority
61
+ path.endsWith(" /" ) -> url.protocolWithAuthority + path.removeSuffix(" /" )
62
+ else -> url.protocolWithAuthority + path.take(path.lastIndexOf(" /" ))
63
+ }
69
64
}
70
65
}
71
66
72
67
override suspend fun start () {
73
- if (! initialized.compareAndSet(expectedValue = false , newValue = true )) {
74
- error(
75
- " SSEClientTransport already started! " +
76
- " If using Client class, note that connect() calls start() automatically." ,
77
- )
68
+ check(initialized.compareAndSet(expectedValue = false , newValue = true )) {
69
+ " SSEClientTransport already started! If using Client class, note that connect() calls start() automatically."
78
70
}
79
71
80
- session = urlString?.let {
81
- client.sseSession(
82
- urlString = it,
72
+ try {
73
+ session = urlString?.let {
74
+ client.sseSession(
75
+ urlString = it,
76
+ reconnectionTime = reconnectionTime,
77
+ block = requestBuilder,
78
+ )
79
+ } ? : client.sseSession(
83
80
reconnectionTime = reconnectionTime,
84
81
block = requestBuilder,
85
82
)
86
- } ? : client.sseSession(
87
- reconnectionTime = reconnectionTime,
88
- block = requestBuilder,
89
- )
90
-
91
- job = scope.launch(CoroutineName (" SseMcpClientTransport.collect#${hashCode()} " )) {
92
- session.incoming.collect { event ->
93
- when (event.event) {
94
- " error" -> {
95
- val e = IllegalStateException (" SSE error: ${event.data} " )
96
- _onError (e)
97
- throw e
98
- }
99
-
100
- " open" -> {
101
- // The connection is open, but we need to wait for the endpoint to be received.
102
- }
103
-
104
- " endpoint" -> {
105
- try {
106
- val eventData = event.data ? : " "
107
-
108
- // check url correctness
109
- val maybeEndpoint = Url (" $baseUrl /${if (eventData.startsWith(" /" )) eventData.substring(1 ) else eventData} " )
110
- endpoint.complete(maybeEndpoint.toString())
111
- } catch (e: Exception ) {
112
- _onError (e)
113
- close()
114
- error(e)
115
- }
116
- }
83
+ scope = CoroutineScope (session.coroutineContext + SupervisorJob ())
117
84
118
- else -> {
119
- try {
120
- val message = McpJson .decodeFromString<JSONRPCMessage >(event.data ? : " " )
121
- _onMessage (message)
122
- } catch (e: Exception ) {
123
- _onError (e)
124
- }
125
- }
126
- }
85
+ job = scope.launch(CoroutineName (" SseMcpClientTransport.connect#${hashCode()} " )) {
86
+ collectMessages()
127
87
}
128
- }
129
88
130
- endpoint.await()
89
+ endpoint.await()
90
+ } catch (e: Exception ) {
91
+ closeResources()
92
+ initialized.store(false )
93
+ throw e
94
+ }
131
95
}
132
96
133
97
@OptIn(ExperimentalCoroutinesApi ::class )
134
98
override suspend fun send (message : JSONRPCMessage ) {
135
- if ( ! endpoint.isCompleted) {
136
- error( " Not connected " )
137
- }
99
+ check(initialized.load()) { " SseClientTransport is not initialized! " }
100
+ check(job?.isActive == true ) { " SseClientTransport is closed! " }
101
+ check(endpoint.isCompleted) { " Not connected! " }
138
102
139
103
try {
140
104
val response = client.post(endpoint.getCompleted()) {
@@ -147,19 +111,80 @@ public class SseClientTransport(
147
111
val text = response.bodyAsText()
148
112
error(" Error POSTing to endpoint (HTTP ${response.status} ): $text " )
149
113
}
150
- } catch (e: Exception ) {
114
+ } catch (e: Throwable ) {
151
115
_onError (e)
152
116
throw e
153
117
}
154
118
}
155
119
156
120
override suspend fun close () {
157
- if (! initialized.load()) {
158
- error(" SSEClientTransport is not initialized!" )
121
+ check(initialized.load()) { " SseClientTransport is not initialized!" }
122
+ closeResources()
123
+ }
124
+
125
+ private suspend fun CoroutineScope.collectMessages () {
126
+ try {
127
+ session.incoming.collect { event ->
128
+ ensureActive()
129
+
130
+ when (event.event) {
131
+ " error" -> {
132
+ val error = IllegalStateException (" SSE error: ${event.data} " )
133
+ _onError (error)
134
+ throw error
135
+ }
136
+
137
+ " open" -> {
138
+ // The connection is open, but we need to wait for the endpoint to be received.
139
+ }
140
+
141
+ " endpoint" -> handleEndpoint(event.data.orEmpty())
142
+ else -> handleMessage(event.data.orEmpty())
143
+ }
144
+ }
145
+ } catch (e: CancellationException ) {
146
+ throw e
147
+ } catch (e: Throwable ) {
148
+ _onError (e)
149
+ throw e
150
+ } finally {
151
+ closeResources()
159
152
}
153
+ }
154
+
155
+ private fun handleEndpoint (eventData : String ) {
156
+ try {
157
+ val path = if (eventData.startsWith(" /" )) eventData.substring(1 ) else eventData
158
+ val endpointUrl = Url (" $baseUrl /$path " )
159
+ endpoint.complete(endpointUrl.toString())
160
+ } catch (e: Throwable ) {
161
+ _onError (e)
162
+ endpoint.completeExceptionally(e)
163
+ throw e
164
+ }
165
+ }
166
+
167
+ private suspend fun handleMessage (data : String ) {
168
+ try {
169
+ val message = McpJson .decodeFromString<JSONRPCMessage >(data)
170
+ _onMessage (message)
171
+ } catch (e: SerializationException ) {
172
+ _onError (e)
173
+ }
174
+ }
175
+
176
+ private suspend fun closeResources () {
177
+ if (! initialized.compareAndSet(expectedValue = true , newValue = false )) return
160
178
161
- session.cancel()
162
- _onClose ()
163
179
job?.cancelAndJoin()
180
+ try {
181
+ if (::session.isInitialized) session.cancel()
182
+ if (::scope.isInitialized) scope.cancel()
183
+ endpoint.cancel()
184
+ } catch (e: Throwable ) {
185
+ _onError (e)
186
+ }
187
+
188
+ _onClose ()
164
189
}
165
190
}
0 commit comments