Skip to content

Introduce server sessions #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ public open class Client(

notification(InitializedNotification())
} catch (error: Throwable) {
logger.error(error) { "Failed to initialize client: ${error.message}" }
close()

if (error !is CancellationException) {
throw IllegalStateException("Error connecting to transport: ${error.message}")
}

throw error

}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.plugins.sse.ClientSSESession
import io.ktor.client.plugins.sse.sseSession
Expand Down Expand Up @@ -46,6 +47,8 @@ public class SseClientTransport(
private val reconnectionTime: Duration? = null,
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
) : AbstractTransport() {
private val logger = KotlinLogging.logger {}

private val initialized: AtomicBoolean = AtomicBoolean(false)
private val endpoint = CompletableDeferred<String>()

Expand Down Expand Up @@ -111,6 +114,8 @@ public class SseClientTransport(
val text = response.bodyAsText()
error("Error POSTing to endpoint (HTTP ${response.status}): $text")
}

logger.debug { "Client successfully sent message via SSE $endpoint" }
} catch (e: Throwable) {
_onError(e)
throw e
Expand Down Expand Up @@ -157,6 +162,7 @@ public class SseClientTransport(
val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData
val endpointUrl = Url("$baseUrl/$path")
endpoint.complete(endpointUrl.toString())
logger.debug { "Client connected to endpoint: $endpointUrl" }
} catch (e: Throwable) {
_onError(e)
endpoint.completeExceptionally(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.plugins.websocket.webSocketSession
import io.ktor.client.request.HttpRequestBuilder
Expand All @@ -10,6 +11,8 @@ import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL
import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport
import kotlin.properties.Delegates

private val logger = KotlinLogging.logger {}

/**
* Client transport for WebSocket: this will connect to a server over the WebSocket protocol.
*/
Expand All @@ -21,6 +24,8 @@ public class WebSocketClientTransport(
override var session: WebSocketSession by Delegates.notNull()

override suspend fun initializeSession() {
logger.debug { "Websocket session initialization started..." }

session = urlString?.let {
client.webSocketSession(it) {
requestBuilder()
Expand All @@ -32,5 +37,7 @@ public class WebSocketClientTransport(

header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL)
}

logger.debug { "Websocket session initialization finished" }
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.request.HttpRequestBuilder
import io.modelcontextprotocol.kotlin.sdk.Implementation
import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION
import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME

private val logger = KotlinLogging.logger {}


/**
* Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient.
*
Expand Down Expand Up @@ -36,6 +40,8 @@ public suspend fun HttpClient.mcpWebSocket(
version = LIB_VERSION
)
)
logger.debug { "Client started to connect to server" }
client.connect(transport)
logger.debug { "Client finished to connect to server" }
return client
}
292 changes: 290 additions & 2 deletions kotlin-sdk-core/api/kotlin-sdk-core.api

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import kotlin.reflect.typeOf
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds

private val LOGGER = KotlinLogging.logger { }
private val logger = KotlinLogging.logger { }

public const val IMPLEMENTATION_NAME: String = "mcp-ktor"

Expand Down Expand Up @@ -204,6 +204,7 @@ public abstract class Protocol(
}
}

logger.info { "Starting transport" }
return transport.start()
}

Expand All @@ -221,29 +222,29 @@ public abstract class Protocol(
}

private suspend fun onNotification(notification: JSONRPCNotification) {
LOGGER.trace { "Received notification: ${notification.method}" }
logger.trace { "Received notification: ${notification.method}" }

val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler

if (handler == null) {
LOGGER.trace { "No handler found for notification: ${notification.method}" }
logger.trace { "No handler found for notification: ${notification.method}" }
return
}
try {
handler(notification)
} catch (cause: Throwable) {
LOGGER.error(cause) { "Error handling notification: ${notification.method}" }
logger.error(cause) { "Error handling notification: ${notification.method}" }
onError(cause)
}
}

private suspend fun onRequest(request: JSONRPCRequest) {
LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" }
logger.trace { "Received request: ${request.method} (id: ${request.id})" }

val handler = requestHandlers[request.method] ?: fallbackRequestHandler

if (handler === null) {
LOGGER.trace { "No handler found for request: ${request.method}" }
logger.trace { "No handler found for request: ${request.method}" }
try {
transport?.send(
JSONRPCResponse(
Expand All @@ -255,15 +256,15 @@ public abstract class Protocol(
)
)
} catch (cause: Throwable) {
LOGGER.error(cause) { "Error sending method not found response" }
logger.error(cause) { "Error sending method not found response" }
onError(cause)
}
return
}

try {
val result = handler(request, RequestHandlerExtra())
LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }

transport?.send(
JSONRPCResponse(
Expand All @@ -273,7 +274,7 @@ public abstract class Protocol(
)

} catch (cause: Throwable) {
LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }
logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }

try {
transport?.send(
Expand All @@ -286,14 +287,14 @@ public abstract class Protocol(
)
)
} catch (sendError: Throwable) {
LOGGER.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" }
logger.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" }
// Optionally implement fallback behavior here
}
}
}

private fun onProgress(notification: ProgressNotification) {
LOGGER.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" }
logger.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" }
val progress = notification.params.progress
val total = notification.params.total
val message = notification.params.message
Expand All @@ -304,7 +305,7 @@ public abstract class Protocol(
val error = Error(
"Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}",
)
LOGGER.error { error.message }
logger.error { error.message }
onError(error)
return
}
Expand Down Expand Up @@ -382,9 +383,9 @@ public abstract class Protocol(
request: Request,
options: RequestOptions? = null,
): T {
LOGGER.trace { "Sending request: ${request.method}" }
logger.trace { "Sending request: ${request.method}" }
val result = CompletableDeferred<T>()
val transport = this@Protocol.transport ?: throw Error("Not connected")
val transport = transport ?: throw Error("Not connected")

if ([email protected]?.enforceStrictCapabilities == true) {
assertCapabilityForMethod(request.method)
Expand All @@ -394,7 +395,7 @@ public abstract class Protocol(
val messageId = message.id

if (options?.onProgress != null) {
LOGGER.trace { "Registering progress handler for request id: $messageId" }
logger.trace { "Registering progress handler for request id: $messageId" }
_progressHandlers.update { current ->
current.put(messageId, options.onProgress)
}
Expand Down Expand Up @@ -427,7 +428,7 @@ public abstract class Protocol(

val notification = CancelledNotification(
params = CancelledNotification.Params(
requestId = messageId,
requestId = messageId,
reason = reason.message ?: "Unknown"
)
)
Expand All @@ -444,12 +445,12 @@ public abstract class Protocol(
val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT
try {
withTimeout(timeout) {
LOGGER.trace { "Sending request message with id: $messageId" }
logger.trace { "Sending request message with id: $messageId" }
[email protected]?.send(message)
}
return result.await()
} catch (cause: TimeoutCancellationException) {
LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
cancel(
McpError(
ErrorCode.Defined.RequestTimeout.code,
Expand All @@ -466,7 +467,7 @@ public abstract class Protocol(
* Emits a notification, which is a one-way message that does not expect a response.
*/
public suspend fun notification(notification: Notification) {
LOGGER.trace { "Sending notification: ${notification.method}" }
logger.trace { "Sending notification: ${notification.method}" }
val transport = this.transport ?: error("Not connected")
assertNotificationCapability(notification.method)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.shared

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.websocket.Frame
import io.ktor.websocket.WebSocketSession
import io.ktor.websocket.close
Expand All @@ -17,6 +18,9 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi

public const val MCP_SUBPROTOCOL: String = "mcp"

private val logger = KotlinLogging.logger {}


/**
* Abstract class representing a WebSocket transport for the Model Context Protocol (MCP).
* Handles communication over a WebSocket session.
Expand All @@ -40,6 +44,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
protected abstract suspend fun initializeSession()

override suspend fun start() {
logger.debug { "Starting websocket transport" }

if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
error(
"WebSocketClientTransport already started! " +
Expand All @@ -53,7 +59,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
while (true) {
val message = try {
session.incoming.receive()
} catch (_: ClosedReceiveChannelException) {
} catch (e: ClosedReceiveChannelException) {
logger.debug { "Closed receive channel, exiting" }
return@launch
}

Expand Down Expand Up @@ -84,6 +91,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
}

override suspend fun send(message: JSONRPCMessage) {
logger.debug { "Sending message" }
if (!initialized.load()) {
error("Not connected")
}
Expand All @@ -96,6 +104,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
error("Not connected")
}

logger.debug { "Closing websocket session" }
session.close()
session.coroutineContext.job.join()
}
Expand Down
Loading
Loading