From 89cea045c7f45b94a171019ab8cac96c756d2977 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Fri, 29 May 2026 23:54:16 +0800 Subject: [PATCH 01/22] feat: concurrent message dispatch across all transports Each transport now dispatches incoming messages via scope.launch instead of calling _onMessage synchronously. Message handlers run as independent coroutines on the transport's dispatcher, enabling concurrent processing. Transport changes: - StdioClientTransport / StdioServerTransport: replace synchronous _onMessage with launchMessageHandler - StreamableHttpClientTransport: same, in both SSE collector and inline-SSE POST response handler - SseClientTransport: same, in SSE event handler - WebSocketMcpTransport: same, in WebSocket frame handler - ChannelTransport: same + yield() between dispatches Close resources fixes: - Remove blocking scopeJob.join() after scope.cancel() in all transports; blocking I/O (readAtMostTo, SSE session) does not respond to cooperative cancellation and caused permanent hangs - invokeOnCloseCallback() is now called before scope.cancel() so pending response handlers (CompletableDeferred in Protocol.request) are completed with CONNECTION_CLOSED before any handler coroutine is cancelled Server batch processing: - StreamableHttpServerTransport.dispatchMessagesConcurrently: supervisorScope + async/awaitAll for parallel message handling - supervisorScope ensures a failure in one handler does not cancel siblings; first non-null error is reported after all handlers complete New concurrency tests: - ChannelTransportTest: 'receive loop continues while previous handler is suspended', 'error in one handler does not affect other messages' - StdioServerTransportTest: same tests via real process I/O - WebSocketMcpTransportTest: same tests via mock WebSocket session - StreamableHttpServerTransportTest: batch concurrent request/notification handling, handler error isolation, error-after-response-committed Test assertion fixes: - BaseTransportTest.testTransportRead: Semaphore-based wait for all messages instead of last-message equality check (order no longer guaranteed with concurrent dispatch); use shouldContainExactlyInAnyOrder - StreamableHttpClientTransportTest: unordered notification method comparison --- .../kotlin/sdk/shared/BaseTransportTest.kt | 13 +- .../kotlin/sdk/client/SseClientTransport.kt | 20 +- .../kotlin/sdk/client/StdioClientTransport.kt | 21 +- .../client/StreamableHttpClientTransport.kt | 27 ++- .../http/StreamableHttpClientTransportTest.kt | 10 +- .../sdk/shared/WebSocketMcpTransport.kt | 30 ++- .../sdk/shared/WebSocketMcpTransportTest.kt | 154 ++++++++++++ .../kotlin/sdk/server/SSEServerTransport.kt | 3 + .../kotlin/sdk/server/StdioServerTransport.kt | 12 +- .../server/StreamableHttpServerTransport.kt | 42 +++- .../sdk/server/StdioServerTransportTest.kt | 42 +++- .../StreamableHttpServerTransportTest.kt | 228 +++++++++++++++++- .../kotlin/sdk/testing/ChannelTransport.kt | 35 +-- .../sdk/testing/ChannelTransportTest.kt | 42 +++- 14 files changed, 612 insertions(+), 67 deletions(-) create mode 100644 kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index a6c6e3df2..f831c03e9 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -1,12 +1,13 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.kotest.assertions.nondeterministic.eventually +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.PingRequest import io.modelcontextprotocol.kotlin.sdk.types.toJSON -import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.sync.Semaphore import kotlin.test.fail import kotlin.time.Duration.Companion.seconds @@ -47,13 +48,11 @@ abstract class BaseTransportTest { ) val readMessages = mutableListOf() - val finished = CompletableDeferred() + val semaphore = Semaphore(messages.size, messages.size) transport.onMessage { message -> readMessages.add(message) - if (message == messages.last()) { - finished.complete(Unit) - } + semaphore.release() } transport.start() @@ -62,9 +61,9 @@ abstract class BaseTransportTest { transport.send(message) } - finished.await() + repeat(messages.size) { semaphore.acquire() } - messages shouldBe readMessages + messages shouldContainExactlyInAnyOrder readMessages transport.close() } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt index 0295eb9a3..b0777b548 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt @@ -80,7 +80,7 @@ public class SseClientTransport( reconnectionTime = reconnectionTime, block = requestBuilder, ) - scope = CoroutineScope(session.coroutineContext + SupervisorJob()) + scope = CoroutineScope(session.coroutineContext + SupervisorJob(session.coroutineContext[Job])) job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) { collectMessages() @@ -163,17 +163,31 @@ public class SseClientTransport( } } - private suspend fun handleMessage(data: String) { + private fun handleMessage(data: String) { try { val message = McpJson.decodeFromString(data) - _onMessage(message) + launchMessageHandler(message) } catch (e: SerializationException) { _onError(e) } } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("SseMcpClientTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } + override suspend fun closeResources() { withContext(NonCancellable) { + invokeOnCloseCallback() job?.cancel() try { if (::session.isInitialized) session.cancel() diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 63f681c5c..798005608 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -172,7 +172,7 @@ public class StdioClientTransport @JvmOverloads public constructor( .collect { event -> when (event) { is Event.JsonRpc -> { - handleJSONRPCMessage(event.message) + launchMessageHandler(event.message) } is Event.StderrEvent -> { @@ -244,7 +244,6 @@ public class StdioClientTransport @JvmOverloads public constructor( override suspend fun closeResources() { withContext(NonCancellable) { scope.stopProcessing("Closed") - scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete } } @@ -264,14 +263,16 @@ public class StdioClientTransport @JvmOverloads public constructor( } } - private suspend fun handleJSONRPCMessage(msg: JSONRPCMessage) { - try { - _onMessage.invoke(msg) - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error processing message." } - runCatching { _onError.invoke(e) } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StdioClientTransport.message#${hashCode()}")) { + try { + _onMessage.invoke(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message." } + runCatching { _onError.invoke(e) } + } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index c27ed79e1..3d604ba54 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -37,7 +37,6 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel -import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.launch @@ -218,7 +217,8 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } - sseJob?.cancelAndJoin() + invokeOnCloseCallback() + sseJob?.cancel() scope.cancel() } @@ -392,9 +392,9 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) + launchMessageHandler(msg.copy(id = replayMessageId)) } else { - _onMessage(msg) + launchMessageHandler(msg) } } .onFailure(_onError) @@ -427,7 +427,7 @@ public class StreamableHttpClientTransport( var id: String? = null var eventName: String? = null - suspend fun dispatch(id: String?, eventName: String?, data: String) { + fun dispatch(id: String?, eventName: String?, data: String) { id?.let { localLastEventId = it hasPrimingEvent = true @@ -441,9 +441,9 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) + launchMessageHandler(msg.copy(id = replayMessageId)) } else { - _onMessage(msg) + launchMessageHandler(msg) } } .onFailure { @@ -481,4 +481,17 @@ public class StreamableHttpClientTransport( } return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } + + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StreamableHttpTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } } diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt index b380f64a6..63011136f 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt @@ -2,6 +2,7 @@ package io.modelcontextprotocol.kotlin.sdk.client.streamable.http import io.kotest.assertions.fail import io.kotest.assertions.nondeterministic.eventually +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.collections.shouldHaveSize import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf @@ -611,11 +612,10 @@ class StreamableHttpClientTransportTest { receivedMessages shouldHaveSize 2 - val firstNotification = receivedMessages[0] as JSONRPCNotification - firstNotification.method shouldBe "notifications/progress" - - val secondNotification = receivedMessages[1] as JSONRPCNotification - secondNotification.method shouldBe "notifications/tools/list_changed" + receivedMessages + .filterIsInstance() + .map { it.method } + .shouldContainExactlyInAnyOrder("notifications/progress", "notifications/tools/list_changed") transport.close() } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index baf37233e..6fa7bcec2 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -10,10 +10,13 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpJson import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.InternalCoroutinesApi +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.job import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.coroutines.cancellation.CancellationException @@ -30,7 +33,7 @@ private val logger = KotlinLogging.logger {} @OptIn(ExperimentalAtomicApi::class) public abstract class WebSocketMcpTransport : AbstractTransport() { private val scope by lazy { - CoroutineScope(session.coroutineContext + SupervisorJob()) + CoroutineScope(session.coroutineContext + SupervisorJob(session.coroutineContext.job)) } private val initialized: AtomicBoolean = AtomicBoolean(false) @@ -73,8 +76,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } try { - val message = McpJson.decodeFromString(message.readText()) - _onMessage.invoke(message) + val parsedMessage = McpJson.decodeFromString(message.readText()) + launchMessageHandler(parsedMessage) } catch (e: CancellationException) { throw e } catch (e: Exception) { @@ -86,6 +89,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { @OptIn(InternalCoroutinesApi::class) session.coroutineContext.job.invokeOnCompletion { + scope.cancel() if (it != null) { _onError.invoke(it) } else { @@ -94,6 +98,19 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("WebSocketMcpTransport.message#${hashCode()}")) { + try { + _onMessage.invoke(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError.invoke(e) + } + } + } + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { logger.debug { "Sending message" } if (!initialized.load()) { @@ -109,7 +126,10 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } logger.debug { "Closing websocket session" } - session.close() - session.coroutineContext.job.join() + withContext(NonCancellable) { + invokeOnCloseCallback() + session.close() + scope.cancel() + } } } diff --git a/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt new file mode 100644 index 000000000..eb52cecae --- /dev/null +++ b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt @@ -0,0 +1,154 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.ktor.websocket.Frame +import io.ktor.websocket.WebSocketExtension +import io.ktor.websocket.WebSocketSession +import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.test.runTest +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.cancellation.CancellationException +import kotlin.test.Test +import kotlin.test.assertSame +import kotlin.test.fail + +class WebSocketMcpTransportTest { + @Test + fun `should continue receiving while previous handler is suspended`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val firstMessage = PingRequest().toJSON() + val secondMessage = InitializedNotification().toJSON() + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + transport.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(firstMessage))) + session.incoming.send(Frame.Text(serializeMessage(secondMessage))) + + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + transport.close() + } + + @Test + fun `close cancels suspended message handler`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val handlerStarted = CompletableDeferred() + val handlerCancelled = CompletableDeferred() + + transport.onMessage { + handlerStarted.complete(Unit) + try { + CompletableDeferred().await() + } catch (e: CancellationException) { + handlerCancelled.complete(Unit) + throw e + } + } + + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(PingRequest().toJSON()))) + handlerStarted.await() + transport.close() + + handlerCancelled.await() + } + + @Test + fun `handler errors are reported via onError`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val expected = IllegalStateException("handler failed") + val capturedError = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { throw expected } + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(PingRequest().toJSON()))) + + assertSame(expected, capturedError.await()) + transport.close() + } + + @Test + fun `handler cancellation is not reported via onError`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val capturedError = CompletableDeferred() + val handlerStarted = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { + handlerStarted.complete(Unit) + throw CancellationException("handler cancelled") + } + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(PingRequest().toJSON()))) + handlerStarted.await() + + if (capturedError.isCompleted) { + fail("CancellationException should not be reported via onError") + } + transport.close() + } + + private class TestWebSocketMcpTransport( + override val session: WebSocketSession, + ) : WebSocketMcpTransport() { + override suspend fun initializeSession() = Unit + } + + private class TestWebSocketSession( + parentContext: CoroutineContext, + ) : WebSocketSession { + private val job = Job(parentContext[Job]) + + override val coroutineContext: CoroutineContext = parentContext + job + override var masking: Boolean = false + override var maxFrameSize: Long = Long.MAX_VALUE + override val incoming = Channel(Channel.UNLIMITED) + override val outgoing = Channel(Channel.UNLIMITED) + override val extensions: List> = emptyList() + + override suspend fun flush() = Unit + + override suspend fun send(frame: Frame) { + outgoing.send(frame) + if (frame is Frame.Close) { + job.cancel() + incoming.close() + outgoing.close() + } + } + + @Deprecated( + "Use cancel() instead.", + replaceWith = ReplaceWith("cancel()", "kotlinx.coroutines.cancel"), + level = DeprecationLevel.ERROR + ) + override fun terminate() { + job.cancel() + incoming.close() + outgoing.close() + } + } +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index a6d461833..41816d9b0 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -109,6 +109,9 @@ public class SseServerTransport(private val endpoint: String, private val sessio /** * Handle a client message, regardless of how it arrived. * This can be used to inform the server of messages that arrive via a means different from HTTP POST. + * + * Ktor provides concurrency across POST requests by invoking [handlePostMessage] in a separate + * [ApplicationCall] coroutine, so this path keeps handler errors synchronous to the current POST. */ public suspend fun handleMessage(message: String) { try { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index a4d588cc5..e516659a2 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -8,10 +8,12 @@ import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.isActive @@ -143,7 +145,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: return job } - private suspend fun processReadBuffer() { + private fun processReadBuffer() { while (true) { val message = try { readBuffer.readMessage() @@ -153,7 +155,12 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: } if (message == null) break - // Async invocation broke delivery order + launchMessageHandler(message) + } + } + + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StdioServerTransport.message#${hashCode()}")) { try { _onMessage.invoke(message) } catch (e: CancellationException) { @@ -197,6 +204,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: processingJob?.cancelAndJoin() readBuffer.clear() + scope.cancel() runCatching { outputSink.flush() diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index ae8a39d1d..2bc70c8b9 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -29,8 +29,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.REQUEST_TIMEO import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.job +import kotlinx.coroutines.supervisorScope import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext @@ -408,7 +411,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat val hasRequest = messages.any { it is JSONRPCRequest } if (!hasRequest) { call.respondNullable(status = HttpStatusCode.Accepted, message = null) - messages.forEach { message -> _onMessage(message) } + dispatchMessagesConcurrently(messages) return } @@ -437,16 +440,39 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } - messages.forEach { message -> _onMessage(message) } + // Batch dispatch with parallel message processing. Non-cancellation + // handler failures are reported after every already-started handler + // has completed, so one failure does not cancel sibling handlers. + dispatchMessagesConcurrently(messages) } catch (e: CancellationException) { throw e - } catch (e: Exception) { - call.reject( - HttpStatusCode.BadRequest, - RPCError.ErrorCode.PARSE_ERROR, - "Parse error: ${e.message}", - ) + } catch (e: Throwable) { _onError(e) + runCatching { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.PARSE_ERROR, + "Parse error: ${e.message}", + ) + } + } + } + + private suspend fun dispatchMessagesConcurrently(messages: List) { + supervisorScope { + val handlerErrors = messages.map { message -> + async { + try { + _onMessage(message) + null + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + e + } + } + }.awaitAll() + handlerErrors.firstOrNull { it != null }?.let { throw it } } } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index 32123a693..9806c125b 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -4,6 +4,7 @@ import io.kotest.assertions.nondeterministic.eventually import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.withClue import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage @@ -154,7 +155,39 @@ class StdioServerTransportTest { server.start() finished.await() - readMessages shouldBe messages + readMessages.shouldContainExactlyInAnyOrder(messages) + } + + @Test + fun `should continue receiving while previous handler is suspended`() = runIntegrationTest { + val server = StdioServerTransport(bufferedInput, printOutput) + server.onError { error -> + throw error + } + + val firstMessage = PingRequest().toJSON() + val secondMessage = InitializedNotification().toJSON() + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + server.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + server.start() + + inputWriter.write(serializeMessage(firstMessage)) + inputWriter.write(serializeMessage(secondMessage)) + inputWriter.flush() + + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + server.close() } // region: Exception handling @@ -224,11 +257,15 @@ class StdioServerTransportTest { val capturedErrors = mutableListOf() val receivedMessages = mutableListOf() val secondMessageProcessed = CompletableDeferred() + val handlerErrorCaptured = CompletableDeferred() val message1 = PingRequest().toJSON() val message2 = InitializedNotification().toJSON() - server.onError { capturedErrors.add(it) } + server.onError { + capturedErrors.add(it) + handlerErrorCaptured.complete(Unit) + } server.onMessage { message -> if (message == message1) { throw throwable @@ -245,6 +282,7 @@ class StdioServerTransportTest { inputWriter.flush() secondMessageProcessed.await() + handlerErrorCaptured.await() capturedErrors shouldContain throwable receivedMessages shouldBe listOf(message2) diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index f6dda1520..08603d8cb 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -31,6 +31,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse @@ -44,6 +45,9 @@ import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.Tool import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put @@ -344,7 +348,229 @@ class StreamableHttpServerTransportTest { val firstMeta = (responses[0] as ListToolsResult).meta val secondMeta = (responses[1] as ListResourcesResult).meta assertEquals("first", firstMeta?.get("label")?.jsonPrimitive?.content) - assertEquals("second", secondMeta?.get("label")?.jsonPrimitive?.content)*/ + assertEquals("second", secondMeta?.get("label")?.jsonPrimitive?.content)*/ + } + + @Test + fun `batched requests are handled concurrently before replying`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val firstRequest = JSONRPCRequest(id = RequestId("first"), method = Method.Defined.ToolsList.value) + val secondRequest = JSONRPCRequest(id = RequestId("second"), method = Method.Defined.ResourcesList.value) + val firstHandlerStarted = CompletableDeferred() + val releaseFirstHandler = CompletableDeferred() + val secondHandlerCompleted = CompletableDeferred() + val secondResult = ListResourcesResult(resources = emptyList()) + + transport.onMessage { message -> + if (message is JSONRPCRequest) { + when (message.id) { + firstRequest.id -> { + firstHandlerStarted.complete(Unit) + releaseFirstHandler.await() + transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + } + secondRequest.id -> { + firstHandlerStarted.await() + transport.send(JSONRPCResponse(message.id, secondResult), null) + secondHandlerCompleted.complete(Unit) + } + else -> transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + } + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstRequest, secondRequest))) + } + } + + secondHandlerCompleted.await() + releaseFirstHandler.complete(Unit) + + val response = responseDeferred.await() + response.status shouldBe HttpStatusCode.OK + response.body>() shouldContainAll listOf( + JSONRPCResponse(firstRequest.id, EmptyResult()), + JSONRPCResponse(secondRequest.id, secondResult), + ) + } + } + + @Test + fun `batched notifications are handled concurrently`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val firstNotification = InitializedNotification().toJSON() + val secondNotification = JSONRPCNotification("notifications/test") + val firstHandlerStarted = CompletableDeferred() + val releaseFirstHandler = CompletableDeferred() + val secondHandlerCompleted = CompletableDeferred() + + transport.onMessage { message -> + when (message) { + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult())) + firstNotification -> { + firstHandlerStarted.complete(Unit) + releaseFirstHandler.await() + } + secondNotification -> { + firstHandlerStarted.await() + secondHandlerCompleted.complete(Unit) + } + else -> Unit + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstNotification, secondNotification))) + } + } + + secondHandlerCompleted.await() + releaseFirstHandler.complete(Unit) + + responseDeferred.await().status shouldBe HttpStatusCode.Accepted + } + } + + @Test + fun `batched notifications report handler error after sibling completes`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val expected = IllegalStateException("notification failed") + val capturedError = CompletableDeferred() + val firstNotification = InitializedNotification().toJSON() + val secondNotification = JSONRPCNotification("notifications/test") + val firstHandlerStarted = CompletableDeferred() + val releaseFirstHandler = CompletableDeferred() + val secondHandlerStarted = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { message -> + when (message) { + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult())) + firstNotification -> { + firstHandlerStarted.complete(Unit) + releaseFirstHandler.await() + } + secondNotification -> { + firstHandlerStarted.await() + secondHandlerStarted.complete(Unit) + throw expected + } + else -> Unit + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstNotification, secondNotification))) + } + } + + secondHandlerStarted.await() + releaseFirstHandler.complete(Unit) + responseDeferred.await().status + } + + capturedError.await() shouldBe expected + } + + @Test + fun `batch handler error reports onError after response is committed`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val expected = IllegalStateException("request failed") + val capturedError = CompletableDeferred() + val firstRequest = JSONRPCRequest(id = RequestId("first"), method = Method.Defined.ToolsList.value) + val notification = JSONRPCNotification("notifications/test") + val requestResponseSent = CompletableDeferred() + val releaseNotificationHandler = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { message -> + when (message) { + firstRequest -> { + transport.send(JSONRPCResponse(firstRequest.id, EmptyResult()), null) + requestResponseSent.complete(Unit) + } + notification -> { + requestResponseSent.await() + releaseNotificationHandler.await() + throw expected + } + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + else -> Unit + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstRequest, notification))) + } + } + + requestResponseSent.await() + releaseNotificationHandler.complete(Unit) + responseDeferred.await().status + } + + capturedError.await() shouldBe expected } @ParameterizedTest diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index ae5f487ce..df387a3b4 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -12,7 +12,6 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel @@ -121,19 +120,8 @@ public class ChannelTransport( for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } - - try { - _onMessage.invoke(message) - logger.trace { "Message processed successfully: ${message::class.simpleName}" } - } catch (e: CancellationException) { - // Let cancellation propagate immediately - logger.debug { "Cancellation requested during message processing" } - throw e - } catch (e: Exception) { - // Report other errors but continue processing - logger.warn(e) { "Error processing message: ${message::class.simpleName}" } - _onError.invoke(e) - } + launchMessageHandler(message) + yield() } logger.info { "ChannelTransport stopped: receive channel closed" } } catch (e: Exception) { @@ -150,6 +138,23 @@ public class ChannelTransport( started.await() } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { + try { + _onMessage.invoke(message) + logger.trace { "Message processed successfully: ${message::class.simpleName}" } + } catch (e: CancellationException) { + // Let cancellation propagate immediately + logger.debug { "Cancellation requested during message processing" } + throw e + } catch (e: Throwable) { + // Report other errors but continue processing + logger.warn(e) { "Error processing message: ${message::class.simpleName}" } + _onError.invoke(e) + } + } + } + /** * Sends a JSON-RPC message through the transport. * @@ -170,13 +175,13 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } + invokeOnCloseCallback() sendChannel.close() if (receiveChannel !== sendChannel) { logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } scope.cancel() - scope.coroutineContext[Job]?.join() // Wait for cleanup logger.info { "ChannelTransport closed" } } } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 4aae932e9..330bc8944 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -2,6 +2,7 @@ package io.modelcontextprotocol.kotlin.sdk.testing import io.kotest.assertions.nondeterministic.eventually import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi @@ -9,11 +10,14 @@ import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ClosedSendChannelException import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout import kotlin.test.Test import kotlin.time.Duration.Companion.seconds @@ -57,7 +61,40 @@ class ChannelTransportTest { // Wait for messages to be processed messagesProcessed.await() - received.shouldContainExactly(msg1, msg2) + received.shouldContainExactlyInAnyOrder(msg1, msg2) + } + + @Test + fun `receive loop continues while previous handler is suspended`() = runTest { + val receiveChannel = Channel(Channel.UNLIMITED) + val transport = ChannelTransport(Channel(), receiveChannel) + + val firstMessage = JSONRPCRequest(RequestId.NumberId(1), "method1") + val secondMessage = JSONRPCRequest(RequestId.NumberId(2), "method2") + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + transport.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + transport.start() + + receiveChannel.send(firstMessage) + receiveChannel.send(secondMessage) + + withContext(Dispatchers.Default.limitedParallelism(1)) { + withTimeout(2.seconds) { + secondMessageProcessed.await() + } + } + releaseFirstHandler.complete(Unit) + transport.close() } @Test @@ -204,8 +241,9 @@ class ChannelTransportTest { allProcessed.await() // All messages should be processed despite error in message 2 - received.shouldContainExactly(1, 2, 3, 4) + received.shouldContainExactlyInAnyOrder(1, 2, 3, 4) errors.size shouldBe 1 errors[0].message shouldBe "Error processing message 2" + transport.close() } } From 0274fab7cf6212ab3656982f508fabfc32c39b11 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 14:17:46 +0800 Subject: [PATCH 02/22] chore: revert a indent --- .../kotlin/sdk/server/StreamableHttpServerTransportTest.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index 08603d8cb..59300397a 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -31,8 +31,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION @@ -348,7 +348,7 @@ class StreamableHttpServerTransportTest { val firstMeta = (responses[0] as ListToolsResult).meta val secondMeta = (responses[1] as ListResourcesResult).meta assertEquals("first", firstMeta?.get("label")?.jsonPrimitive?.content) - assertEquals("second", secondMeta?.get("label")?.jsonPrimitive?.content)*/ + assertEquals("second", secondMeta?.get("label")?.jsonPrimitive?.content)*/ } @Test From 2147ce3bd7082abd788f5bebf7ff75c13e5c75f2 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 14:19:34 +0800 Subject: [PATCH 03/22] fix: BaseTransportTest assertion order problem for readability --- .../modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index f831c03e9..825310aeb 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -63,7 +63,7 @@ abstract class BaseTransportTest { repeat(messages.size) { semaphore.acquire() } - messages shouldContainExactlyInAnyOrder readMessages + readMessages shouldContainExactlyInAnyOrder messages transport.close() } From c3a337e7132a835ff2c232b7dca05b740f6d6516 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 14:28:33 +0800 Subject: [PATCH 04/22] fix: error handling in StreamableHttpServerTransport JSONRpc Error responses are now indenpedent --- .../server/StreamableHttpServerTransport.kt | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 2bc70c8b9..8d6bb33ac 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -446,7 +446,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat dispatchMessagesConcurrently(messages) } catch (e: CancellationException) { throw e - } catch (e: Throwable) { + } catch (e: Exception) { _onError(e) runCatching { call.reject( @@ -460,19 +460,29 @@ public class StreamableHttpServerTransport(private val configuration: Configurat private suspend fun dispatchMessagesConcurrently(messages: List) { supervisorScope { - val handlerErrors = messages.map { message -> + messages.map { message -> async { try { _onMessage(message) - null } catch (e: CancellationException) { throw e - } catch (e: Throwable) { - e + } catch (e: Exception) { + _onError(e) + if (message is JSONRPCRequest) { + send( + JSONRPCError( + id = message.id, + error = RPCError( + code = RPCError.ErrorCode.INTERNAL_ERROR, + message = "Message processing error: ${e.message}", + ), + ), + TransportSendOptions(relatedRequestId = message.id), + ) + } } } }.awaitAll() - handlerErrors.firstOrNull { it != null }?.let { throw it } } } From 2267a15949ae0595e1a747153a02bdfa9b0d6ac0 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 14:55:08 +0800 Subject: [PATCH 05/22] review: address Copilot PR comments - StreamableHttpServerTransport.dispatchMessagesConcurrently: handler errors no longer throw; each error is reported via _onError and for JSONRPCRequest messages an error response is sent back. supervisorScope + async/awaitAll ensures all handlers complete before returning. - Outer catch in handlePostMessage: restore Exception (not Throwable) to avoid swallowing Error subtypes; keep PARSE_ERROR semantics since handler errors no longer escape dispatchMessagesConcurrently. - BaseTransportTest.testTransportRead: Mutex + Semaphore for thread-safe concurrent message collection; shouldContainExactlyInAnyOrder with correct argument order. - ChannelTransport.closeResources: scope.coroutineContext[Job]?.cancelAndJoin() instead of scope.cancel() so close() waits for handler cancellation. - WebSocketMcpTransport.close: same change. - StreamableHttpClientTransport.closeResources: same change. - ChannelTransport event loop: add comment explaining why yield() is needed after launchMessageHandler. - ConcurrentSet replaced with Mutex + mutableListOf for commonTest compatibility. --- .../kotlin/sdk/shared/BaseTransportTest.kt | 7 ++++++- .../sdk/client/StreamableHttpClientTransport.kt | 4 ++-- .../kotlin/sdk/shared/WebSocketMcpTransport.kt | 3 ++- .../kotlin/sdk/testing/ChannelTransport.kt | 11 +++++++++-- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index 825310aeb..90392259f 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -3,6 +3,8 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.kotest.assertions.nondeterministic.eventually import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.PingRequest @@ -48,10 +50,13 @@ abstract class BaseTransportTest { ) val readMessages = mutableListOf() + val mutex = Mutex() val semaphore = Semaphore(messages.size, messages.size) transport.onMessage { message -> - readMessages.add(message) + mutex.withLock { + readMessages.add(message) + } semaphore.release() } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 3d604ba54..24914049a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -36,7 +36,7 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.launch @@ -219,7 +219,7 @@ public class StreamableHttpClientTransport( logger.debug { "Client transport closing." } invokeOnCloseCallback() sseJob?.cancel() - scope.cancel() + scope.coroutineContext[Job]?.cancelAndJoin() } /** diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index 6fa7bcec2..bd62cc910 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -13,6 +13,7 @@ import kotlinx.coroutines.InternalCoroutinesApi import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.job import kotlinx.coroutines.launch @@ -129,7 +130,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { withContext(NonCancellable) { invokeOnCloseCallback() session.close() - scope.cancel() + scope.coroutineContext.job.cancelAndJoin() } } } diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index df387a3b4..b6de42016 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -12,8 +12,9 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel @@ -121,6 +122,11 @@ public class ChannelTransport( for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } launchMessageHandler(message) + // Yield after launching each handler so the dispatcher can schedule it + // before the event loop continues. Without this, if the channel closes + // immediately after the last message, the event loop enters finally -> + // closeResources -> scope.cancel(), which cancels the pending handler + // coroutine before it has a chance to run. yield() } logger.info { "ChannelTransport stopped: receive channel closed" } @@ -181,7 +187,8 @@ public class ChannelTransport( logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } - scope.cancel() + scope.coroutineContext[Job]?.cancelAndJoin() + logger.info { "ChannelTransport closed" } } } From 320f9ab8ee9b406e0d0df1744e0ac89e9fc6b010 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 15:14:05 +0800 Subject: [PATCH 06/22] review: remove StdioServerTransport scope.cancel, fix ChannelTransport whitespace, add runCatching around send in dispatchMessagesConcurrently StdioServerTransport: - Remove scope.cancel() added by WIP commit. processingJob?.cancelAndJoin() already handles the processing loop; launched handler coroutines are fire-and-forget and do not need explicit cancellation on close. StreamableHttpServerTransport: - Wrap the send() call in dispatchMessagesConcurrently fallback error path with runCatching so a send failure does not abort the batch and is reported via _onError instead. ChannelTransport: - Remove trailing whitespace after cancelAndJoin(). BaseTransportTest: - Use Channel instead of Mutex+Semaphore for thread-safe concurrent message collection in testTransportRead. --- .../kotlin/sdk/shared/BaseTransportTest.kt | 21 +++++++++--------- .../kotlin/sdk/server/StdioServerTransport.kt | 1 - .../server/StreamableHttpServerTransport.kt | 22 +++++++++++-------- .../kotlin/sdk/testing/ChannelTransport.kt | 1 - 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index 90392259f..748cd9bb3 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -3,13 +3,11 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.kotest.assertions.nondeterministic.eventually import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.PingRequest import io.modelcontextprotocol.kotlin.sdk.types.toJSON -import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.channels.Channel import kotlin.test.fail import kotlin.time.Duration.Companion.seconds @@ -49,15 +47,10 @@ abstract class BaseTransportTest { InitializedNotification().toJSON(), ) - val readMessages = mutableListOf() - val mutex = Mutex() - val semaphore = Semaphore(messages.size, messages.size) + val chan = Channel() transport.onMessage { message -> - mutex.withLock { - readMessages.add(message) - } - semaphore.release() + chan.send(message) } transport.start() @@ -66,7 +59,13 @@ abstract class BaseTransportTest { transport.send(message) } - repeat(messages.size) { semaphore.acquire() } + val readMessages = buildList { + repeat(messages.size) { + add(chan.receive()) + } + } + + readMessages shouldContainExactlyInAnyOrder messages diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index e516659a2..f2352b929 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -204,7 +204,6 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: processingJob?.cancelAndJoin() readBuffer.clear() - scope.cancel() runCatching { outputSink.flush() diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 8d6bb33ac..d4f5c8d8a 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -469,16 +469,20 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } catch (e: Exception) { _onError(e) if (message is JSONRPCRequest) { - send( - JSONRPCError( - id = message.id, - error = RPCError( - code = RPCError.ErrorCode.INTERNAL_ERROR, - message = "Message processing error: ${e.message}", + runCatching { + send( + JSONRPCError( + id = message.id, + error = RPCError( + code = RPCError.ErrorCode.INTERNAL_ERROR, + message = "Message processing error: ${e.message}", + ), ), - ), - TransportSendOptions(relatedRequestId = message.id), - ) + TransportSendOptions(relatedRequestId = message.id), + ) + }.onFailure { sendError -> + _onError(sendError) + } } } } diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index b6de42016..4d8d1c329 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -188,7 +188,6 @@ public class ChannelTransport( receiveChannel.cancel() } scope.coroutineContext[Job]?.cancelAndJoin() - logger.info { "ChannelTransport closed" } } } From bed090e7b1c2fcb617ad951739eb261e3f63eb0d Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 15:23:09 +0800 Subject: [PATCH 07/22] review: remove unnecessary withTimeout from ChannelTransportTest --- .../kotlin/sdk/shared/BaseTransportTest.kt | 2 -- .../kotlin/sdk/shared/WebSocketMcpTransportTest.kt | 10 +++------- .../sdk/server/StreamableHttpServerTransportTest.kt | 11 +++++++++++ .../kotlin/sdk/testing/ChannelTransport.kt | 12 +++++------- .../kotlin/sdk/testing/ChannelTransportTest.kt | 9 +-------- 5 files changed, 20 insertions(+), 24 deletions(-) diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index 748cd9bb3..80366ef39 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -65,8 +65,6 @@ abstract class BaseTransportTest { } } - - readMessages shouldContainExactlyInAnyOrder messages transport.close() diff --git a/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt index eb52cecae..0920710ad 100644 --- a/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt +++ b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt @@ -111,15 +111,11 @@ class WebSocketMcpTransportTest { transport.close() } - private class TestWebSocketMcpTransport( - override val session: WebSocketSession, - ) : WebSocketMcpTransport() { + private class TestWebSocketMcpTransport(override val session: WebSocketSession) : WebSocketMcpTransport() { override suspend fun initializeSession() = Unit } - private class TestWebSocketSession( - parentContext: CoroutineContext, - ) : WebSocketSession { + private class TestWebSocketSession(parentContext: CoroutineContext) : WebSocketSession { private val job = Job(parentContext[Job]) override val coroutineContext: CoroutineContext = parentContext + job @@ -143,7 +139,7 @@ class WebSocketMcpTransportTest { @Deprecated( "Use cancel() instead.", replaceWith = ReplaceWith("cancel()", "kotlinx.coroutines.cancel"), - level = DeprecationLevel.ERROR + level = DeprecationLevel.ERROR, ) override fun terminate() { job.cancel() diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index 59300397a..7f910f05c 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -373,11 +373,13 @@ class StreamableHttpServerTransportTest { releaseFirstHandler.await() transport.send(JSONRPCResponse(message.id, EmptyResult()), null) } + secondRequest.id -> { firstHandlerStarted.await() transport.send(JSONRPCResponse(message.id, secondResult), null) secondHandlerCompleted.complete(Unit) } + else -> transport.send(JSONRPCResponse(message.id, EmptyResult()), null) } } @@ -427,14 +429,17 @@ class StreamableHttpServerTransportTest { transport.onMessage { message -> when (message) { is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult())) + firstNotification -> { firstHandlerStarted.complete(Unit) releaseFirstHandler.await() } + secondNotification -> { firstHandlerStarted.await() secondHandlerCompleted.complete(Unit) } + else -> Unit } } @@ -481,15 +486,18 @@ class StreamableHttpServerTransportTest { transport.onMessage { message -> when (message) { is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult())) + firstNotification -> { firstHandlerStarted.complete(Unit) releaseFirstHandler.await() } + secondNotification -> { firstHandlerStarted.await() secondHandlerStarted.complete(Unit) throw expected } + else -> Unit } } @@ -539,12 +547,15 @@ class StreamableHttpServerTransportTest { transport.send(JSONRPCResponse(firstRequest.id, EmptyResult()), null) requestResponseSent.complete(Unit) } + notification -> { requestResponseSent.await() releaseNotificationHandler.await() throw expected } + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + else -> Unit } } diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index 4d8d1c329..6af6859bb 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -19,6 +19,7 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.yield @@ -41,6 +42,7 @@ public class ChannelTransport( override val logger: KLogger = KotlinLogging.logger {} private val scope = CoroutineScope(SupervisorJob() + dispatcher) + private val handlerJobs = mutableListOf() /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. @@ -122,12 +124,6 @@ public class ChannelTransport( for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } launchMessageHandler(message) - // Yield after launching each handler so the dispatcher can schedule it - // before the event loop continues. Without this, if the channel closes - // immediately after the last message, the event loop enters finally -> - // closeResources -> scope.cancel(), which cancels the pending handler - // coroutine before it has a chance to run. - yield() } logger.info { "ChannelTransport stopped: receive channel closed" } } catch (e: Exception) { @@ -145,7 +141,7 @@ public class ChannelTransport( } private fun launchMessageHandler(message: JSONRPCMessage) { - scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { + val job = scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { try { _onMessage.invoke(message) logger.trace { "Message processed successfully: ${message::class.simpleName}" } @@ -159,6 +155,7 @@ public class ChannelTransport( _onError.invoke(e) } } + handlerJobs.add(job) } /** @@ -187,6 +184,7 @@ public class ChannelTransport( logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } + handlerJobs.joinAll() scope.coroutineContext[Job]?.cancelAndJoin() logger.info { "ChannelTransport closed" } } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 330bc8944..3203a725e 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -10,14 +10,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ClosedSendChannelException import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.withContext -import kotlinx.coroutines.withTimeout import kotlin.test.Test import kotlin.time.Duration.Companion.seconds @@ -88,11 +85,7 @@ class ChannelTransportTest { receiveChannel.send(firstMessage) receiveChannel.send(secondMessage) - withContext(Dispatchers.Default.limitedParallelism(1)) { - withTimeout(2.seconds) { - secondMessageProcessed.await() - } - } + secondMessageProcessed.await() releaseFirstHandler.complete(Unit) transport.close() } From 33ba1e306a501cc751f851050d6247e66a48bacd Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 15:46:03 +0800 Subject: [PATCH 08/22] review: fix handlerJobs thread-safety and memory leak, remove unused StdioServerTransport cancel import --- .../kotlin/sdk/client/StdioClientTransport.kt | 1 + .../kotlin/sdk/server/StdioServerTransport.kt | 1 - .../sdk/server/StreamableHttpServerTransport.kt | 2 +- .../kotlin/sdk/testing/ChannelTransport.kt | 14 +++++++++----- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 798005608..cc599462a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -244,6 +244,7 @@ public class StdioClientTransport @JvmOverloads public constructor( override suspend fun closeResources() { withContext(NonCancellable) { scope.stopProcessing("Closed") + scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index f2352b929..d1f363eac 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -13,7 +13,6 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.isActive diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index d4f5c8d8a..bb8ee878e 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -447,7 +447,6 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } catch (e: CancellationException) { throw e } catch (e: Exception) { - _onError(e) runCatching { call.reject( HttpStatusCode.BadRequest, @@ -455,6 +454,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat "Parse error: ${e.message}", ) } + _onError(e) } } diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index 6af6859bb..cd98dbf8f 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -19,9 +19,10 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel -import kotlinx.coroutines.joinAll +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.launch import kotlinx.coroutines.yield +import kotlin.coroutines.coroutineContext /** * A transport implementation that uses Kotlin Coroutines Channels for asynchronous @@ -42,7 +43,6 @@ public class ChannelTransport( override val logger: KLogger = KotlinLogging.logger {} private val scope = CoroutineScope(SupervisorJob() + dispatcher) - private val handlerJobs = mutableListOf() /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. @@ -141,7 +141,7 @@ public class ChannelTransport( } private fun launchMessageHandler(message: JSONRPCMessage) { - val job = scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { + scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { try { _onMessage.invoke(message) logger.trace { "Message processed successfully: ${message::class.simpleName}" } @@ -155,7 +155,6 @@ public class ChannelTransport( _onError.invoke(e) } } - handlerJobs.add(job) } /** @@ -184,7 +183,12 @@ public class ChannelTransport( logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } - handlerJobs.joinAll() + // Join in-flight handler child jobs before cancelling the scope. + // Filter out the current (event-loop) coroutine to avoid deadlock. + // Using scope.coroutineContext.job.children instead of a manual + // handlerJobs list avoids thread-safety and memory-leak concerns. + val currentJob = currentCoroutineContext()[Job] + scope.coroutineContext[Job]?.children?.filter { it !== currentJob }?.forEach { it.join() } scope.coroutineContext[Job]?.cancelAndJoin() logger.info { "ChannelTransport closed" } } From 074cc323b3f48cb091e2a16b0585a440d969ebf9 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 15:58:24 +0800 Subject: [PATCH 09/22] review: fix exceptions in message handler test for concurrent dispatch --- .../kotlin/AbstractResourceIntegrationTest.kt | 3 ++- .../kotlin/sdk/testing/ChannelTransportTest.kt | 12 ++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index da7d85b35..58cef3299 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -21,6 +21,7 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import java.util.concurrent.CopyOnWriteArrayList @@ -215,7 +216,7 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { val invalidUri = "test://nonexistent.txt" val exception = assertThrows { - runBlocking { + withContext(Dispatchers.Default){ client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = invalidUri))) } } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 3203a725e..a7bde9659 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -207,20 +207,16 @@ class ChannelTransportTest { val receiveChannel = Channel(Channel.UNLIMITED) val transport = ChannelTransport(Channel(), receiveChannel) - val allProcessed = CompletableDeferred() - val received = mutableListOf() + val received = Channel(Channel.UNLIMITED) val errors = mutableListOf() transport.onError { errors.add(it) } transport.onMessage { msg -> val id = ((msg as JSONRPCRequest).id as RequestId.NumberId).value.toInt() - received.add(id) + received.send(id) if (id == 2) { throw RuntimeException("Error processing message 2") } - if (received.size == 4) { - allProcessed.complete(Unit) - } } transport.start() @@ -231,10 +227,10 @@ class ChannelTransportTest { receiveChannel.send(JSONRPCRequest(RequestId.NumberId(3), "m3")) receiveChannel.send(JSONRPCRequest(RequestId.NumberId(4), "m4")) - allProcessed.await() + val receivedValues = buildList { repeat(4) { add(received.receive()) } } // All messages should be processed despite error in message 2 - received.shouldContainExactlyInAnyOrder(1, 2, 3, 4) + receivedValues.shouldContainExactlyInAnyOrder(1, 2, 3, 4) errors.size shouldBe 1 errors[0].message shouldBe "Error processing message 2" transport.close() From 52477eb5a5632187d3f85554a7a33b7111c185e0 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 16:16:57 +0800 Subject: [PATCH 10/22] review: address remaining Copilot comments - KDoc, NonCancellable, spacing, test assertions --- .../kotlin/AbstractResourceIntegrationTest.kt | 2 +- .../sdk/shared/WebSocketMcpTransport.kt | 4 +++ .../server/StreamableHttpServerTransport.kt | 4 ++- .../StreamableHttpServerTransportTest.kt | 4 +-- .../kotlin/sdk/testing/ChannelTransport.kt | 27 ++++++++++--------- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index 58cef3299..990a7ac6b 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -216,7 +216,7 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { val invalidUri = "test://nonexistent.txt" val exception = assertThrows { - withContext(Dispatchers.Default){ + withContext(Dispatchers.Default) { client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = invalidUri))) } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index bd62cc910..9a805b14c 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -90,6 +90,10 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { @OptIn(InternalCoroutinesApi::class) session.coroutineContext.job.invokeOnCompletion { + // Cancel the scope when the session completes. For normal session + // completion the SupervisorJob parent does not auto-cancel children; + // for error/cancellation the propagation already cancels the scope + // job, making this cancel a no-op. scope.cancel() if (it != null) { _onError.invoke(it) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index bb8ee878e..f826735d1 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -344,7 +344,9 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } /** - * Handles POST requests containing JSON-RPC messages + * Handles POST requests containing JSON-RPC messages. Batch messages are + * dispatched concurrently — the MCP batch spec does not mandate sequential + * processing, and individual POST requests were already concurrent. */ public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) { try { diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index 7f910f05c..bb5c64eae 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -520,7 +520,7 @@ class StreamableHttpServerTransportTest { secondHandlerStarted.await() releaseFirstHandler.complete(Unit) - responseDeferred.await().status + responseDeferred.await().status shouldBe HttpStatusCode.Accepted } capturedError.await() shouldBe expected @@ -578,7 +578,7 @@ class StreamableHttpServerTransportTest { requestResponseSent.await() releaseNotificationHandler.complete(Unit) - responseDeferred.await().status + responseDeferred.await().status shouldBe HttpStatusCode.OK } capturedError.await() shouldBe expected diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index cd98dbf8f..b6bcc8834 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -13,6 +13,7 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel @@ -21,8 +22,8 @@ import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlinx.coroutines.yield -import kotlin.coroutines.coroutineContext /** * A transport implementation that uses Kotlin Coroutines Channels for asynchronous @@ -177,19 +178,19 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } - invokeOnCloseCallback() - sendChannel.close() - if (receiveChannel !== sendChannel) { - logger.debug { "Cancelling separate receive channel" } - receiveChannel.cancel() + withContext(NonCancellable) { + invokeOnCloseCallback() + sendChannel.close() + if (receiveChannel !== sendChannel) { + logger.debug { "Cancelling separate receive channel" } + receiveChannel.cancel() + } + // Join in-flight handler child jobs before cancelling the scope. + // Filter out the current (event-loop) coroutine to avoid deadlock. + val currentJob = currentCoroutineContext()[Job] + scope.coroutineContext[Job]?.children?.filter { it !== currentJob }?.forEach { it.join() } + scope.coroutineContext[Job]?.cancelAndJoin() } - // Join in-flight handler child jobs before cancelling the scope. - // Filter out the current (event-loop) coroutine to avoid deadlock. - // Using scope.coroutineContext.job.children instead of a manual - // handlerJobs list avoids thread-safety and memory-leak concerns. - val currentJob = currentCoroutineContext()[Job] - scope.coroutineContext[Job]?.children?.filter { it !== currentJob }?.forEach { it.join() } - scope.coroutineContext[Job]?.cancelAndJoin() logger.info { "ChannelTransport closed" } } } From 4f74f9ee345bf5b0ef0ded5559b6f2cac74b59c0 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 18:05:21 +0800 Subject: [PATCH 11/22] fix: use Channel instead of mutableListOf for errors in ChannelTransportTest to fix Native thread-safety --- .../kotlin/sdk/testing/ChannelTransportTest.kt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index a7bde9659..70e176247 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -208,9 +208,9 @@ class ChannelTransportTest { val transport = ChannelTransport(Channel(), receiveChannel) val received = Channel(Channel.UNLIMITED) - val errors = mutableListOf() + val errors = Channel(Channel.UNLIMITED) - transport.onError { errors.add(it) } + transport.onError { errors.trySend(it) } transport.onMessage { msg -> val id = ((msg as JSONRPCRequest).id as RequestId.NumberId).value.toInt() received.send(id) @@ -231,8 +231,10 @@ class ChannelTransportTest { // All messages should be processed despite error in message 2 receivedValues.shouldContainExactlyInAnyOrder(1, 2, 3, 4) - errors.size shouldBe 1 - errors[0].message shouldBe "Error processing message 2" + val error = errors.receive() + error.message shouldBe "Error processing message 2" + // Ensure no extra errors were reported + errors.tryReceive().getOrNull() shouldBe null transport.close() } } From 0d8999bc6f6cde0968615362726fe19cbc0cacff Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Sat, 30 May 2026 18:27:02 +0800 Subject: [PATCH 12/22] review: report call.reject failure via _onError instead of silent swallow --- .../kotlin/sdk/server/StreamableHttpServerTransport.kt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index f826735d1..d77afc08d 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -455,6 +455,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat RPCError.ErrorCode.PARSE_ERROR, "Parse error: ${e.message}", ) + }.onFailure { rejectFailure -> + _onError(rejectFailure) } _onError(e) } From 57e0cd23ac8087c7f1e4a82dd71c5133d2aa5527 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 20:47:52 +0800 Subject: [PATCH 13/22] feat: concurrent message dispatch on top of upstream/main lifecycle hardening --- .../kotlin/sdk/client/SseClientTransport.kt | 20 ++++++- .../kotlin/sdk/client/StdioClientTransport.kt | 20 ++++--- .../client/StreamableHttpClientTransport.kt | 52 +++++++++------- .../kotlin/sdk/server/StdioServerTransport.kt | 51 +++++++++------- .../server/StreamableHttpServerTransport.kt | 60 ++++++++++++++++--- .../sdk/server/StdioServerTransportTest.kt | 37 +++++++++++- .../kotlin/sdk/testing/ChannelTransport.kt | 50 +++++++++------- .../sdk/testing/ChannelTransportTest.kt | 55 +++++++++++++---- 8 files changed, 249 insertions(+), 96 deletions(-) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt index 0295eb9a3..b0777b548 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt @@ -80,7 +80,7 @@ public class SseClientTransport( reconnectionTime = reconnectionTime, block = requestBuilder, ) - scope = CoroutineScope(session.coroutineContext + SupervisorJob()) + scope = CoroutineScope(session.coroutineContext + SupervisorJob(session.coroutineContext[Job])) job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) { collectMessages() @@ -163,17 +163,31 @@ public class SseClientTransport( } } - private suspend fun handleMessage(data: String) { + private fun handleMessage(data: String) { try { val message = McpJson.decodeFromString(data) - _onMessage(message) + launchMessageHandler(message) } catch (e: SerializationException) { _onError(e) } } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("SseMcpClientTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } + override suspend fun closeResources() { withContext(NonCancellable) { + invokeOnCloseCallback() job?.cancel() try { if (::session.isInitialized) session.cancel() diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 63f681c5c..cc599462a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -172,7 +172,7 @@ public class StdioClientTransport @JvmOverloads public constructor( .collect { event -> when (event) { is Event.JsonRpc -> { - handleJSONRPCMessage(event.message) + launchMessageHandler(event.message) } is Event.StderrEvent -> { @@ -264,14 +264,16 @@ public class StdioClientTransport @JvmOverloads public constructor( } } - private suspend fun handleJSONRPCMessage(msg: JSONRPCMessage) { - try { - _onMessage.invoke(msg) - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error processing message." } - runCatching { _onError.invoke(e) } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StdioClientTransport.message#${hashCode()}")) { + try { + _onMessage.invoke(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message." } + runCatching { _onError.invoke(e) } + } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index c27ed79e1..a09f640e3 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -36,7 +36,6 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.isActive @@ -167,9 +166,9 @@ public class StreamableHttpClientTransport( when (response.contentType()?.withoutParameters()) { ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> - runCatching { McpJson.decodeFromString(json) } - .onSuccess { _onMessage(it) } - .onFailure { + runCatching { McpJson.decodeFromString(json) } + .onSuccess { launchMessageHandler(it) } + .onFailure { _onError(it) throw it } @@ -218,8 +217,9 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } - sseJob?.cancelAndJoin() - scope.cancel() + invokeOnCloseCallback() + sseJob?.cancel() + scope.coroutineContext[Job]?.cancelAndJoin() } /** @@ -392,12 +392,12 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) - } else { - _onMessage(msg) - } - } - .onFailure(_onError) + launchMessageHandler(msg.copy(id = replayMessageId)) + } else { + launchMessageHandler(msg) + } + } + .onFailure(_onError) } "error" -> _onError(StreamableHttpError(null, event.data)) @@ -427,7 +427,7 @@ public class StreamableHttpClientTransport( var id: String? = null var eventName: String? = null - suspend fun dispatch(id: String?, eventName: String?, data: String) { + fun dispatch(id: String?, eventName: String?, data: String) { id?.let { localLastEventId = it hasPrimingEvent = true @@ -441,12 +441,12 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) - } else { - _onMessage(msg) - } - } - .onFailure { + launchMessageHandler(msg.copy(id = replayMessageId)) + } else { + launchMessageHandler(msg) + } + } + .onFailure { _onError(it) throw it } @@ -477,8 +477,20 @@ public class StreamableHttpClientTransport( line.startsWith("retry:") -> line.substringAfter("retry:").trim().toLongOrNull()?.let { localServerRetryDelay = it.milliseconds } + } + } + return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) + } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StreamableHttpTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) } } - return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index 2b760899a..11c8660a3 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -199,27 +199,20 @@ public class StdioServerTransport private constructor( } } - private suspend fun processorPump() { - try { - for (chunk in readChannel) { - readBuffer.append(chunk) - while (true) { - val message = try { - readBuffer.readMessage() - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - _onError(e) - null - } ?: break - try { - _onMessage(message) - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error processing message" } - _onError(e) - } + private suspend fun processorPump() { + try { + for (chunk in readChannel) { + readBuffer.append(chunk) + while (true) { + val message = try { + readBuffer.readMessage() + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + _onError(e) + null + } ?: break + launchMessageHandler(message) } } } catch (e: CancellationException) { @@ -231,7 +224,7 @@ public class StdioServerTransport private constructor( } } - private suspend fun writerPump() { + private suspend fun writerPump() { try { for (message in writeChannel) { val json = serializeMessage(message) @@ -248,6 +241,20 @@ public class StdioServerTransport private constructor( } } + private fun launchMessageHandler(message: JSONRPCMessage) { + val s = effectiveScope ?: return + s.launch(handlerDispatcher + CoroutineName("StdioServerTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } + /** * Closes the transport. When called after [start], waits for in-flight outbound messages to be * flushed, releases the input source and output sink, cancels the internal scope when the diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index ae8a39d1d..d77afc08d 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -29,8 +29,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.REQUEST_TIMEO import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.job +import kotlinx.coroutines.supervisorScope import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext @@ -341,7 +344,9 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } /** - * Handles POST requests containing JSON-RPC messages + * Handles POST requests containing JSON-RPC messages. Batch messages are + * dispatched concurrently — the MCP batch spec does not mandate sequential + * processing, and individual POST requests were already concurrent. */ public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) { try { @@ -408,7 +413,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat val hasRequest = messages.any { it is JSONRPCRequest } if (!hasRequest) { call.respondNullable(status = HttpStatusCode.Accepted, message = null) - messages.forEach { message -> _onMessage(message) } + dispatchMessagesConcurrently(messages) return } @@ -437,19 +442,58 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } - messages.forEach { message -> _onMessage(message) } + // Batch dispatch with parallel message processing. Non-cancellation + // handler failures are reported after every already-started handler + // has completed, so one failure does not cancel sibling handlers. + dispatchMessagesConcurrently(messages) } catch (e: CancellationException) { throw e } catch (e: Exception) { - call.reject( - HttpStatusCode.BadRequest, - RPCError.ErrorCode.PARSE_ERROR, - "Parse error: ${e.message}", - ) + runCatching { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.PARSE_ERROR, + "Parse error: ${e.message}", + ) + }.onFailure { rejectFailure -> + _onError(rejectFailure) + } _onError(e) } } + private suspend fun dispatchMessagesConcurrently(messages: List) { + supervisorScope { + messages.map { message -> + async { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + _onError(e) + if (message is JSONRPCRequest) { + runCatching { + send( + JSONRPCError( + id = message.id, + error = RPCError( + code = RPCError.ErrorCode.INTERNAL_ERROR, + message = "Message processing error: ${e.message}", + ), + ), + TransportSendOptions(relatedRequestId = message.id), + ) + }.onFailure { sendError -> + _onError(sendError) + } + } + } + } + }.awaitAll() + } + } + /** Handles an HTTP GET request by establishing a standalone SSE stream for server-initiated notifications. */ public suspend fun handleGetRequest(session: ServerSSESession?, call: ApplicationCall) { // NOTE: enableJsonResponse only controls how POST responses are delivered (JSON vs. SSE). diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index c6370a1a6..743f561ca 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -4,6 +4,7 @@ import io.kotest.assertions.nondeterministic.eventually import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.withClue import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer @@ -169,7 +170,41 @@ class StdioServerTransportTest { server.start() finished.await() - readMessages shouldBe messages + readMessages.shouldContainExactlyInAnyOrder(messages) + } + + @Test + fun `should continue receiving while previous handler is suspended`() = runIntegrationTest { + val server = StdioServerTransport(input = bufferedInput, output = printOutput) + server.onError { error -> + throw error + } + + val firstMessage = PingRequest().toJSON() + val secondMessage = InitializedNotification().toJSON() + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + server.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + // Push messages before starting + inputWriter.write(serializeMessage(firstMessage)) + inputWriter.write(serializeMessage(secondMessage)) + inputWriter.flush() + + server.start() + + // second message should be processed even though first is still suspended + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + server.close() } // region: Exception handling diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index ae5f487ce..502e9f0a4 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -13,13 +13,16 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlinx.coroutines.yield /** @@ -121,19 +124,7 @@ public class ChannelTransport( for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } - - try { - _onMessage.invoke(message) - logger.trace { "Message processed successfully: ${message::class.simpleName}" } - } catch (e: CancellationException) { - // Let cancellation propagate immediately - logger.debug { "Cancellation requested during message processing" } - throw e - } catch (e: Exception) { - // Report other errors but continue processing - logger.warn(e) { "Error processing message: ${message::class.simpleName}" } - _onError.invoke(e) - } + launchMessageHandler(message) } logger.info { "ChannelTransport stopped: receive channel closed" } } catch (e: Exception) { @@ -170,13 +161,32 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } - sendChannel.close() - if (receiveChannel !== sendChannel) { - logger.debug { "Cancelling separate receive channel" } - receiveChannel.cancel() + withContext(NonCancellable) { + invokeOnCloseCallback() + sendChannel.close() + if (receiveChannel !== sendChannel) { + logger.debug { "Cancelling separate receive channel" } + receiveChannel.cancel() + } + val currentJob = currentCoroutineContext()[Job] + scope.coroutineContext[Job]?.children?.filter { it !== currentJob }?.forEach { it.join() } + scope.coroutineContext[Job]?.cancelAndJoin() } - scope.cancel() - scope.coroutineContext[Job]?.join() // Wait for cleanup logger.info { "ChannelTransport closed" } } + + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { + try { + _onMessage.invoke(message) + logger.trace { "Message processed successfully: ${message::class.simpleName}" } + } catch (e: CancellationException) { + logger.debug { "Cancellation requested during message processing" } + throw e + } catch (e: Throwable) { + logger.warn(e) { "Error processing message: ${message::class.simpleName}" } + _onError.invoke(e) + } + } + } } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 4aae932e9..4c6979617 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -2,6 +2,7 @@ package io.modelcontextprotocol.kotlin.sdk.testing import io.kotest.assertions.nondeterministic.eventually import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi @@ -57,7 +58,36 @@ class ChannelTransportTest { // Wait for messages to be processed messagesProcessed.await() - received.shouldContainExactly(msg1, msg2) + received.shouldContainExactlyInAnyOrder(msg1, msg2) + } + + @Test + fun `receive loop continues while previous handler is suspended`() = runTest { + val receiveChannel = Channel(Channel.UNLIMITED) + val transport = ChannelTransport(Channel(), receiveChannel) + + val firstMessage = JSONRPCRequest(RequestId.NumberId(1), "method1") + val secondMessage = JSONRPCRequest(RequestId.NumberId(2), "method2") + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + transport.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + transport.start() + + receiveChannel.send(firstMessage) + receiveChannel.send(secondMessage) + + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + transport.close() } @Test @@ -177,20 +207,16 @@ class ChannelTransportTest { val receiveChannel = Channel(Channel.UNLIMITED) val transport = ChannelTransport(Channel(), receiveChannel) - val allProcessed = CompletableDeferred() - val received = mutableListOf() - val errors = mutableListOf() + val received = Channel(Channel.UNLIMITED) + val errors = Channel(Channel.UNLIMITED) - transport.onError { errors.add(it) } + transport.onError { errors.trySend(it) } transport.onMessage { msg -> val id = ((msg as JSONRPCRequest).id as RequestId.NumberId).value.toInt() - received.add(id) + received.send(id) if (id == 2) { throw RuntimeException("Error processing message 2") } - if (received.size == 4) { - allProcessed.complete(Unit) - } } transport.start() @@ -201,11 +227,14 @@ class ChannelTransportTest { receiveChannel.send(JSONRPCRequest(RequestId.NumberId(3), "m3")) receiveChannel.send(JSONRPCRequest(RequestId.NumberId(4), "m4")) - allProcessed.await() + val receivedValues = buildList { repeat(4) { add(received.receive()) } } // All messages should be processed despite error in message 2 - received.shouldContainExactly(1, 2, 3, 4) - errors.size shouldBe 1 - errors[0].message shouldBe "Error processing message 2" + receivedValues.shouldContainExactlyInAnyOrder(1, 2, 3, 4) + val error = errors.receive() + error.message shouldBe "Error processing message 2" + // Ensure no extra errors were reported + errors.tryReceive().getOrNull() shouldBe null + transport.close() } } From 6c33d94e34034cd24b413eab7cdb36c364f6f015 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 21:03:10 +0800 Subject: [PATCH 14/22] fix: use CompletableDeferred for error capture in StdioServerTransportTest to avoid async race The 'should continue processing messages after handler throws' test was failing because launchMessageHandler dispatches the error callback asynchronously in a separate coroutine. The old mutableListOf approach raced against the error handler coroutine. Changed to CompletableDeferred + await() for correct async error capture. --- .../kotlin/sdk/server/StdioServerTransportTest.kt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index 743f561ca..6318a77d6 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -283,16 +283,16 @@ class StdioServerTransportTest { @ParameterizedTest(name = "[{index}] handler throws {0}") @MethodSource("handlerErrors") - fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { + fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { val server = StdioServerTransport(input = bufferedInput, output = printOutput) - val capturedErrors = mutableListOf() + val capturedError = CompletableDeferred() val receivedMessages = mutableListOf() val secondMessageProcessed = CompletableDeferred() val message1 = PingRequest().toJSON() val message2 = InitializedNotification().toJSON() - server.onError { capturedErrors.add(it) } + server.onError { capturedError.complete(it) } server.onMessage { message -> if (message == message1) { throw throwable @@ -310,7 +310,7 @@ class StdioServerTransportTest { secondMessageProcessed.await() - capturedErrors shouldContain throwable + capturedError.await() shouldBe throwable receivedMessages shouldBe listOf(message2) server.close() } From e43dc77dc8a6ac8fee8477438eb8cdc1566cff83 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 21:09:44 +0800 Subject: [PATCH 15/22] style: apply ktlint formatting --- .../client/StreamableHttpClientTransport.kt | 32 +++++++++---------- .../kotlin/sdk/server/StdioServerTransport.kt | 28 ++++++++-------- .../sdk/server/StdioServerTransportTest.kt | 4 +-- .../sdk/testing/ChannelTransportTest.kt | 2 +- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index a09f640e3..91419d66b 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -166,9 +166,9 @@ public class StreamableHttpClientTransport( when (response.contentType()?.withoutParameters()) { ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> - runCatching { McpJson.decodeFromString(json) } + runCatching { McpJson.decodeFromString(json) } .onSuccess { launchMessageHandler(it) } - .onFailure { + .onFailure { _onError(it) throw it } @@ -392,12 +392,12 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - launchMessageHandler(msg.copy(id = replayMessageId)) - } else { + launchMessageHandler(msg.copy(id = replayMessageId)) + } else { launchMessageHandler(msg) - } - } - .onFailure(_onError) + } + } + .onFailure(_onError) } "error" -> _onError(StreamableHttpError(null, event.data)) @@ -441,12 +441,12 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - launchMessageHandler(msg.copy(id = replayMessageId)) - } else { + launchMessageHandler(msg.copy(id = replayMessageId)) + } else { launchMessageHandler(msg) - } - } - .onFailure { + } + } + .onFailure { _onError(it) throw it } @@ -477,10 +477,10 @@ public class StreamableHttpClientTransport( line.startsWith("retry:") -> line.substringAfter("retry:").trim().toLongOrNull()?.let { localServerRetryDelay = it.milliseconds } - } - } - return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) - } + } + } + return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) + } private fun launchMessageHandler(message: JSONRPCMessage) { scope.launch(CoroutineName("StreamableHttpTransport.message#${hashCode()}")) { try { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index 11c8660a3..196c58a57 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -199,19 +199,19 @@ public class StdioServerTransport private constructor( } } - private suspend fun processorPump() { - try { - for (chunk in readChannel) { - readBuffer.append(chunk) - while (true) { - val message = try { - readBuffer.readMessage() - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - _onError(e) - null - } ?: break + private suspend fun processorPump() { + try { + for (chunk in readChannel) { + readBuffer.append(chunk) + while (true) { + val message = try { + readBuffer.readMessage() + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + _onError(e) + null + } ?: break launchMessageHandler(message) } } @@ -224,7 +224,7 @@ public class StdioServerTransport private constructor( } } - private suspend fun writerPump() { + private suspend fun writerPump() { try { for (message in writeChannel) { val json = serializeMessage(message) diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index 6318a77d6..53d30b095 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -170,7 +170,7 @@ class StdioServerTransportTest { server.start() finished.await() - readMessages.shouldContainExactlyInAnyOrder(messages) + readMessages.shouldContainExactlyInAnyOrder(messages) } @Test @@ -283,7 +283,7 @@ class StdioServerTransportTest { @ParameterizedTest(name = "[{index}] handler throws {0}") @MethodSource("handlerErrors") - fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { + fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { val server = StdioServerTransport(input = bufferedInput, output = printOutput) val capturedError = CompletableDeferred() val receivedMessages = mutableListOf() diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 4c6979617..70e176247 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -58,7 +58,7 @@ class ChannelTransportTest { // Wait for messages to be processed messagesProcessed.await() - received.shouldContainExactlyInAnyOrder(msg1, msg2) + received.shouldContainExactlyInAnyOrder(msg1, msg2) } @Test From 8b3a51e4ddfa86a36606279e4142b43f36d9e1d3 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 21:36:02 +0800 Subject: [PATCH 16/22] fix: address Copilot review - Channel capacity and closeResources child-join --- .../kotlin/sdk/shared/BaseTransportTest.kt | 2 +- .../kotlin/sdk/testing/ChannelTransport.kt | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index 80366ef39..367e61f69 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -47,7 +47,7 @@ abstract class BaseTransportTest { InitializedNotification().toJSON(), ) - val chan = Channel() + val chan = Channel(messages.size) transport.onMessage { message -> chan.send(message) diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index 502e9f0a4..1653c0807 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -20,7 +20,6 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel -import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import kotlinx.coroutines.yield @@ -168,8 +167,6 @@ public class ChannelTransport( logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } - val currentJob = currentCoroutineContext()[Job] - scope.coroutineContext[Job]?.children?.filter { it !== currentJob }?.forEach { it.join() } scope.coroutineContext[Job]?.cancelAndJoin() } logger.info { "ChannelTransport closed" } From f11c62399271ce42c8aaa55a17b757a7acea8b30 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 22:12:56 +0800 Subject: [PATCH 17/22] refactor: separate notification fire-and-forget from request await in StreamableHttpServerTransport --- .../client/StreamableHttpClientTransport.kt | 10 ++- .../server/StreamableHttpServerTransport.kt | 71 ++++++++++++------- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 91419d66b..bc204531c 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -35,11 +35,13 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlin.math.pow import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds @@ -217,9 +219,11 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } - invokeOnCloseCallback() - sseJob?.cancel() - scope.coroutineContext[Job]?.cancelAndJoin() + withContext(NonCancellable) { + invokeOnCloseCallback() + sseJob?.cancel() + scope.coroutineContext[Job]?.cancelAndJoin() + } } /** diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index d77afc08d..1b5654d72 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -28,11 +28,17 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.REQUEST_TIMEOUT import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.awaitCancellation +import kotlinx.coroutines.cancel import kotlinx.coroutines.job +import kotlinx.coroutines.launch import kotlinx.coroutines.supervisorScope import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -183,6 +189,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat private val sessionMutex = Mutex() private val streamMutex = Mutex() + private val handlerScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private companion object { const val STANDALONE_SSE_STREAM_ID = "_GET_stream" } @@ -299,6 +307,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat override suspend fun close() { withContext(NonCancellable) { + handlerScope.cancel() streamMutex.withLock { streamsMapping.values.forEach { try { @@ -410,10 +419,12 @@ public class StreamableHttpServerTransport(private val configuration: Configurat if (!validateSession(call) || !validateProtocolVersion(call)) return } - val hasRequest = messages.any { it is JSONRPCRequest } - if (!hasRequest) { + val (notifications, requests) = messages.partition { it !is JSONRPCRequest } + + notifications.forEach { launchNotificationHandler(it) } + + if (requests.isEmpty()) { call.respondNullable(status = HttpStatusCode.Accepted, message = null) - dispatchMessagesConcurrently(messages) return } @@ -442,10 +453,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } - // Batch dispatch with parallel message processing. Non-cancellation - // handler failures are reported after every already-started handler - // has completed, so one failure does not cancel sibling handlers. - dispatchMessagesConcurrently(messages) + dispatchRequestsConcurrently(requests) } catch (e: CancellationException) { throw e } catch (e: Exception) { @@ -462,31 +470,46 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } } - private suspend fun dispatchMessagesConcurrently(messages: List) { + private fun launchNotificationHandler(message: JSONRPCMessage) { + handlerScope.launch( + CoroutineName("StreamableHttpTransport.notification#${hashCode()}"), + ) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + _onError(e) + } + } + } + + private suspend fun dispatchRequestsConcurrently(requests: List) { + if (requests.isEmpty()) return + @Suppress("UNCHECKED_CAST") + val requestList = requests as List supervisorScope { - messages.map { message -> + requestList.map { request -> async { try { - _onMessage(message) + _onMessage(request) } catch (e: CancellationException) { throw e } catch (e: Exception) { _onError(e) - if (message is JSONRPCRequest) { - runCatching { - send( - JSONRPCError( - id = message.id, - error = RPCError( - code = RPCError.ErrorCode.INTERNAL_ERROR, - message = "Message processing error: ${e.message}", - ), + runCatching { + send( + JSONRPCError( + id = request.id, + error = RPCError( + code = RPCError.ErrorCode.INTERNAL_ERROR, + message = "Message processing error: ${e.message}", ), - TransportSendOptions(relatedRequestId = message.id), - ) - }.onFailure { sendError -> - _onError(sendError) - } + ), + TransportSendOptions(relatedRequestId = request.id), + ) + }.onFailure { sendError -> + _onError(sendError) } } } From 78b7f21a2cc039f103225724a059ebcab1b74486 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 22:34:32 +0800 Subject: [PATCH 18/22] fix: move handlerScope.cancel after stream cleanup; strengthen test assertion with Channel --- .../sdk/server/StreamableHttpServerTransport.kt | 16 +--------------- .../sdk/server/StdioServerTransportTest.kt | 16 ++++++++-------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 1b5654d72..ff7d03ee0 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -307,7 +307,6 @@ public class StreamableHttpServerTransport(private val configuration: Configurat override suspend fun close() { withContext(NonCancellable) { - handlerScope.cancel() streamMutex.withLock { streamsMapping.values.forEach { try { @@ -320,6 +319,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat requestToResponseMapping.clear() invokeOnCloseCallback() } + handlerScope.cancel() } } @@ -497,20 +497,6 @@ public class StreamableHttpServerTransport(private val configuration: Configurat throw e } catch (e: Exception) { _onError(e) - runCatching { - send( - JSONRPCError( - id = request.id, - error = RPCError( - code = RPCError.ErrorCode.INTERNAL_ERROR, - message = "Message processing error: ${e.message}", - ), - ), - TransportSendOptions(relatedRequestId = request.id), - ) - }.onFailure { sendError -> - _onError(sendError) - } } } }.awaitAll() diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index 53d30b095..2ccf8067a 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -18,6 +18,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.toJSON import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.SupervisorJob @@ -151,14 +152,9 @@ class StdioServerTransportTest { InitializedNotification().toJSON(), ) - val readMessages = mutableListOf() - val finished = CompletableDeferred() - + val received = Channel(messages.size) server.onMessage { message -> - readMessages.add(message) - if (message == messages[1]) { - finished.complete(Unit) - } + received.trySend(message) } // Push both messages before starting the server @@ -168,8 +164,12 @@ class StdioServerTransportTest { inputWriter.flush() server.start() - finished.await() + val readMessages = buildList { + repeat(messages.size) { + add(received.receive()) + } + } readMessages.shouldContainExactlyInAnyOrder(messages) } From 17835584ed51b6d15743731d00ce0fe4dfbbcdc1 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 22:37:54 +0800 Subject: [PATCH 19/22] fix: revert ChannelTransport.closeResources to scope.cancel()+join() without NonCancellable to fix JS/Wasm --- .../kotlin/sdk/testing/ChannelTransport.kt | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index 1653c0807..a100a9b1b 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -13,15 +13,14 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext import kotlinx.coroutines.yield /** @@ -160,15 +159,14 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } - withContext(NonCancellable) { - invokeOnCloseCallback() - sendChannel.close() - if (receiveChannel !== sendChannel) { - logger.debug { "Cancelling separate receive channel" } - receiveChannel.cancel() - } - scope.coroutineContext[Job]?.cancelAndJoin() + invokeOnCloseCallback() + sendChannel.close() + if (receiveChannel !== sendChannel) { + logger.debug { "Cancelling separate receive channel" } + receiveChannel.cancel() } + scope.cancel() + scope.coroutineContext[Job]?.join() logger.info { "ChannelTransport closed" } } From e7c671b6ef293a1ce180655df8b1cd1136950807 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Mon, 1 Jun 2026 22:52:50 +0800 Subject: [PATCH 20/22] fix: split ChannelTransport into eventLoop scope and handlerScope to resolve cross-platform cancellation conflicts --- .../kotlin/sdk/testing/ChannelTransport.kt | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index a100a9b1b..3f75417dd 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -15,7 +15,9 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.withContext import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel @@ -42,6 +44,7 @@ public class ChannelTransport( override val logger: KLogger = KotlinLogging.logger {} private val scope = CoroutineScope(SupervisorJob() + dispatcher) + private val handlerScope = CoroutineScope(SupervisorJob() + dispatcher) /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. @@ -159,19 +162,21 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } - invokeOnCloseCallback() - sendChannel.close() - if (receiveChannel !== sendChannel) { - logger.debug { "Cancelling separate receive channel" } - receiveChannel.cancel() + withContext(NonCancellable) { + invokeOnCloseCallback() + sendChannel.close() + if (receiveChannel !== sendChannel) { + logger.debug { "Cancelling separate receive channel" } + receiveChannel.cancel() + } + handlerScope.cancel() + handlerScope.coroutineContext[Job]?.join() } - scope.cancel() - scope.coroutineContext[Job]?.join() logger.info { "ChannelTransport closed" } } private fun launchMessageHandler(message: JSONRPCMessage) { - scope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { + handlerScope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { try { _onMessage.invoke(message) logger.trace { "Message processed successfully: ${message::class.simpleName}" } From 9cd31ad71c5bd86575894102cab3adaab1614e48 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Tue, 2 Jun 2026 12:26:49 +0800 Subject: [PATCH 21/22] Fix ChannelTransport startup message race --- .../kotlin/sdk/testing/ChannelTransport.kt | 21 +++++++---- .../sdk/testing/ChannelTransportTest.kt | 37 +++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index 3f75417dd..3d8087e40 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -23,7 +23,6 @@ import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.launch -import kotlinx.coroutines.yield /** * A transport implementation that uses Kotlin Coroutines Channels for asynchronous @@ -45,6 +44,8 @@ public class ChannelTransport( private val scope = CoroutineScope(SupervisorJob() + dispatcher) private val handlerScope = CoroutineScope(SupervisorJob() + dispatcher) + private val eventLoopStarted = CompletableDeferred() + private val messageProcessingReady = CompletableDeferred() /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. @@ -116,12 +117,10 @@ public class ChannelTransport( */ override suspend fun initialize() { logger.info { "ChannelTransport starting message processing" } - val started = CompletableDeferred() scope.launch(CoroutineName("ChannelTransport#${hashCode()}-event-loop")) { try { - // Signal that event loop has started, then yield to ensure we're ready - started.complete(Unit) - yield() + eventLoopStarted.complete(Unit) + messageProcessingReady.await() for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } @@ -130,8 +129,8 @@ public class ChannelTransport( logger.info { "ChannelTransport stopped: receive channel closed" } } catch (e: Exception) { // Only complete exceptionally if not already completed - if (!started.isCompleted) { - started.completeExceptionally(e) + if (!eventLoopStarted.isCompleted) { + eventLoopStarted.completeExceptionally(e) } throw e } finally { @@ -139,7 +138,12 @@ public class ChannelTransport( } } // Wait for the event loop to start - started.await() + eventLoopStarted.await() + } + + override suspend fun start() { + super.start() + messageProcessingReady.complete(Unit) } /** @@ -169,6 +173,7 @@ public class ChannelTransport( logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } + messageProcessingReady.complete(Unit) handlerScope.cancel() handlerScope.coroutineContext[Job]?.join() } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 70e176247..cb2ce7f24 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -6,15 +6,20 @@ import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi +import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Runnable import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ClosedSendChannelException import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest +import kotlin.coroutines.CoroutineContext import kotlin.test.Test import kotlin.time.Duration.Companion.seconds @@ -90,6 +95,32 @@ class ChannelTransportTest { transport.close() } + @Test + fun `queued messages are processed after start reaches operational state`() = runTest { + val sendChannel = Channel(Channel.UNLIMITED) + val receiveChannel = Channel(Channel.UNLIMITED) + val transport = ChannelTransport(sendChannel, receiveChannel, ImmediateDispatcher) + + val requestId = RequestId.NumberId(1) + val errors = Channel(Channel.UNLIMITED) + transport.onError { errors.trySend(it) } + transport.onMessage { + transport.send(JSONRPCResponse(requestId, EmptyResult())) + } + + try { + receiveChannel.send(JSONRPCRequest(requestId, "method1")) + transport.start() + + eventually(2.seconds) { + sendChannel.tryReceive().getOrNull() shouldBe JSONRPCResponse(requestId, EmptyResult()) + } + errors.tryReceive().getOrNull() shouldBe null + } finally { + transport.close() + } + } + @Test fun `start completes when channel closes`() = runTest { val receiveChannel = Channel(Channel.UNLIMITED) @@ -237,4 +268,10 @@ class ChannelTransportTest { errors.tryReceive().getOrNull() shouldBe null transport.close() } + + private object ImmediateDispatcher : CoroutineDispatcher() { + override fun dispatch(context: CoroutineContext, block: Runnable) { + block.run() + } + } } From 8e22f9db40e1b989c4a5a168dd786d4a19c14f82 Mon Sep 17 00:00:00 2001 From: ghostflyby Date: Tue, 2 Jun 2026 12:57:31 +0800 Subject: [PATCH 22/22] fix: hardened channel close order on mulitplatforms --- .../kotlin/sdk/testing/ChannelTransport.kt | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index 3d8087e40..47037c0bd 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -13,16 +13,16 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin -import kotlinx.coroutines.withContext import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * A transport implementation that uses Kotlin Coroutines Channels for asynchronous @@ -46,6 +46,7 @@ public class ChannelTransport( private val handlerScope = CoroutineScope(SupervisorJob() + dispatcher) private val eventLoopStarted = CompletableDeferred() private val messageProcessingReady = CompletableDeferred() + private var eventLoopJob: Job? = null /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. @@ -117,7 +118,7 @@ public class ChannelTransport( */ override suspend fun initialize() { logger.info { "ChannelTransport starting message processing" } - scope.launch(CoroutineName("ChannelTransport#${hashCode()}-event-loop")) { + eventLoopJob = scope.launch(CoroutineName("ChannelTransport#${hashCode()}-event-loop")) { try { eventLoopStarted.complete(Unit) messageProcessingReady.await() @@ -166,6 +167,7 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } + val closeCallerJob = currentCoroutineContext()[Job] withContext(NonCancellable) { invokeOnCloseCallback() sendChannel.close() @@ -174,8 +176,10 @@ public class ChannelTransport( receiveChannel.cancel() } messageProcessingReady.complete(Unit) - handlerScope.cancel() - handlerScope.coroutineContext[Job]?.join() + if (closeCallerJob == eventLoopJob) { + handlerScope.coroutineContext[Job]?.children?.forEach { it.join() } + } + handlerScope.coroutineContext[Job]?.cancelAndJoin() } logger.info { "ChannelTransport closed" } }