Skip to content

Commit f9f1a15

Browse files
committed
Changes after rebase
1 parent 3b19e6f commit f9f1a15

File tree

3 files changed

+346
-344
lines changed
  • kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server
  • kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client
  • src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server

3 files changed

+346
-344
lines changed
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
package io.modelcontextprotocol.kotlin.sdk.server
2+
3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities
5+
import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest
6+
import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema
7+
import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult
8+
import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest
9+
import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult
10+
import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject
11+
import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult
12+
import io.modelcontextprotocol.kotlin.sdk.Implementation
13+
import io.modelcontextprotocol.kotlin.sdk.InitializeRequest
14+
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
15+
import io.modelcontextprotocol.kotlin.sdk.InitializedNotification
16+
import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION
17+
import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest
18+
import io.modelcontextprotocol.kotlin.sdk.ListRootsResult
19+
import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification
20+
import io.modelcontextprotocol.kotlin.sdk.Method
21+
import io.modelcontextprotocol.kotlin.sdk.PingRequest
22+
import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification
23+
import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification
24+
import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification
25+
import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
26+
import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification
27+
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
28+
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
29+
import kotlinx.coroutines.CompletableDeferred
30+
import kotlinx.serialization.json.JsonObject
31+
32+
private val logger = KotlinLogging.logger {}
33+
34+
public open class ServerSession(
35+
private val serverInfo: Implementation,
36+
options: ServerOptions,
37+
) : Protocol(options) {
38+
private var _onInitialized: (() -> Unit) = {}
39+
private var _onClose: () -> Unit = {}
40+
41+
init {
42+
// Core protocol handlers
43+
setRequestHandler<InitializeRequest>(Method.Defined.Initialize) { request, _ ->
44+
handleInitialize(request)
45+
}
46+
setNotificationHandler<InitializedNotification>(Method.Defined.NotificationsInitialized) {
47+
_onInitialized()
48+
CompletableDeferred(Unit)
49+
}
50+
}
51+
52+
/**
53+
* The capabilities supported by the server, related to the session.
54+
*/
55+
private val serverCapabilities = options.capabilities
56+
57+
/**
58+
* The client's reported capabilities after initialization.
59+
*/
60+
public var clientCapabilities: ClientCapabilities? = null
61+
private set
62+
63+
/**
64+
* The client's version information after initialization.
65+
*/
66+
public var clientVersion: Implementation? = null
67+
private set
68+
69+
/**
70+
* Registers a callback to be invoked when the server has completed initialization.
71+
*/
72+
public fun onInitialized(block: () -> Unit) {
73+
val old = _onInitialized
74+
_onInitialized = {
75+
old()
76+
block()
77+
}
78+
}
79+
80+
/**
81+
* Registers a callback to be invoked when the server session is closing.
82+
*/
83+
public fun onClose(block: () -> Unit) {
84+
val old = _onClose
85+
_onClose = {
86+
old()
87+
block()
88+
}
89+
}
90+
91+
/**
92+
* Called when the server session is closing.
93+
*/
94+
override fun onClose() {
95+
logger.debug { "Server connection closing" }
96+
_onClose()
97+
}
98+
99+
/**
100+
* Sends a ping request to the client to check connectivity.
101+
*
102+
* @return The result of the ping request.
103+
* @throws IllegalStateException If for some reason the method is not supported or the connection is closed.
104+
*/
105+
public suspend fun ping(): EmptyRequestResult {
106+
return request<EmptyRequestResult>(PingRequest())
107+
}
108+
109+
/**
110+
* Creates a message using the server's sampling capability.
111+
*
112+
* @param params The parameters for creating a message.
113+
* @param options Optional request options.
114+
* @return The created message result.
115+
* @throws IllegalStateException If the server does not support sampling or if the request fails.
116+
*/
117+
public suspend fun createMessage(
118+
params: CreateMessageRequest,
119+
options: RequestOptions? = null
120+
): CreateMessageResult {
121+
logger.debug { "Creating message with params: $params" }
122+
return request<CreateMessageResult>(params, options)
123+
}
124+
125+
/**
126+
* Lists the available "roots" from the client's perspective (if supported).
127+
*
128+
* @param params JSON parameters for the request, usually empty.
129+
* @param options Optional request options.
130+
* @return The list of roots.
131+
* @throws IllegalStateException If the server or client does not support roots.
132+
*/
133+
public suspend fun listRoots(
134+
params: JsonObject = EmptyJsonObject,
135+
options: RequestOptions? = null
136+
): ListRootsResult {
137+
logger.debug { "Listing roots with params: $params" }
138+
return request<ListRootsResult>(ListRootsRequest(params), options)
139+
}
140+
141+
public suspend fun createElicitation(
142+
message: String,
143+
requestedSchema: RequestedSchema,
144+
options: RequestOptions? = null
145+
): CreateElicitationResult {
146+
logger.debug { "Creating elicitation with message: $message" }
147+
return request(CreateElicitationRequest(message, requestedSchema), options)
148+
}
149+
150+
/**
151+
* Sends a logging message notification to the client.
152+
*
153+
* @param notification The logging message notification.
154+
*/
155+
public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) {
156+
logger.trace { "Sending logging message: ${notification.params.data}" }
157+
notification(notification)
158+
}
159+
160+
/**
161+
* Sends a resource-updated notification to the client, indicating that a specific resource has changed.
162+
*
163+
* @param notification Details of the updated resource.
164+
*/
165+
public suspend fun sendResourceUpdated(notification: ResourceUpdatedNotification) {
166+
logger.debug { "Sending resource updated notification for: ${notification.params.uri}" }
167+
notification(notification)
168+
}
169+
170+
/**
171+
* Sends a notification to the client indicating that the list of resources has changed.
172+
*/
173+
public suspend fun sendResourceListChanged() {
174+
logger.debug { "Sending resource list changed notification" }
175+
notification(ResourceListChangedNotification())
176+
}
177+
178+
/**
179+
* Sends a notification to the client indicating that the list of tools has changed.
180+
*/
181+
public suspend fun sendToolListChanged() {
182+
logger.debug { "Sending tool list changed notification" }
183+
notification(ToolListChangedNotification())
184+
}
185+
186+
/**
187+
* Sends a notification to the client indicating that the list of prompts has changed.
188+
*/
189+
public suspend fun sendPromptListChanged() {
190+
logger.debug { "Sending prompt list changed notification" }
191+
notification(PromptListChangedNotification())
192+
}
193+
194+
/**
195+
* Asserts that the client supports the capability required for the given [method].
196+
*
197+
* This method is automatically called by the [Protocol] framework before handling requests.
198+
* Throws [IllegalStateException] if the capability is not supported.
199+
*
200+
* @param method The method for which we are asserting capability.
201+
*/
202+
override fun assertCapabilityForMethod(method: Method) {
203+
logger.trace { "Asserting capability for method: ${method.value}" }
204+
when (method.value) {
205+
"sampling/createMessage" -> {
206+
if (clientCapabilities?.sampling == null) {
207+
logger.error { "Client capability assertion failed: sampling not supported" }
208+
throw IllegalStateException("Client does not support sampling (required for ${method.value})")
209+
}
210+
}
211+
212+
"roots/list" -> {
213+
if (clientCapabilities?.roots == null) {
214+
throw IllegalStateException("Client does not support listing roots (required for ${method.value})")
215+
}
216+
}
217+
218+
"elicitation/create" -> {
219+
if (clientCapabilities?.elicitation == null) {
220+
throw IllegalStateException("Client does not support elicitation (required for ${method.value})")
221+
}
222+
}
223+
224+
"ping" -> {
225+
// No specific capability required
226+
}
227+
}
228+
}
229+
230+
/**
231+
* Asserts that the server can handle the specified notification method.
232+
*
233+
* Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification.
234+
*
235+
* @param method The notification method.
236+
*/
237+
override fun assertNotificationCapability(method: Method) {
238+
logger.trace { "Asserting notification capability for method: ${method.value}" }
239+
when (method.value) {
240+
"notifications/message" -> {
241+
if (serverCapabilities.logging == null) {
242+
logger.error { "Server capability assertion failed: logging not supported" }
243+
throw IllegalStateException("Server does not support logging (required for ${method.value})")
244+
}
245+
}
246+
247+
"notifications/resources/updated",
248+
"notifications/resources/list_changed" -> {
249+
if (serverCapabilities.resources == null) {
250+
throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})")
251+
}
252+
}
253+
254+
"notifications/tools/list_changed" -> {
255+
if (serverCapabilities.tools == null) {
256+
throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})")
257+
}
258+
}
259+
260+
"notifications/prompts/list_changed" -> {
261+
if (serverCapabilities.prompts == null) {
262+
throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})")
263+
}
264+
}
265+
266+
"notifications/cancelled",
267+
"notifications/progress" -> {
268+
// Always allowed
269+
}
270+
}
271+
}
272+
273+
/**
274+
* Asserts that the server can handle the specified request method.
275+
*
276+
* Throws [IllegalStateException] if the server does not have the capabilities required to handle this request.
277+
*
278+
* @param method The request method.
279+
*/
280+
override fun assertRequestHandlerCapability(method: Method) {
281+
logger.trace { "Asserting request handler capability for method: ${method.value}" }
282+
when (method.value) {
283+
"sampling/createMessage" -> {
284+
if (serverCapabilities.sampling == null) {
285+
logger.error { "Server capability assertion failed: sampling not supported" }
286+
throw IllegalStateException("Server does not support sampling (required for $method)")
287+
}
288+
}
289+
290+
"logging/setLevel" -> {
291+
if (serverCapabilities.logging == null) {
292+
throw IllegalStateException("Server does not support logging (required for $method)")
293+
}
294+
}
295+
296+
"prompts/get",
297+
"prompts/list" -> {
298+
if (serverCapabilities.prompts == null) {
299+
throw IllegalStateException("Server does not support prompts (required for $method)")
300+
}
301+
}
302+
303+
"resources/list",
304+
"resources/templates/list",
305+
"resources/read" -> {
306+
if (serverCapabilities.resources == null) {
307+
throw IllegalStateException("Server does not support resources (required for $method)")
308+
}
309+
}
310+
311+
"tools/call",
312+
"tools/list" -> {
313+
if (serverCapabilities.tools == null) {
314+
throw IllegalStateException("Server does not support tools (required for $method)")
315+
}
316+
}
317+
318+
"ping", "initialize" -> {
319+
// No capability required
320+
}
321+
}
322+
}
323+
324+
private suspend fun handleInitialize(request: InitializeRequest): InitializeResult {
325+
logger.debug { "Handling initialization request from client" }
326+
clientCapabilities = request.capabilities
327+
clientVersion = request.clientInfo
328+
329+
val requestedVersion = request.protocolVersion
330+
val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) {
331+
requestedVersion
332+
} else {
333+
logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" }
334+
LATEST_PROTOCOL_VERSION
335+
}
336+
337+
return InitializeResult(
338+
protocolVersion = protocolVersion,
339+
capabilities = serverCapabilities,
340+
serverInfo = serverInfo
341+
)
342+
}
343+
}

kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import io.ktor.server.application.install
55
import io.ktor.server.cio.CIO
66
import io.ktor.server.engine.EmbeddedServer
77
import io.ktor.server.engine.embeddedServer
8+
import io.ktor.server.routing.post
9+
import io.ktor.server.routing.route
810
import io.ktor.server.routing.routing
911
import io.modelcontextprotocol.kotlin.sdk.Implementation
1012
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
@@ -110,7 +112,7 @@ class SseTransportTest : BaseTransportTest() {
110112
val server = embeddedServer(CIO, port = 0) {
111113
install(ServerSSE)
112114
routing {
113-
mcp("/sse") { mcpServer }
115+
mcp { mcpServer }
114116
// route("/sse") {
115117
// sse {
116118
// mcpSseTransport("", transportManager).apply {

0 commit comments

Comments
 (0)