Skip to content

Commit 2907252

Browse files
devcrocode5l
andauthored
atomic and persistent collections for thread safety (#143)
* Refactor handlers and collections to use atomic references: use persistent collections for thread safety * refactor remove and add public tools/roots/prompts for server --------- Co-authored-by: Leonid Stashevsky <[email protected]>
1 parent bd3cc4e commit 2907252

File tree

6 files changed

+171
-118
lines changed

6 files changed

+171
-118
lines changed

api/kotlin-sdk.api

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,6 +2872,9 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextp
28722872
public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
28732873
public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;
28742874
public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation;
2875+
public final fun getPrompts ()Ljava/util/Map;
2876+
public final fun getResources ()Ljava/util/Map;
2877+
public final fun getTools ()Ljava/util/Map;
28752878
public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
28762879
public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
28772880
public fun onClose ()V

build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import org.jreleaser.model.Active
1414
plugins {
1515
alias(libs.plugins.kotlin.multiplatform)
1616
alias(libs.plugins.kotlin.serialization)
17+
alias(libs.plugins.kotlin.atomicfu)
1718
alias(libs.plugins.dokka)
1819
alias(libs.plugins.jreleaser)
1920
`maven-publish`
@@ -246,6 +247,7 @@ kotlin {
246247
kotlin.srcDir(generateLibVersionTask.map { it.sourcesDir })
247248
dependencies {
248249
api(libs.kotlinx.serialization.json)
250+
api(libs.kotlinx.collections.immutable)
249251
api(libs.ktor.client.cio)
250252
api(libs.ktor.server.cio)
251253
api(libs.ktor.server.sse)

gradle/libs.versions.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# plugins version
33
kotlin = "2.2.0"
44
dokka = "2.0.0"
5+
atomicfu = "0.29.0"
56

67
# libraries version
78
serialization = "1.9.0"
9+
collections-immutable = "0.4.0"
810
coroutines = "1.10.2"
911
ktor = "3.2.1"
1012
mockk = "1.14.4"
@@ -17,6 +19,7 @@ kotest = "5.9.1"
1719
[libraries]
1820
# Kotlinx libraries
1921
kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" }
22+
kotlinx-collections-immutable = { group = "org.jetbrains.kotlinx", name = "kotlinx-collections-immutable", version.ref = "collections-immutable" }
2023
kotlin-logging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "logging" }
2124

2225
# Ktor
@@ -36,6 +39,7 @@ kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json",
3639
[plugins]
3740
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
3841
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
42+
kotlin-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" }
3943
dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" }
4044
jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"}
4145
kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" }

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest
2222
import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult
2323
import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest
2424
import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult
25+
import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest
2526
import io.modelcontextprotocol.kotlin.sdk.ListRootsResult
2627
import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest
2728
import io.modelcontextprotocol.kotlin.sdk.ListToolsResult
@@ -41,6 +42,12 @@ import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
4142
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
4243
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
4344
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
45+
import kotlinx.atomicfu.atomic
46+
import kotlinx.atomicfu.getAndUpdate
47+
import kotlinx.atomicfu.update
48+
import kotlinx.collections.immutable.minus
49+
import kotlinx.collections.immutable.persistentMapOf
50+
import kotlinx.collections.immutable.toPersistentSet
4451
import kotlinx.serialization.json.JsonElement
4552
import kotlinx.serialization.json.JsonNull
4653
import kotlinx.serialization.json.JsonObject
@@ -94,14 +101,14 @@ public open class Client(
94101

95102
private val capabilities: ClientCapabilities = options.capabilities
96103

97-
private val roots = mutableMapOf<String, Root>()
104+
private val roots = atomic(persistentMapOf<String, Root>())
98105

99106
init {
100107
logger.debug { "Initializing MCP client with capabilities: $capabilities" }
101108

102109
// Internal handlers for roots
103110
if (capabilities.roots != null) {
104-
setRequestHandler<ListToolsRequest>(Method.Defined.RootsList) { _, _ ->
111+
setRequestHandler<ListRootsRequest>(Method.Defined.RootsList) { _, _ ->
105112
handleListRoots()
106113
}
107114
}
@@ -483,7 +490,7 @@ public open class Client(
483490
throw IllegalStateException("Client does not support roots capability.")
484491
}
485492
logger.info { "Adding root: $name ($uri)" }
486-
roots[uri] = Root(uri, name)
493+
roots.update { current -> current.put(uri, Root(uri, name)) }
487494
}
488495

489496
/**
@@ -498,10 +505,7 @@ public open class Client(
498505
throw IllegalStateException("Client does not support roots capability.")
499506
}
500507
logger.info { "Adding ${rootsToAdd.size} roots" }
501-
for (r in rootsToAdd) {
502-
logger.info { "Adding root: ${r.name} (${r.uri})" }
503-
roots[r.uri] = r
504-
}
508+
roots.update { current -> current.putAll(rootsToAdd.associateBy { it.uri }) }
505509
}
506510

507511
/**
@@ -517,7 +521,8 @@ public open class Client(
517521
throw IllegalStateException("Client does not support roots capability.")
518522
}
519523
logger.info { "Removing root: $uri" }
520-
val removed = roots.remove(uri) != null
524+
val oldMap = roots.getAndUpdate { current -> current.remove(uri) }
525+
val removed = uri in oldMap
521526
logger.debug {
522527
if (removed) {
523528
"Root removed: $uri"
@@ -541,13 +546,11 @@ public open class Client(
541546
throw IllegalStateException("Client does not support roots capability.")
542547
}
543548
logger.info { "Removing ${uris.size} roots" }
544-
var removedCount = 0
545-
for (uri in uris) {
546-
logger.debug { "Removing root: $uri" }
547-
if (roots.remove(uri) != null) {
548-
removedCount++
549-
}
550-
}
549+
550+
val oldMap = roots.getAndUpdate { current -> current - uris.toPersistentSet() }
551+
552+
val removedCount = uris.count { it in oldMap }
553+
551554
logger.info {
552555
if (removedCount > 0) {
553556
"Removed $removedCount roots"
@@ -571,7 +574,7 @@ public open class Client(
571574
// --- Internal Handlers ---
572575

573576
private suspend fun handleListRoots(): ListRootsResult {
574-
val rootList = roots.values.toList()
577+
val rootList = roots.value.values.toList()
575578
return ListRootsResult(rootList)
576579
}
577580
}

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification
4444
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
4545
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
4646
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
47+
import kotlinx.atomicfu.atomic
48+
import kotlinx.atomicfu.getAndUpdate
49+
import kotlinx.atomicfu.update
50+
import kotlinx.collections.immutable.minus
51+
import kotlinx.collections.immutable.persistentMapOf
52+
import kotlinx.collections.immutable.toPersistentSet
4753
import kotlinx.coroutines.CompletableDeferred
4854
import kotlinx.serialization.json.JsonObject
4955

@@ -91,9 +97,15 @@ public open class Server(
9197

9298
private val capabilities: ServerCapabilities = options.capabilities
9399

94-
private val tools = mutableMapOf<String, RegisteredTool>()
95-
private val prompts = mutableMapOf<String, RegisteredPrompt>()
96-
private val resources = mutableMapOf<String, RegisteredResource>()
100+
private val _tools = atomic(persistentMapOf<String, RegisteredTool>())
101+
private val _prompts = atomic(persistentMapOf<String, RegisteredPrompt>())
102+
private val _resources = atomic(persistentMapOf<String, RegisteredResource>())
103+
public val tools: Map<String, RegisteredTool>
104+
get() = _tools.value
105+
public val prompts: Map<String, RegisteredPrompt>
106+
get() = _prompts.value
107+
public val resources: Map<String, RegisteredResource>
108+
get() = _resources.value
97109

98110
init {
99111
logger.debug { "Initializing MCP server with capabilities: $capabilities" }
@@ -192,7 +204,9 @@ public open class Server(
192204
throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.")
193205
}
194206
logger.info { "Registering tool: $name" }
195-
tools[name] = RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler)
207+
_tools.update { current ->
208+
current.put(name, RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler))
209+
}
196210
}
197211

198212
/**
@@ -207,10 +221,7 @@ public open class Server(
207221
throw IllegalStateException("Server does not support tools capability.")
208222
}
209223
logger.info { "Registering ${toolsToAdd.size} tools" }
210-
for (rt in toolsToAdd) {
211-
logger.debug { "Registering tool: ${rt.tool.name}" }
212-
tools[rt.tool.name] = rt
213-
}
224+
_tools.update { current -> current.putAll(toolsToAdd.associateBy { it.tool.name }) }
214225
}
215226

216227
/**
@@ -226,7 +237,10 @@ public open class Server(
226237
throw IllegalStateException("Server does not support tools capability.")
227238
}
228239
logger.info { "Removing tool: $name" }
229-
val removed = tools.remove(name) != null
240+
241+
val oldMap = _tools.getAndUpdate { current -> current.remove(name) }
242+
243+
val removed = name in oldMap
230244
logger.debug {
231245
if (removed) {
232246
"Tool removed: $name"
@@ -250,18 +264,15 @@ public open class Server(
250264
throw IllegalStateException("Server does not support tools capability.")
251265
}
252266
logger.info { "Removing ${toolNames.size} tools" }
253-
var removedCount = 0
254-
for (name in toolNames) {
255-
logger.debug { "Removing tool: $name" }
256-
if (tools.remove(name) != null) {
257-
removedCount++
258-
}
259-
}
267+
268+
val oldMap = _tools.getAndUpdate { current -> current - toolNames.toPersistentSet() }
269+
270+
val removedCount = toolNames.count { it in oldMap }
260271
logger.info {
261272
if (removedCount > 0) {
262-
"Removed $removedCount tools"
273+
"Removed $removedCount tools"
263274
} else {
264-
"No tools were removed"
275+
"No tools were removed"
265276
}
266277
}
267278
return removedCount
@@ -280,7 +291,7 @@ public open class Server(
280291
throw IllegalStateException("Server does not support prompts capability.")
281292
}
282293
logger.info { "Registering prompt: ${prompt.name}" }
283-
prompts[prompt.name] = RegisteredPrompt(prompt, promptProvider)
294+
_prompts.update { current -> current.put(prompt.name, RegisteredPrompt(prompt, promptProvider)) }
284295
}
285296

286297
/**
@@ -314,10 +325,7 @@ public open class Server(
314325
throw IllegalStateException("Server does not support prompts capability.")
315326
}
316327
logger.info { "Registering ${promptsToAdd.size} prompts" }
317-
for (rp in promptsToAdd) {
318-
logger.debug { "Registering prompt: ${rp.prompt.name}" }
319-
prompts[rp.prompt.name] = rp
320-
}
328+
_prompts.update { current -> current.putAll(promptsToAdd.associateBy { it.prompt.name }) }
321329
}
322330

323331
/**
@@ -333,7 +341,10 @@ public open class Server(
333341
throw IllegalStateException("Server does not support prompts capability.")
334342
}
335343
logger.info { "Removing prompt: $name" }
336-
val removed = prompts.remove(name) != null
344+
345+
val oldMap = _prompts.getAndUpdate { current -> current.remove(name) }
346+
347+
val removed = name in oldMap
337348
logger.debug {
338349
if (removed) {
339350
"Prompt removed: $name"
@@ -357,13 +368,11 @@ public open class Server(
357368
throw IllegalStateException("Server does not support prompts capability.")
358369
}
359370
logger.info { "Removing ${promptNames.size} prompts" }
360-
var removedCount = 0
361-
for (name in promptNames) {
362-
logger.debug { "Removing prompt: $name" }
363-
if (prompts.remove(name) != null) {
364-
removedCount++
365-
}
366-
}
371+
372+
val oldMap = _prompts.getAndUpdate { current -> current - promptNames.toPersistentSet() }
373+
374+
val removedCount = promptNames.count { it in oldMap }
375+
367376
logger.info {
368377
if (removedCount > 0) {
369378
"Removed $removedCount prompts"
@@ -396,7 +405,12 @@ public open class Server(
396405
throw IllegalStateException("Server does not support resources capability.")
397406
}
398407
logger.info { "Registering resource: $name ($uri)" }
399-
resources[uri] = RegisteredResource(Resource(uri, name, description, mimeType), readHandler)
408+
_resources.update { current ->
409+
current.put(
410+
uri,
411+
RegisteredResource(Resource(uri, name, description, mimeType), readHandler)
412+
)
413+
}
400414
}
401415

402416
/**
@@ -411,10 +425,7 @@ public open class Server(
411425
throw IllegalStateException("Server does not support resources capability.")
412426
}
413427
logger.info { "Registering ${resourcesToAdd.size} resources" }
414-
for (r in resourcesToAdd) {
415-
logger.debug { "Registering resource: ${r.resource.name} (${r.resource.uri})" }
416-
resources[r.resource.uri] = r
417-
}
428+
_resources.update { current -> current.putAll(resourcesToAdd.associateBy { it.resource.uri }) }
418429
}
419430

420431
/**
@@ -430,7 +441,10 @@ public open class Server(
430441
throw IllegalStateException("Server does not support resources capability.")
431442
}
432443
logger.info { "Removing resource: $uri" }
433-
val removed = resources.remove(uri) != null
444+
445+
val oldMap = _resources.getAndUpdate { current -> current.remove(uri) }
446+
447+
val removed = uri in oldMap
434448
logger.debug {
435449
if (removed) {
436450
"Resource removed: $uri"
@@ -454,13 +468,11 @@ public open class Server(
454468
throw IllegalStateException("Server does not support resources capability.")
455469
}
456470
logger.info { "Removing ${uris.size} resources" }
457-
var removedCount = 0
458-
for (uri in uris) {
459-
logger.debug { "Removing resource: $uri" }
460-
if (resources.remove(uri) != null) {
461-
removedCount++
462-
}
463-
}
471+
472+
val oldMap = _resources.getAndUpdate { current -> current - uris.toPersistentSet() }
473+
474+
val removedCount = uris.count { it in oldMap }
475+
464476
logger.info {
465477
if (removedCount > 0) {
466478
"Removed $removedCount resources"
@@ -586,7 +598,7 @@ public open class Server(
586598

587599
private suspend fun handleCallTool(request: CallToolRequest): CallToolResult {
588600
logger.debug { "Handling tool call request for tool: ${request.name}" }
589-
val tool = tools[request.name]
601+
val tool = _tools.value[request.name]
590602
?: run {
591603
logger.error { "Tool not found: ${request.name}" }
592604
throw IllegalArgumentException("Tool not found: ${request.name}")

0 commit comments

Comments
 (0)