diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt new file mode 100644 index 00000000..abdd42b0 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -0,0 +1,396 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.ktor.http.* +import io.ktor.server.application.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.sse.* +import io.modelcontextprotocol.kotlin.sdk.* +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement +import kotlin.collections.HashMap +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages. + * + * Creates a new StreamableHttp server transport. + * + * @param isStateful If true, the server will generate and maintain session IDs for each client connection. + * Session IDs are included in response headers and must be provided by clients in subsequent requests. + * @param enableJSONResponse If true, the server will return JSON responses instead of starting an SSE stream. + * This can be useful for simple request/response scenarios without streaming. + */ +@OptIn(ExperimentalAtomicApi::class) +public class StreamableHttpServerTransport( + private val isStateful: Boolean = false, + private val enableJSONResponse: Boolean = false, +): AbstractTransport() { + private val standalone = "standalone" + private val streamMapping: HashMap = hashMapOf() + private val requestToStreamMapping: HashMap = hashMapOf() + private val requestResponseMapping: HashMap = hashMapOf() + private val callMapping: HashMap = hashMapOf() + private val started: AtomicBoolean = AtomicBoolean(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) + + /** + * The current session ID for this transport instance. + * This is set during initialization when isStateful is true. + * Clients must include this ID in the Mcp-Session-Id header for all subsequent requests. + */ + public var sessionId: String? = null + private set + + override suspend fun start() { + if (!started.compareAndSet(false, true)) { + error("StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically.") + } + } + + override suspend fun send(message: JSONRPCMessage) { + var requestId: RequestId? = null + + if (message is JSONRPCResponse) { + requestId = message.id + } + + if (requestId == null) { + val standaloneSSE = streamMapping[standalone] ?: return + + standaloneSSE.send( + event = "message", + data = McpJson.encodeToString(message), + ) + return + } + + val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId") + val correspondingStream = + streamMapping[streamId] ?: error("No connection established for request id $requestId") + val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId") + + if (!enableJSONResponse) { + correspondingStream.send( + event = "message", + data = McpJson.encodeToString(message), + ) + } + + requestResponseMapping[requestId] = message + val relatedIds = + requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key } + val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null } + + if (!allResponsesReady) return + + if (enableJSONResponse) { + correspondingCall.response.headers.append(HttpHeaders.ContentType, ContentType.Application.Json.toString()) + correspondingCall.response.status(HttpStatusCode.OK) + if (sessionId != null) { + correspondingCall.response.header(MCP_SESSION_ID, sessionId!!) + } + val responses = relatedIds.map { requestResponseMapping[it] } + if (responses.size == 1) { + correspondingCall.respond(responses[0]!!) + } else { + correspondingCall.respond(responses) + } + callMapping.remove(streamId) + } else { + correspondingStream.close() + streamMapping.remove(streamId) + } + + for (id in relatedIds) { + requestToStreamMapping.remove(id) + requestResponseMapping.remove(id) + } + + } + + override suspend fun close() { + streamMapping.values.forEach { + it.close() + } + streamMapping.clear() + requestToStreamMapping.clear() + requestResponseMapping.clear() + _onClose.invoke() + } + + /** + * Handles HTTP POST requests for the StreamableHttp transport. + * This method processes JSON-RPC messages sent by clients and manages SSE sessions. + * + * @param call The Ktor ApplicationCall representing the incoming HTTP request + * @param session The ServerSSESession for streaming responses back to the client + * + * The method performs the following: + * - Validates required headers (Accept and Content-Type) + * - Parses JSON-RPC messages from the request body + * - Handles initialization requests and session management + * - Sets up SSE streams for request/response communication + * - Sends appropriate error responses for invalid requests + */ + @OptIn(ExperimentalUuidApi::class) + public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) { + try { + if (!validateHeaders(call)) return + + val messages = parseBody(call) + + if (messages.isEmpty()) return + + val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method.Defined.Initialize.value } + if (hasInitializationRequest) { + if (initialized.load() && sessionId != null) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Server already initialized" + ) + ) + ) + return + } + + if (messages.size > 1) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Only one initialization request is allowed" + ) + ) + ) + return + } + + if (isStateful) { + sessionId = Uuid.random().toString() + } + initialized.store(true) + } + + if (!validateSession(call)) { + return + } + + val hasRequests = messages.any { it is JSONRPCRequest } + val streamId = Uuid.random().toString() + + if (!hasRequests) { + call.respondNullable(HttpStatusCode.Accepted) + } else { + if (!enableJSONResponse) { + call.response.headers.append(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()) + + if (sessionId != null) { + call.response.header(MCP_SESSION_ID, sessionId!!) + } + } + + for (message in messages) { + if (message is JSONRPCRequest) { + streamMapping[streamId] = session + callMapping[streamId] = call + requestToStreamMapping[message.id] = streamId + } + } + } + for (message in messages) { + _onMessage.invoke(message) + } + } catch (e: Exception) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = e.message ?: "Parse error" + ) + ) + ) + _onError.invoke(e) + } + } + + /** + * Handles HTTP GET requests for establishing standalone SSE streams. + * This method sets up a persistent SSE connection for server-initiated messages. + * + * @param call The Ktor ApplicationCall representing the incoming HTTP request + * @param session The ServerSSESession for streaming messages to the client + * + * The method: + * - Validates the Accept header includes text/event-stream + * - Validates session if stateful mode is enabled + * - Ensures only one standalone SSE stream per session + * - Sets up the stream for server-initiated notifications + */ + public suspend fun handleGetRequest(call: ApplicationCall, session: ServerSSESession) { + val acceptHeader = call.request.headers["Accept"]?.split(",")?.map { it.trim() } ?: listOf() + val acceptsEventStream = acceptHeader.any { it == "text/event-stream" || it.startsWith("text/event-stream;") } + + if (!acceptsEventStream) { + call.response.status(HttpStatusCode.NotAcceptable) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Not Acceptable: Client must accept text/event-stream" + ) + ) + ) + return + } + + if (!validateSession(call)) { + return + } + + if (sessionId != null) { + call.response.header(MCP_SESSION_ID, sessionId!!) + } + + if (streamMapping[standalone] != null) { + call.response.status(HttpStatusCode.Conflict) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Conflict: Only one SSE stream is allowed per session" + ) + ) + ) + session.close() + return + } + + // TODO: Equivalent of typescript res.writeHead(200, headers).flushHeaders(); + streamMapping[standalone] = session + } + + /** + * Handles HTTP DELETE requests to terminate the transport session. + * This method allows clients to gracefully close their connection and clean up server resources. + * + * @param call The Ktor ApplicationCall representing the incoming HTTP request + * + * The method: + * - Validates the session if stateful mode is enabled + * - Closes all active SSE streams + * - Clears all internal mappings and resources + * - Responds with 200 OK on successful termination + */ + public suspend fun handleDeleteRequest(call: ApplicationCall) { + if (!validateSession(call)) { + return + } + close() + call.respondNullable(HttpStatusCode.OK) + } + + private suspend fun validateSession(call: ApplicationCall): Boolean { + if (sessionId == null) { + return true + } + + if (!initialized.load()) { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Bad Request: Server not initialized" + ) + ) + ) + return false + } + return true + } + + private suspend fun validateHeaders(call: ApplicationCall): Boolean { + val acceptHeader = call.request.headers["Accept"]?.split(",")?.map { it.trim() } ?: listOf() + + val acceptsEventStream = acceptHeader.any { it == "text/event-stream" || it.startsWith("text/event-stream;") } + val acceptsJson = acceptHeader.any { it == "application/json" || it.startsWith("application/json;") } + + if (!acceptsEventStream || !acceptsJson) { + call.response.status(HttpStatusCode.NotAcceptable) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Not Acceptable: Client must accept both application/json and text/event-stream" + ) + ) + ) + return false + } + + val contentType = call.request.contentType() + if (contentType != ContentType.Application.Json) { + call.response.status(HttpStatusCode.UnsupportedMediaType) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Unknown(-32000), + message = "Unsupported Media Type: Content-Type must be application/json" + ) + ) + ) + return false + } + + return true + } + + private suspend fun parseBody( + call: ApplicationCall, + ): List { + val messages = mutableListOf() + when (val body = call.receive()) { + is JsonObject -> messages.add(McpJson.decodeFromJsonElement(body)) + is JsonArray -> messages.addAll(McpJson.decodeFromJsonElement>(body)) + else -> { + call.response.status(HttpStatusCode.BadRequest) + call.respond( + JSONRPCResponse( + id = null, + error = JSONRPCError( + code = ErrorCode.Defined.InvalidRequest, + message = "Invalid Request: Body must be a JSON object or array" + ) + ) + ) + return listOf() + } + } + return messages + } + + +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt new file mode 100644 index 00000000..727aebec --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Constants.kt @@ -0,0 +1,3 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +internal const val MCP_SESSION_ID = "mcp-session-id" \ No newline at end of file diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index c93fa941..71af0cf9 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -239,7 +239,7 @@ public data class JSONRPCNotification( */ @Serializable public class JSONRPCResponse( - public val id: RequestId, + public val id: RequestId?, public val jsonrpc: String = JSONRPC_VERSION, public val result: RequestResult? = null, public val error: JSONRPCError? = null, diff --git a/src/jvmTest/kotlin/server/StreamableHttpServerTransportTest.kt b/src/jvmTest/kotlin/server/StreamableHttpServerTransportTest.kt new file mode 100644 index 00000000..7c6e5370 --- /dev/null +++ b/src/jvmTest/kotlin/server/StreamableHttpServerTransportTest.kt @@ -0,0 +1,159 @@ +package server + +import io.modelcontextprotocol.kotlin.sdk.* +import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Test + +class StreamableHttpServerTransportTest { + + @Test + fun `should start and close cleanly`() = runBlocking { + val transport = StreamableHttpServerTransport(isStateful = false) + + var didClose = false + transport.onClose { + didClose = true + } + + transport.start() + assertFalse(didClose, "Should not have closed yet") + + transport.close() + assertTrue(didClose, "Should have closed after calling close()") + } + + @Test + fun `should initialize with stateful mode`() = runBlocking { + val transport = StreamableHttpServerTransport(isStateful = true) + transport.start() + + assertNull(transport.sessionId, "Session ID should be null before initialization") + + transport.close() + } + + @Test + fun `should initialize with stateless mode`() = runBlocking { + val transport = StreamableHttpServerTransport(isStateful = false) + transport.start() + + assertNull(transport.sessionId, "Session ID should be null in stateless mode") + + transport.close() + } + + @Test + fun `should not allow double start`() = runBlocking { + val transport = StreamableHttpServerTransport() + transport.start() + + val exception = assertThrows(IllegalStateException::class.java) { + runBlocking { transport.start() } + } + + assertTrue(exception.message?.contains("already started") == true) + + transport.close() + } + + @Test + fun `should handle message callbacks`() = runBlocking { + val transport = StreamableHttpServerTransport() + var receivedMessage: JSONRPCMessage? = null + + transport.onMessage { message -> + receivedMessage = message + } + + transport.start() + + // Test that message handler can be called + assertTrue(receivedMessage == null) // Verify initially null + + transport.close() + } + + @Test + fun `should handle error callbacks`() = runBlocking { + val transport = StreamableHttpServerTransport() + var receivedException: Throwable? = null + + transport.onError { error -> + receivedException = error + } + + transport.start() + + // Test that error handler can be called + assertTrue(receivedException == null) // Verify initially null + + transport.close() + } + + @Test + fun `should clear all mappings on close`() = runBlocking { + val transport = StreamableHttpServerTransport() + transport.start() + + // After close, all internal state should be cleared + transport.close() + + // Verify close was called by checking the close handler + var closeHandlerCalled = false + transport.onClose { closeHandlerCalled = true } + + // Since we already closed, setting a new handler won't trigger it + assertFalse(closeHandlerCalled) + } + + @Test + fun `should support enableJSONResponse flag`() { + val transportWithJson = StreamableHttpServerTransport(enableJSONResponse = true) + val transportWithoutJson = StreamableHttpServerTransport(enableJSONResponse = false) + + // Just verify the transports can be created with different flags + assertNotNull(transportWithJson) + assertNotNull(transportWithoutJson) + } + + @Test + fun `should support isStateful flag`() { + val statefulTransport = StreamableHttpServerTransport(isStateful = true) + val statelessTransport = StreamableHttpServerTransport(isStateful = false) + + // Just verify the transports can be created with different flags + assertNotNull(statefulTransport) + assertNotNull(statelessTransport) + } + + @Test + fun `should handle close without error callbacks`() = runBlocking { + val transport = StreamableHttpServerTransport() + + transport.start() + + // Should not throw even without error handler + assertDoesNotThrow { + runBlocking { transport.close() } + } + } + + @Test + fun `should handle multiple close calls`() = runBlocking { + val transport = StreamableHttpServerTransport() + var closeCount = 0 + + transport.onClose { + closeCount++ + } + + transport.start() + transport.close() + assertEquals(1, closeCount, "Close handler should be called once after first close") + + transport.close() // Second close should be safe + assertEquals(2, closeCount, "Close handler should be called again on second close") + } +} \ No newline at end of file