diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index a6c6e3df2..367e61f69 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -1,12 +1,13 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.kotest.assertions.nondeterministic.eventually +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.PingRequest import io.modelcontextprotocol.kotlin.sdk.types.toJSON -import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel import kotlin.test.fail import kotlin.time.Duration.Companion.seconds @@ -46,14 +47,10 @@ abstract class BaseTransportTest { InitializedNotification().toJSON(), ) - val readMessages = mutableListOf() - val finished = CompletableDeferred() + val chan = Channel(messages.size) transport.onMessage { message -> - readMessages.add(message) - if (message == messages.last()) { - finished.complete(Unit) - } + chan.send(message) } transport.start() @@ -62,9 +59,13 @@ abstract class BaseTransportTest { transport.send(message) } - finished.await() + val readMessages = buildList { + repeat(messages.size) { + add(chan.receive()) + } + } - messages shouldBe readMessages + readMessages shouldContainExactlyInAnyOrder messages transport.close() } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index da7d85b35..990a7ac6b 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -21,6 +21,7 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import java.util.concurrent.CopyOnWriteArrayList @@ -215,7 +216,7 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { val invalidUri = "test://nonexistent.txt" val exception = assertThrows { - runBlocking { + withContext(Dispatchers.Default) { client.readResource(ReadResourceRequest(ReadResourceRequestParams(uri = invalidUri))) } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt index 0295eb9a3..b0777b548 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt @@ -80,7 +80,7 @@ public class SseClientTransport( reconnectionTime = reconnectionTime, block = requestBuilder, ) - scope = CoroutineScope(session.coroutineContext + SupervisorJob()) + scope = CoroutineScope(session.coroutineContext + SupervisorJob(session.coroutineContext[Job])) job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) { collectMessages() @@ -163,17 +163,31 @@ public class SseClientTransport( } } - private suspend fun handleMessage(data: String) { + private fun handleMessage(data: String) { try { val message = McpJson.decodeFromString(data) - _onMessage(message) + launchMessageHandler(message) } catch (e: SerializationException) { _onError(e) } } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("SseMcpClientTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } + override suspend fun closeResources() { withContext(NonCancellable) { + invokeOnCloseCallback() job?.cancel() try { if (::session.isInitialized) session.cancel() diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 63f681c5c..cc599462a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -172,7 +172,7 @@ public class StdioClientTransport @JvmOverloads public constructor( .collect { event -> when (event) { is Event.JsonRpc -> { - handleJSONRPCMessage(event.message) + launchMessageHandler(event.message) } is Event.StderrEvent -> { @@ -264,14 +264,16 @@ public class StdioClientTransport @JvmOverloads public constructor( } } - private suspend fun handleJSONRPCMessage(msg: JSONRPCMessage) { - try { - _onMessage.invoke(msg) - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error processing message." } - runCatching { _onError.invoke(e) } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StdioClientTransport.message#${hashCode()}")) { + try { + _onMessage.invoke(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message." } + runCatching { _onError.invoke(e) } + } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index c27ed79e1..bc204531c 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -35,12 +35,13 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlin.math.pow import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds @@ -168,7 +169,7 @@ public class StreamableHttpClientTransport( when (response.contentType()?.withoutParameters()) { ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> runCatching { McpJson.decodeFromString(json) } - .onSuccess { _onMessage(it) } + .onSuccess { launchMessageHandler(it) } .onFailure { _onError(it) throw it @@ -218,8 +219,11 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } - sseJob?.cancelAndJoin() - scope.cancel() + withContext(NonCancellable) { + invokeOnCloseCallback() + sseJob?.cancel() + scope.coroutineContext[Job]?.cancelAndJoin() + } } /** @@ -392,9 +396,9 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) + launchMessageHandler(msg.copy(id = replayMessageId)) } else { - _onMessage(msg) + launchMessageHandler(msg) } } .onFailure(_onError) @@ -427,7 +431,7 @@ public class StreamableHttpClientTransport( var id: String? = null var eventName: String? = null - suspend fun dispatch(id: String?, eventName: String?, data: String) { + fun dispatch(id: String?, eventName: String?, data: String) { id?.let { localLastEventId = it hasPrimingEvent = true @@ -441,9 +445,9 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) + launchMessageHandler(msg.copy(id = replayMessageId)) } else { - _onMessage(msg) + launchMessageHandler(msg) } } .onFailure { @@ -481,4 +485,16 @@ public class StreamableHttpClientTransport( } return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("StreamableHttpTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } } diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt index b380f64a6..63011136f 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt @@ -2,6 +2,7 @@ package io.modelcontextprotocol.kotlin.sdk.client.streamable.http import io.kotest.assertions.fail import io.kotest.assertions.nondeterministic.eventually +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.collections.shouldHaveSize import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf @@ -611,11 +612,10 @@ class StreamableHttpClientTransportTest { receivedMessages shouldHaveSize 2 - val firstNotification = receivedMessages[0] as JSONRPCNotification - firstNotification.method shouldBe "notifications/progress" - - val secondNotification = receivedMessages[1] as JSONRPCNotification - secondNotification.method shouldBe "notifications/tools/list_changed" + receivedMessages + .filterIsInstance() + .map { it.method } + .shouldContainExactlyInAnyOrder("notifications/progress", "notifications/tools/list_changed") transport.close() } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index baf37233e..9a805b14c 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -10,10 +10,14 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpJson import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.InternalCoroutinesApi +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.job import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.coroutines.cancellation.CancellationException @@ -30,7 +34,7 @@ private val logger = KotlinLogging.logger {} @OptIn(ExperimentalAtomicApi::class) public abstract class WebSocketMcpTransport : AbstractTransport() { private val scope by lazy { - CoroutineScope(session.coroutineContext + SupervisorJob()) + CoroutineScope(session.coroutineContext + SupervisorJob(session.coroutineContext.job)) } private val initialized: AtomicBoolean = AtomicBoolean(false) @@ -73,8 +77,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } try { - val message = McpJson.decodeFromString(message.readText()) - _onMessage.invoke(message) + val parsedMessage = McpJson.decodeFromString(message.readText()) + launchMessageHandler(parsedMessage) } catch (e: CancellationException) { throw e } catch (e: Exception) { @@ -86,6 +90,11 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { @OptIn(InternalCoroutinesApi::class) session.coroutineContext.job.invokeOnCompletion { + // Cancel the scope when the session completes. For normal session + // completion the SupervisorJob parent does not auto-cancel children; + // for error/cancellation the propagation already cancels the scope + // job, making this cancel a no-op. + scope.cancel() if (it != null) { _onError.invoke(it) } else { @@ -94,6 +103,19 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } } + private fun launchMessageHandler(message: JSONRPCMessage) { + scope.launch(CoroutineName("WebSocketMcpTransport.message#${hashCode()}")) { + try { + _onMessage.invoke(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError.invoke(e) + } + } + } + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { logger.debug { "Sending message" } if (!initialized.load()) { @@ -109,7 +131,10 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } logger.debug { "Closing websocket session" } - session.close() - session.coroutineContext.job.join() + withContext(NonCancellable) { + invokeOnCloseCallback() + session.close() + scope.coroutineContext.job.cancelAndJoin() + } } } diff --git a/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt new file mode 100644 index 000000000..0920710ad --- /dev/null +++ b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransportTest.kt @@ -0,0 +1,150 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.ktor.websocket.Frame +import io.ktor.websocket.WebSocketExtension +import io.ktor.websocket.WebSocketSession +import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.test.runTest +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.cancellation.CancellationException +import kotlin.test.Test +import kotlin.test.assertSame +import kotlin.test.fail + +class WebSocketMcpTransportTest { + @Test + fun `should continue receiving while previous handler is suspended`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val firstMessage = PingRequest().toJSON() + val secondMessage = InitializedNotification().toJSON() + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + transport.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(firstMessage))) + session.incoming.send(Frame.Text(serializeMessage(secondMessage))) + + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + transport.close() + } + + @Test + fun `close cancels suspended message handler`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val handlerStarted = CompletableDeferred() + val handlerCancelled = CompletableDeferred() + + transport.onMessage { + handlerStarted.complete(Unit) + try { + CompletableDeferred().await() + } catch (e: CancellationException) { + handlerCancelled.complete(Unit) + throw e + } + } + + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(PingRequest().toJSON()))) + handlerStarted.await() + transport.close() + + handlerCancelled.await() + } + + @Test + fun `handler errors are reported via onError`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val expected = IllegalStateException("handler failed") + val capturedError = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { throw expected } + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(PingRequest().toJSON()))) + + assertSame(expected, capturedError.await()) + transport.close() + } + + @Test + fun `handler cancellation is not reported via onError`() = runTest { + val session = TestWebSocketSession(coroutineContext) + val transport = TestWebSocketMcpTransport(session) + val capturedError = CompletableDeferred() + val handlerStarted = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { + handlerStarted.complete(Unit) + throw CancellationException("handler cancelled") + } + transport.start() + + session.incoming.send(Frame.Text(serializeMessage(PingRequest().toJSON()))) + handlerStarted.await() + + if (capturedError.isCompleted) { + fail("CancellationException should not be reported via onError") + } + transport.close() + } + + private class TestWebSocketMcpTransport(override val session: WebSocketSession) : WebSocketMcpTransport() { + override suspend fun initializeSession() = Unit + } + + private class TestWebSocketSession(parentContext: CoroutineContext) : WebSocketSession { + private val job = Job(parentContext[Job]) + + override val coroutineContext: CoroutineContext = parentContext + job + override var masking: Boolean = false + override var maxFrameSize: Long = Long.MAX_VALUE + override val incoming = Channel(Channel.UNLIMITED) + override val outgoing = Channel(Channel.UNLIMITED) + override val extensions: List> = emptyList() + + override suspend fun flush() = Unit + + override suspend fun send(frame: Frame) { + outgoing.send(frame) + if (frame is Frame.Close) { + job.cancel() + incoming.close() + outgoing.close() + } + } + + @Deprecated( + "Use cancel() instead.", + replaceWith = ReplaceWith("cancel()", "kotlinx.coroutines.cancel"), + level = DeprecationLevel.ERROR, + ) + override fun terminate() { + job.cancel() + incoming.close() + outgoing.close() + } + } +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index a6d461833..41816d9b0 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -109,6 +109,9 @@ public class SseServerTransport(private val endpoint: String, private val sessio /** * Handle a client message, regardless of how it arrived. * This can be used to inform the server of messages that arrive via a means different from HTTP POST. + * + * Ktor provides concurrency across POST requests by invoking [handlePostMessage] in a separate + * [ApplicationCall] coroutine, so this path keeps handler errors synchronous to the current POST. */ public suspend fun handleMessage(message: String) { try { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index 2b760899a..196c58a57 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -212,14 +212,7 @@ public class StdioServerTransport private constructor( _onError(e) null } ?: break - try { - _onMessage(message) - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error processing message" } - _onError(e) - } + launchMessageHandler(message) } } } catch (e: CancellationException) { @@ -248,6 +241,20 @@ public class StdioServerTransport private constructor( } } + private fun launchMessageHandler(message: JSONRPCMessage) { + val s = effectiveScope ?: return + s.launch(handlerDispatcher + CoroutineName("StdioServerTransport.message#${hashCode()}")) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } + } + } + /** * Closes the transport. When called after [start], waits for in-flight outbound messages to be * flushed, releases the input source and output sink, cancels the internal scope when the diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index ae8a39d1d..ff7d03ee0 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -28,9 +28,18 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.REQUEST_TIMEOUT import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.awaitCancellation +import kotlinx.coroutines.cancel import kotlinx.coroutines.job +import kotlinx.coroutines.launch +import kotlinx.coroutines.supervisorScope import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext @@ -180,6 +189,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat private val sessionMutex = Mutex() private val streamMutex = Mutex() + private val handlerScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private companion object { const val STANDALONE_SSE_STREAM_ID = "_GET_stream" } @@ -308,6 +319,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat requestToResponseMapping.clear() invokeOnCloseCallback() } + handlerScope.cancel() } } @@ -341,7 +353,9 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } /** - * Handles POST requests containing JSON-RPC messages + * Handles POST requests containing JSON-RPC messages. Batch messages are + * dispatched concurrently — the MCP batch spec does not mandate sequential + * processing, and individual POST requests were already concurrent. */ public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) { try { @@ -405,10 +419,12 @@ public class StreamableHttpServerTransport(private val configuration: Configurat if (!validateSession(call) || !validateProtocolVersion(call)) return } - val hasRequest = messages.any { it is JSONRPCRequest } - if (!hasRequest) { + val (notifications, requests) = messages.partition { it !is JSONRPCRequest } + + notifications.forEach { launchNotificationHandler(it) } + + if (requests.isEmpty()) { call.respondNullable(status = HttpStatusCode.Accepted, message = null) - messages.forEach { message -> _onMessage(message) } return } @@ -437,19 +453,56 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } - messages.forEach { message -> _onMessage(message) } + dispatchRequestsConcurrently(requests) } catch (e: CancellationException) { throw e } catch (e: Exception) { - call.reject( - HttpStatusCode.BadRequest, - RPCError.ErrorCode.PARSE_ERROR, - "Parse error: ${e.message}", - ) + runCatching { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.PARSE_ERROR, + "Parse error: ${e.message}", + ) + }.onFailure { rejectFailure -> + _onError(rejectFailure) + } _onError(e) } } + private fun launchNotificationHandler(message: JSONRPCMessage) { + handlerScope.launch( + CoroutineName("StreamableHttpTransport.notification#${hashCode()}"), + ) { + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + _onError(e) + } + } + } + + private suspend fun dispatchRequestsConcurrently(requests: List) { + if (requests.isEmpty()) return + @Suppress("UNCHECKED_CAST") + val requestList = requests as List + supervisorScope { + requestList.map { request -> + async { + try { + _onMessage(request) + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + _onError(e) + } + } + }.awaitAll() + } + } + /** Handles an HTTP GET request by establishing a standalone SSE stream for server-initiated notifications. */ public suspend fun handleGetRequest(session: ServerSSESession?, call: ApplicationCall) { // NOTE: enableJsonResponse only controls how POST responses are delivered (JSON vs. SSE). diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index c6370a1a6..2ccf8067a 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -4,6 +4,7 @@ import io.kotest.assertions.nondeterministic.eventually import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.withClue import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer @@ -17,6 +18,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.toJSON import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.SupervisorJob @@ -150,14 +152,9 @@ class StdioServerTransportTest { InitializedNotification().toJSON(), ) - val readMessages = mutableListOf() - val finished = CompletableDeferred() - + val received = Channel(messages.size) server.onMessage { message -> - readMessages.add(message) - if (message == messages[1]) { - finished.complete(Unit) - } + received.trySend(message) } // Push both messages before starting the server @@ -167,9 +164,47 @@ class StdioServerTransportTest { inputWriter.flush() server.start() - finished.await() - readMessages shouldBe messages + val readMessages = buildList { + repeat(messages.size) { + add(received.receive()) + } + } + readMessages.shouldContainExactlyInAnyOrder(messages) + } + + @Test + fun `should continue receiving while previous handler is suspended`() = runIntegrationTest { + val server = StdioServerTransport(input = bufferedInput, output = printOutput) + server.onError { error -> + throw error + } + + val firstMessage = PingRequest().toJSON() + val secondMessage = InitializedNotification().toJSON() + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + server.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + // Push messages before starting + inputWriter.write(serializeMessage(firstMessage)) + inputWriter.write(serializeMessage(secondMessage)) + inputWriter.flush() + + server.start() + + // second message should be processed even though first is still suspended + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + server.close() } // region: Exception handling @@ -250,14 +285,14 @@ class StdioServerTransportTest { @MethodSource("handlerErrors") fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { val server = StdioServerTransport(input = bufferedInput, output = printOutput) - val capturedErrors = mutableListOf() + val capturedError = CompletableDeferred() val receivedMessages = mutableListOf() val secondMessageProcessed = CompletableDeferred() val message1 = PingRequest().toJSON() val message2 = InitializedNotification().toJSON() - server.onError { capturedErrors.add(it) } + server.onError { capturedError.complete(it) } server.onMessage { message -> if (message == message1) { throw throwable @@ -275,7 +310,7 @@ class StdioServerTransportTest { secondMessageProcessed.await() - capturedErrors shouldContain throwable + capturedError.await() shouldBe throwable receivedMessages shouldBe listOf(message2) server.close() } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index f6dda1520..bb5c64eae 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -32,6 +32,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION @@ -44,6 +45,9 @@ import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.Tool import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put @@ -347,6 +351,239 @@ class StreamableHttpServerTransportTest { assertEquals("second", secondMeta?.get("label")?.jsonPrimitive?.content)*/ } + @Test + fun `batched requests are handled concurrently before replying`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val firstRequest = JSONRPCRequest(id = RequestId("first"), method = Method.Defined.ToolsList.value) + val secondRequest = JSONRPCRequest(id = RequestId("second"), method = Method.Defined.ResourcesList.value) + val firstHandlerStarted = CompletableDeferred() + val releaseFirstHandler = CompletableDeferred() + val secondHandlerCompleted = CompletableDeferred() + val secondResult = ListResourcesResult(resources = emptyList()) + + transport.onMessage { message -> + if (message is JSONRPCRequest) { + when (message.id) { + firstRequest.id -> { + firstHandlerStarted.complete(Unit) + releaseFirstHandler.await() + transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + } + + secondRequest.id -> { + firstHandlerStarted.await() + transport.send(JSONRPCResponse(message.id, secondResult), null) + secondHandlerCompleted.complete(Unit) + } + + else -> transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + } + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstRequest, secondRequest))) + } + } + + secondHandlerCompleted.await() + releaseFirstHandler.complete(Unit) + + val response = responseDeferred.await() + response.status shouldBe HttpStatusCode.OK + response.body>() shouldContainAll listOf( + JSONRPCResponse(firstRequest.id, EmptyResult()), + JSONRPCResponse(secondRequest.id, secondResult), + ) + } + } + + @Test + fun `batched notifications are handled concurrently`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val firstNotification = InitializedNotification().toJSON() + val secondNotification = JSONRPCNotification("notifications/test") + val firstHandlerStarted = CompletableDeferred() + val releaseFirstHandler = CompletableDeferred() + val secondHandlerCompleted = CompletableDeferred() + + transport.onMessage { message -> + when (message) { + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult())) + + firstNotification -> { + firstHandlerStarted.complete(Unit) + releaseFirstHandler.await() + } + + secondNotification -> { + firstHandlerStarted.await() + secondHandlerCompleted.complete(Unit) + } + + else -> Unit + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstNotification, secondNotification))) + } + } + + secondHandlerCompleted.await() + releaseFirstHandler.complete(Unit) + + responseDeferred.await().status shouldBe HttpStatusCode.Accepted + } + } + + @Test + fun `batched notifications report handler error after sibling completes`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val expected = IllegalStateException("notification failed") + val capturedError = CompletableDeferred() + val firstNotification = InitializedNotification().toJSON() + val secondNotification = JSONRPCNotification("notifications/test") + val firstHandlerStarted = CompletableDeferred() + val releaseFirstHandler = CompletableDeferred() + val secondHandlerStarted = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { message -> + when (message) { + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult())) + + firstNotification -> { + firstHandlerStarted.complete(Unit) + releaseFirstHandler.await() + } + + secondNotification -> { + firstHandlerStarted.await() + secondHandlerStarted.complete(Unit) + throw expected + } + + else -> Unit + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstNotification, secondNotification))) + } + } + + secondHandlerStarted.await() + releaseFirstHandler.complete(Unit) + responseDeferred.await().status shouldBe HttpStatusCode.Accepted + } + + capturedError.await() shouldBe expected + } + + @Test + fun `batch handler error reports onError after response is committed`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val expected = IllegalStateException("request failed") + val capturedError = CompletableDeferred() + val firstRequest = JSONRPCRequest(id = RequestId("first"), method = Method.Defined.ToolsList.value) + val notification = JSONRPCNotification("notifications/test") + val requestResponseSent = CompletableDeferred() + val releaseNotificationHandler = CompletableDeferred() + + transport.onError { capturedError.complete(it) } + transport.onMessage { message -> + when (message) { + firstRequest -> { + transport.send(JSONRPCResponse(firstRequest.id, EmptyResult()), null) + requestResponseSent.complete(Unit) + } + + notification -> { + requestResponseSent.await() + releaseNotificationHandler.await() + throw expected + } + + is JSONRPCRequest -> transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + + else -> Unit + } + } + + configureTransportEndpoint(transport) + + val initResponse = client.post(path) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + coroutineScope { + val responseDeferred = async { + client.post(path) { + addStreamableHeaders() + header(MCP_SESSION_ID_HEADER, initResponse.headers[MCP_SESSION_ID_HEADER]) + setBody(encodeMessages(listOf(firstRequest, notification))) + } + } + + requestResponseSent.await() + releaseNotificationHandler.complete(Unit) + responseDeferred.await().status shouldBe HttpStatusCode.OK + } + + capturedError.await() shouldBe expected + } + @ParameterizedTest @MethodSource("invalidPayloads") fun `POST with a null or empty body returns an error`(payload: String?) = testApplication { diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index ae5f487ce..47037c0bd 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -13,14 +13,16 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.launch -import kotlinx.coroutines.yield +import kotlinx.coroutines.withContext /** * A transport implementation that uses Kotlin Coroutines Channels for asynchronous @@ -41,6 +43,10 @@ public class ChannelTransport( override val logger: KLogger = KotlinLogging.logger {} private val scope = CoroutineScope(SupervisorJob() + dispatcher) + private val handlerScope = CoroutineScope(SupervisorJob() + dispatcher) + private val eventLoopStarted = CompletableDeferred() + private val messageProcessingReady = CompletableDeferred() + private var eventLoopJob: Job? = null /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. @@ -112,34 +118,20 @@ public class ChannelTransport( */ override suspend fun initialize() { logger.info { "ChannelTransport starting message processing" } - val started = CompletableDeferred() - scope.launch(CoroutineName("ChannelTransport#${hashCode()}-event-loop")) { + eventLoopJob = scope.launch(CoroutineName("ChannelTransport#${hashCode()}-event-loop")) { try { - // Signal that event loop has started, then yield to ensure we're ready - started.complete(Unit) - yield() + eventLoopStarted.complete(Unit) + messageProcessingReady.await() for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } - - try { - _onMessage.invoke(message) - logger.trace { "Message processed successfully: ${message::class.simpleName}" } - } catch (e: CancellationException) { - // Let cancellation propagate immediately - logger.debug { "Cancellation requested during message processing" } - throw e - } catch (e: Exception) { - // Report other errors but continue processing - logger.warn(e) { "Error processing message: ${message::class.simpleName}" } - _onError.invoke(e) - } + launchMessageHandler(message) } logger.info { "ChannelTransport stopped: receive channel closed" } } catch (e: Exception) { // Only complete exceptionally if not already completed - if (!started.isCompleted) { - started.completeExceptionally(e) + if (!eventLoopStarted.isCompleted) { + eventLoopStarted.completeExceptionally(e) } throw e } finally { @@ -147,7 +139,12 @@ public class ChannelTransport( } } // Wait for the event loop to start - started.await() + eventLoopStarted.await() + } + + override suspend fun start() { + super.start() + messageProcessingReady.complete(Unit) } /** @@ -170,13 +167,35 @@ public class ChannelTransport( */ override suspend fun closeResources() { logger.info { "Closing ChannelTransport" } - sendChannel.close() - if (receiveChannel !== sendChannel) { - logger.debug { "Cancelling separate receive channel" } - receiveChannel.cancel() + val closeCallerJob = currentCoroutineContext()[Job] + withContext(NonCancellable) { + invokeOnCloseCallback() + sendChannel.close() + if (receiveChannel !== sendChannel) { + logger.debug { "Cancelling separate receive channel" } + receiveChannel.cancel() + } + messageProcessingReady.complete(Unit) + if (closeCallerJob == eventLoopJob) { + handlerScope.coroutineContext[Job]?.children?.forEach { it.join() } + } + handlerScope.coroutineContext[Job]?.cancelAndJoin() } - scope.cancel() - scope.coroutineContext[Job]?.join() // Wait for cleanup logger.info { "ChannelTransport closed" } } + + private fun launchMessageHandler(message: JSONRPCMessage) { + handlerScope.launch(CoroutineName("ChannelTransport#${hashCode()}-message")) { + try { + _onMessage.invoke(message) + logger.trace { "Message processed successfully: ${message::class.simpleName}" } + } catch (e: CancellationException) { + logger.debug { "Cancellation requested during message processing" } + throw e + } catch (e: Throwable) { + logger.warn(e) { "Error processing message: ${message::class.simpleName}" } + _onError.invoke(e) + } + } + } } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 4aae932e9..cb2ce7f24 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -2,18 +2,24 @@ package io.modelcontextprotocol.kotlin.sdk.testing import io.kotest.assertions.nondeterministic.eventually import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi +import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Runnable import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ClosedSendChannelException import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest +import kotlin.coroutines.CoroutineContext import kotlin.test.Test import kotlin.time.Duration.Companion.seconds @@ -57,7 +63,62 @@ class ChannelTransportTest { // Wait for messages to be processed messagesProcessed.await() - received.shouldContainExactly(msg1, msg2) + received.shouldContainExactlyInAnyOrder(msg1, msg2) + } + + @Test + fun `receive loop continues while previous handler is suspended`() = runTest { + val receiveChannel = Channel(Channel.UNLIMITED) + val transport = ChannelTransport(Channel(), receiveChannel) + + val firstMessage = JSONRPCRequest(RequestId.NumberId(1), "method1") + val secondMessage = JSONRPCRequest(RequestId.NumberId(2), "method2") + val releaseFirstHandler = CompletableDeferred() + val secondMessageProcessed = CompletableDeferred() + + transport.onMessage { message -> + if (message == firstMessage) { + releaseFirstHandler.await() + } + if (message == secondMessage) { + secondMessageProcessed.complete(Unit) + } + } + + transport.start() + + receiveChannel.send(firstMessage) + receiveChannel.send(secondMessage) + + secondMessageProcessed.await() + releaseFirstHandler.complete(Unit) + transport.close() + } + + @Test + fun `queued messages are processed after start reaches operational state`() = runTest { + val sendChannel = Channel(Channel.UNLIMITED) + val receiveChannel = Channel(Channel.UNLIMITED) + val transport = ChannelTransport(sendChannel, receiveChannel, ImmediateDispatcher) + + val requestId = RequestId.NumberId(1) + val errors = Channel(Channel.UNLIMITED) + transport.onError { errors.trySend(it) } + transport.onMessage { + transport.send(JSONRPCResponse(requestId, EmptyResult())) + } + + try { + receiveChannel.send(JSONRPCRequest(requestId, "method1")) + transport.start() + + eventually(2.seconds) { + sendChannel.tryReceive().getOrNull() shouldBe JSONRPCResponse(requestId, EmptyResult()) + } + errors.tryReceive().getOrNull() shouldBe null + } finally { + transport.close() + } } @Test @@ -177,20 +238,16 @@ class ChannelTransportTest { val receiveChannel = Channel(Channel.UNLIMITED) val transport = ChannelTransport(Channel(), receiveChannel) - val allProcessed = CompletableDeferred() - val received = mutableListOf() - val errors = mutableListOf() + val received = Channel(Channel.UNLIMITED) + val errors = Channel(Channel.UNLIMITED) - transport.onError { errors.add(it) } + transport.onError { errors.trySend(it) } transport.onMessage { msg -> val id = ((msg as JSONRPCRequest).id as RequestId.NumberId).value.toInt() - received.add(id) + received.send(id) if (id == 2) { throw RuntimeException("Error processing message 2") } - if (received.size == 4) { - allProcessed.complete(Unit) - } } transport.start() @@ -201,11 +258,20 @@ class ChannelTransportTest { receiveChannel.send(JSONRPCRequest(RequestId.NumberId(3), "m3")) receiveChannel.send(JSONRPCRequest(RequestId.NumberId(4), "m4")) - allProcessed.await() + val receivedValues = buildList { repeat(4) { add(received.receive()) } } // All messages should be processed despite error in message 2 - received.shouldContainExactly(1, 2, 3, 4) - errors.size shouldBe 1 - errors[0].message shouldBe "Error processing message 2" + receivedValues.shouldContainExactlyInAnyOrder(1, 2, 3, 4) + val error = errors.receive() + error.message shouldBe "Error processing message 2" + // Ensure no extra errors were reported + errors.tryReceive().getOrNull() shouldBe null + transport.close() + } + + private object ImmediateDispatcher : CoroutineDispatcher() { + override fun dispatch(context: CoroutineContext, block: Runnable) { + block.run() + } } }