From fcf29f7f46e21f19f3d93e80090ba1365379ec4c Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Tue, 10 Feb 2026 22:58:27 -0700 Subject: [PATCH 1/8] feat!: Passes init payload result This automatically propagates the init payload result from the callback into the `onExecute` and `onSubscribe` closures. Since the init callback is usually used to determine authentication and authorization, this should be usable from our execution and subscription calls, and lifecycle is most easily managed within this package --- Sources/GraphQLTransportWS/Server.swift | 32 +++---- .../GraphQLTransportWSTests.swift | 93 +++++++++++++++---- 2 files changed, 87 insertions(+), 38 deletions(-) diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 751b863..e8b57e2 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -6,16 +6,17 @@ import GraphQL /// By default, there are no authorization checks public class Server< InitPayload: Equatable & Codable & Sendable, + InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable >: @unchecked Sendable where SubscriptionSequenceType.Element == GraphQLResult { // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - let onExecute: (GraphQLRequest) async throws -> GraphQLResult - let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType - var auth: (InitPayload) async throws -> Void + + let onInit: (InitPayload) async throws -> InitPayloadResult + let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult + let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType var onExit: () async throws -> Void = {} var onMessage: (String) async throws -> Void = { _ in } @@ -23,6 +24,7 @@ public class Server< var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } var initialized = false + var initResult: InitPayloadResult? let decoder = JSONDecoder() let encoder = GraphQLJSONEncoder() @@ -37,13 +39,14 @@ public class Server< /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. public init( messenger: Messenger, - onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType + onInit: @escaping (InitPayload) async throws -> InitPayloadResult, + onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, + onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType ) { self.messenger = messenger + self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe - auth = { _ in } messenger.onReceive { message in guard let messenger = self.messenger else { return } @@ -99,13 +102,6 @@ public class Server< subscriptionTasks.values.forEach { $0.cancel() } } - /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. - /// Throw from this closure to indicate that authorization has failed. - /// - Parameter callback: The callback to assign - public func auth(_ callback: @escaping (InitPayload) async throws -> Void) { - auth = callback - } - /// Define the callback run when the communication is shut down, either by the client or server /// - Parameter callback: The callback to assign public func onExit(_ callback: @escaping () -> Void) { @@ -137,7 +133,7 @@ public class Server< } do { - try await auth(connectionInitRequest.payload) + initResult = try await onInit(connectionInitRequest.payload) } catch { try await self.error(.unauthorized()) return @@ -148,7 +144,7 @@ public class Server< } private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws { - guard initialized else { + guard initialized, let initResult else { try await error(.notInitialized()) return } @@ -171,7 +167,7 @@ public class Server< if isStreaming { subscriptionTasks[id] = Task { do { - let stream = try await onSubscribe(graphQLRequest) + let stream = try await onSubscribe(graphQLRequest, initResult) for try await event in stream { try Task.checkCancellation() try await self.sendNext(event, id: id) @@ -186,7 +182,7 @@ public class Server< } } else { do { - let result = try await onExecute(graphQLRequest) + let result = try await onExecute(graphQLRequest, initResult) try await sendNext(result, id: id) try await sendComplete(id: id) } catch { diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index d8bc65b..4e2447a 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -8,9 +8,10 @@ import GraphQLTransportWS class GraphqlTransportWSTests: XCTestCase { var clientMessenger: TestMessenger! var serverMessenger: TestMessenger! - var server: Server>! - var context: TestContext! var subscribeReady: Bool! = false + + let context = TestContext() + let api = TestAPI() override func setUp() { // Point the client and server at each other @@ -18,32 +19,29 @@ class GraphqlTransportWSTests: XCTestCase { serverMessenger = TestMessenger() clientMessenger.other = serverMessenger serverMessenger.other = clientMessenger + } - let api = TestAPI() - let context = TestContext() - - server = .init( + /// Tests that trying to run methods before `connection_init` is not allowed + func testInitialize() async throws { + _ = Server>( messenger: serverMessenger, - onExecute: { graphQLRequest in - try await api.execute( + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( request: graphQLRequest.query, - context: context + context: self.context ) }, - onSubscribe: { graphQLRequest in - let subscription = try await api.subscribe( + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( request: graphQLRequest.query, - context: context + context: self.context ).get() self.subscribeReady = true return subscription } ) - self.context = context - } - - /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() async throws { + let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in @@ -78,9 +76,26 @@ class GraphqlTransportWSTests: XCTestCase { /// Tests that throwing in the authorization callback forces an unauthorized error func testAuthWithThrow() async throws { - server.auth { _ in - throw TestError.couldBeAnything - } + _ = Server>( + messenger: serverMessenger, + onInit: { _ in + throw TestError.couldBeAnything + }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( + request: graphQLRequest.query, + context: self.context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( + request: graphQLRequest.query, + context: self.context + ).get() + self.subscribeReady = true + return subscription + } + ) let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in @@ -111,6 +126,25 @@ class GraphqlTransportWSTests: XCTestCase { /// Tests a single-op conversation func testSingleOp() async throws { + _ = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( + request: graphQLRequest.query, + context: self.context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( + request: graphQLRequest.query, + context: self.context + ).get() + self.subscribeReady = true + return subscription + } + ) + let id = UUID().description let client = Client(messenger: clientMessenger) @@ -156,6 +190,25 @@ class GraphqlTransportWSTests: XCTestCase { /// Tests a streaming conversation func testStreaming() async throws { + _ = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( + request: graphQLRequest.query, + context: self.context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( + request: graphQLRequest.query, + context: self.context + ).get() + self.subscribeReady = true + return subscription + } + ) + let id = UUID().description var dataIndex = 1 From a726bd3cb55ac1b28481e53e5ae5231d9e78f18a Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 00:16:19 -0700 Subject: [PATCH 2/8] feat!: Messenger is send-only To support receiving messages, we added `listen` functions to server and client. This resolves the confusing ownership rules by avoiding `onReceive` callbacks in Messenger. --- README.md | 28 +++------- Sources/GraphQLTransportWS/Client.swift | 16 +++--- Sources/GraphQLTransportWS/Messenger.swift | 11 ++-- Sources/GraphQLTransportWS/Server.swift | 17 +++---- .../GraphQLTransportWSTests.swift | 51 +++++++++++++++---- .../Utils/TestMessenger.swift | 33 +++++------- 6 files changed, 79 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index f9d2d2b..3b99f6b 100644 --- a/README.md +++ b/README.md @@ -26,32 +26,21 @@ import GraphQLTransportWS /// Messenger wrapper for WebSockets class WebSocketMessenger: Messenger { - private weak var websocket: WebSocket? - private var onReceive: (String) async throws -> Void = { _ in } + let websocket: WebSocket? init(websocket: WebSocket) { self.websocket = websocket - websocket.onText { _, message in - try await self.onReceive(message) - } } func send(_ message: S) where S: Collection, S.Element == Character async throws { - guard let websocket = websocket else { return } try await websocket.send(message) } - func onReceive(callback: @escaping (String) async throws -> Void) { - self.onReceive = callback - } - func error(_ message: String, code: Int) async throws { - guard let websocket = websocket else { return } try await websocket.send("\(code): \(message)") } func close() async throws { - guard let websocket = websocket else { return } try await websocket.close() } } @@ -85,6 +74,12 @@ routes.webSocket( ) } ) + let incoming = AsyncStream { continuation in + websocket.onText { _, message in + continuation.yield(message) + } + } + try await server.listen(to: incoming) } ) ``` @@ -125,12 +120,3 @@ This example would require `connection_init` message from the client to look lik ``` If the `payload` field is not required on your server, you may make Server's generic declaration optional like `Server` - -## Memory Management - -Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket -implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the -objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger -callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must -be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes, -the Server will no longer respond to messages. diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 9b0e816..70846e1 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -2,9 +2,8 @@ import Foundation import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. -public class Client { - // We keep this weak because we strongly inject this object into the messenger callback - weak var messenger: Messenger? +public class Client: @unchecked Sendable { + let messenger: Messenger var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } var onNext: (NextResponse, Client) async throws -> Void = { _, _ in } @@ -23,7 +22,12 @@ public class Client { messenger: Messenger ) { self.messenger = messenger - messenger.onReceive { message in + } + + /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. + /// - Parameter incoming: The server message sequence that the client should react to. + public func listen(to incoming: A) async throws -> Void where A.Element == String { + for try await message in incoming { try await self.onMessage(message, self) // Detect and ignore error responses. @@ -108,7 +112,6 @@ public class Client { /// Send a `connection_init` request through the messenger public func sendConnectionInit(payload: InitPayload) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionInitRequest( payload: payload @@ -118,7 +121,6 @@ public class Client { /// Send a `subscribe` request through the messenger public func sendStart(payload: GraphQLRequest, id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( SubscribeRequest( payload: payload, @@ -129,7 +131,6 @@ public class Client { /// Send a `complete` request through the messenger public func sendStop(id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( CompleteRequest( id: id @@ -139,7 +140,6 @@ public class Client { /// Send an error through the messenger and close the connection private func error(_ error: GraphQLTransportWSError) async throws { - guard let messenger = messenger else { return } try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 3a9c157..0be9a35 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -1,15 +1,10 @@ import Foundation -/// Protocol for an object that can send and recieve messages. This allows mocking in tests -public protocol Messenger: AnyObject { - // AnyObject compliance requires that the implementing object is a class and we can reference it weakly - +/// Protocol for an object that can send messages. This allows mocking in tests +public protocol Messenger { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character - - /// Set the callback that should be run when a message is recieved - func onReceive(callback: @escaping (String) async throws -> Void) + func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index e8b57e2..bc22205 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -12,7 +12,7 @@ public class Server< SubscriptionSequenceType.Element == GraphQLResult { // We keep this weak because we strongly inject this object into the messenger callback - weak var messenger: Messenger? + let messenger: Messenger let onInit: (InitPayload) async throws -> InitPayloadResult let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult @@ -47,10 +47,12 @@ public class Server< self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe + } - messenger.onReceive { message in - guard let messenger = self.messenger else { return } - + /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. + /// - Parameter incoming: The client message sequence that the server should react to. + public func listen(to incoming: A) async throws -> Void where A.Element == String { + for try await message in incoming { try await self.onMessage(message) // Detect and ignore error responses. @@ -188,7 +190,7 @@ public class Server< } catch { try await sendError(error, id: id) } - try await messenger?.close() + try await messenger.close() } } @@ -208,7 +210,6 @@ public class Server< /// Send a `connection_ack` response through the messenger private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionAckResponse(payload: payload).toJSON(encoder) ) @@ -216,7 +217,6 @@ public class Server< /// Send a `next` response through the messenger private func sendNext(_ payload: GraphQLResult? = nil, id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( NextResponse( payload: payload, @@ -227,7 +227,6 @@ public class Server< /// Send a `complete` response through the messenger private func sendComplete(id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( CompleteResponse( id: id @@ -238,7 +237,6 @@ public class Server< /// Send an `error` response through the messenger private func sendError(_ errors: [Error], id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( ErrorResponse( errors, @@ -260,7 +258,6 @@ public class Server< /// Send an error through the messenger and close the connection private func error(_ error: GraphQLTransportWSError) async throws { - guard let messenger = messenger else { return } try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 4e2447a..30416d3 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -14,16 +14,13 @@ class GraphqlTransportWSTests: XCTestCase { let api = TestAPI() override func setUp() { - // Point the client and server at each other clientMessenger = TestMessenger() serverMessenger = TestMessenger() - clientMessenger.other = serverMessenger - serverMessenger.other = clientMessenger } /// Tests that trying to run methods before `connection_init` is not allowed func testInitialize() async throws { - _ = Server>( + let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in @@ -41,8 +38,16 @@ class GraphqlTransportWSTests: XCTestCase { return subscription } ) - let client = Client(messenger: clientMessenger) + let serverStream = serverMessenger.stream + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } + let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in continuation.yield(message) @@ -76,7 +81,7 @@ class GraphqlTransportWSTests: XCTestCase { /// Tests that throwing in the authorization callback forces an unauthorized error func testAuthWithThrow() async throws { - _ = Server>( + let server = Server>( messenger: serverMessenger, onInit: { _ in throw TestError.couldBeAnything @@ -96,8 +101,16 @@ class GraphqlTransportWSTests: XCTestCase { return subscription } ) - let client = Client(messenger: clientMessenger) + let clientStream = clientMessenger.stream + let serverStream = serverMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } + let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in continuation.yield(message) @@ -126,7 +139,7 @@ class GraphqlTransportWSTests: XCTestCase { /// Tests a single-op conversation func testSingleOp() async throws { - _ = Server>( + let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in @@ -144,10 +157,18 @@ class GraphqlTransportWSTests: XCTestCase { return subscription } ) + let client = Client(messenger: clientMessenger) + let clientStream = clientMessenger.stream + let serverStream = serverMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } let id = UUID().description - let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in client.onConnectionAck { _, client in try await client.sendStart( @@ -190,7 +211,7 @@ class GraphqlTransportWSTests: XCTestCase { /// Tests a streaming conversation func testStreaming() async throws { - _ = Server>( + let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in @@ -208,13 +229,21 @@ class GraphqlTransportWSTests: XCTestCase { return subscription } ) + let client = Client(messenger: clientMessenger) + let clientStream = clientMessenger.stream + let serverStream = serverMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } let id = UUID().description var dataIndex = 1 let dataIndexMax = 3 - let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in client.onConnectionAck { _, client in try await client.sendStart( diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index a35aa09..20b1e70 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -4,32 +4,27 @@ import Foundation @testable import GraphQLTransportWS /// Messenger for simple testing that doesn't require starting up a websocket server. -/// -/// Note that this only retains a weak reference to 'other', so the client should retain references -/// or risk them being deinitialized early -class TestMessenger: Messenger, @unchecked Sendable { - weak var other: TestMessenger? - var onReceive: (String) async throws -> Void = { _ in } - let queue: DispatchQueue = .init(label: "Test messenger") - - init() {} - - func send(_ message: S) async throws where S: Collection, S.Element == Character { - guard let other = other else { - return - } - try await other.onReceive(String(message)) +actor TestMessenger: Messenger { + /// An async stream of the messages sent through this messenger. + let stream: AsyncStream + private var continuation: AsyncStream.Continuation + + init() { + let (stream, continuation) = AsyncStream.makeStream() + self.stream = stream + self.continuation = continuation } - func onReceive(callback: @escaping (String) async throws -> Void) { - onReceive = callback + func send(_ message: S) async throws where S: Collection, S.Element == Character { + continuation.yield(String(message)) } func error(_ message: String, code: Int) async throws { - try await send("\(code): \(message)") + continuation.yield("\(code): \(message)") + continuation.finish() } func close() { - // This is a testing no-op + continuation.finish() } } From 3f125f521ca5430800e788fa2a59dc0fdf12b2ed Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 00:34:48 -0700 Subject: [PATCH 3/8] feat!: Actor conversion and Swift Testing Client and Server became actors to ensure sendability, and Messenger was marked sendable --- Sources/GraphQLTransportWS/Client.swift | 54 ++-- Sources/GraphQLTransportWS/Messenger.swift | 4 +- Sources/GraphQLTransportWS/Server.swift | 54 ++-- .../GraphQLTransportWSTests.swift | 254 +++++++++--------- 4 files changed, 158 insertions(+), 208 deletions(-) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 70846e1..db79afc 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -2,14 +2,14 @@ import Foundation import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. -public class Client: @unchecked Sendable { +public actor Client { let messenger: Messenger - var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } - var onNext: (NextResponse, Client) async throws -> Void = { _, _ in } - var onError: (ErrorResponse, Client) async throws -> Void = { _, _ in } - var onComplete: (CompleteResponse, Client) async throws -> Void = { _, _ in } - var onMessage: (String, Client) async throws -> Void = { _, _ in } + let onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void + let onNext: (NextResponse, Client) async throws -> Void + let onError: (ErrorResponse, Client) async throws -> Void + let onComplete: (CompleteResponse, Client) async throws -> Void + let onMessage: (String, Client) async throws -> Void let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() @@ -19,9 +19,19 @@ public class Client: @unchecked Sendable { /// - Parameters: /// - messenger: The messenger to bind the client to. public init( - messenger: Messenger + messenger: Messenger, + onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, + onNext: @escaping (NextResponse, Client) async throws -> Void = { _, _ in }, + onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, + onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in }, + onMessage: @escaping (String, Client) async throws -> Void = { _, _ in } ) { self.messenger = messenger + self.onConnectionAck = onConnectionAck + self.onNext = onNext + self.onError = onError + self.onComplete = onComplete + self.onMessage = onMessage } /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. @@ -80,36 +90,6 @@ public class Client: @unchecked Sendable { } } - /// Define the callback run on receipt of a `connection_ack` message - /// - Parameter callback: The callback to assign - public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) async throws -> Void) { - onConnectionAck = callback - } - - /// Define the callback run on receipt of a `next` message - /// - Parameter callback: The callback to assign - public func onNext(_ callback: @escaping (NextResponse, Client) async throws -> Void) { - onNext = callback - } - - /// Define the callback run on receipt of an `error` message - /// - Parameter callback: The callback to assign - public func onError(_ callback: @escaping (ErrorResponse, Client) async throws -> Void) { - onError = callback - } - - /// Define the callback run on receipt of a `complete` message - /// - Parameter callback: The callback to assign - public func onComplete(_ callback: @escaping (CompleteResponse, Client) async throws -> Void) { - onComplete = callback - } - - /// Define the callback run on receipt of any message - /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String, Client) async throws -> Void) { - onMessage = callback - } - /// Send a `connection_init` request through the messenger public func sendConnectionInit(payload: InitPayload) async throws { try await messenger.send( diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 0be9a35..aa4e1a9 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -1,7 +1,7 @@ import Foundation -/// Protocol for an object that can send messages. This allows mocking in tests -public protocol Messenger { +/// Protocol for an object that can send messages. This allows mocking in tests. +public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index bc22205..e7cf0e2 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -4,31 +4,28 @@ import GraphQL /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server< +public actor Server< InitPayload: Equatable & Codable & Sendable, InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable ->: @unchecked Sendable where +> where SubscriptionSequenceType.Element == GraphQLResult { - // We keep this weak because we strongly inject this object into the messenger callback let messenger: Messenger let onInit: (InitPayload) async throws -> InitPayloadResult let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType - - var onExit: () async throws -> Void = {} - var onMessage: (String) async throws -> Void = { _ in } - var onOperationComplete: (String) async throws -> Void = { _ in } - var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } - - var initialized = false - var initResult: InitPayloadResult? + + let onMessage: (String) async throws -> Void + let onOperationComplete: (String) async throws -> Void + let onOperationError: (String, [Error]) async throws -> Void let decoder = JSONDecoder() let encoder = GraphQLJSONEncoder() + private var initialized = false + private var initResult: InitPayloadResult? private var subscriptionTasks = [String: Task]() /// Create a new server @@ -37,16 +34,25 @@ public class Server< /// - messenger: The messenger to bind the server to. /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. + /// - onMessage: Optional callback run on every message event + /// - onOperationComplete: Optional callback run when an operation completes + /// - onOperationError: Optional callback run when an operation errors public init( messenger: Messenger, onInit: @escaping (InitPayload) async throws -> InitPayloadResult, onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType + onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, + onMessage: @escaping (String) async throws -> Void = { _ in }, + onOperationComplete: @escaping (String) async throws -> Void = { _ in }, + onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in }, ) { self.messenger = messenger self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe + self.onMessage = onMessage + self.onOperationComplete = onOperationComplete + self.onOperationError = onOperationError } /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. @@ -104,30 +110,6 @@ public class Server< subscriptionTasks.values.forEach { $0.cancel() } } - /// Define the callback run when the communication is shut down, either by the client or server - /// - Parameter callback: The callback to assign - public func onExit(_ callback: @escaping () -> Void) { - onExit = callback - } - - /// Define the callback run on receipt of any message - /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String) -> Void) { - onMessage = callback - } - - /// Define the callback run on the completion a full operation (query/mutation, end of subscription) - /// - Parameter callback: The callback to assign - public func onOperationComplete(_ callback: @escaping (String) -> Void) { - onOperationComplete = callback - } - - /// Define the callback to run on error of any full operation (failed query, interrupted subscription) - /// - Parameter callback: The callback to assign - public func onOperationError(_ callback: @escaping (String, [Error]) -> Void) { - onOperationError = callback - } - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { guard !initialized else { try await error(.tooManyInitializations()) diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 30416d3..77796d3 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -1,44 +1,48 @@ import Foundation import GraphQL -import XCTest +import Testing import GraphQLTransportWS -class GraphqlTransportWSTests: XCTestCase { - var clientMessenger: TestMessenger! - var serverMessenger: TestMessenger! - var subscribeReady: Bool! = false - - let context = TestContext() - let api = TestAPI() - - override func setUp() { - clientMessenger = TestMessenger() - serverMessenger = TestMessenger() - } +@Suite +struct GraphqlTransportWSTests { + let clientMessenger = TestMessenger() + let serverMessenger = TestMessenger() /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() async throws { + @Test func initialize() async throws { + let api = TestAPI() + let context = TestContext() let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true return subscription } ) - let client = Client(messenger: clientMessenger) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onMessage: { message, _ in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + } + ) let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { @@ -48,16 +52,6 @@ class GraphqlTransportWSTests: XCTestCase { try await client.listen(to: serverStream) } - let messageStream = AsyncThrowingStream { continuation in - client.onMessage { message, _ in - continuation.yield(message) - // Expect only one message - continuation.finish() - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - } try await client.sendStart( payload: GraphQLRequest( @@ -73,35 +67,47 @@ class GraphqlTransportWSTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages, + #expect( + messages == ["\(ErrorCode.notInitialized): Connection not initialized"] ) } /// Tests that throwing in the authorization callback forces an unauthorized error - func testAuthWithThrow() async throws { + @Test func authWithThrow() async throws { + let api = TestAPI() + let context = TestContext() let server = Server>( messenger: serverMessenger, onInit: { _ in throw TestError.couldBeAnything }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true return subscription } ) - let client = Client(messenger: clientMessenger) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onMessage: { message, _ in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + } + ) let clientStream = clientMessenger.stream let serverStream = serverMessenger.stream Task { @@ -110,17 +116,6 @@ class GraphqlTransportWSTests: XCTestCase { Task { try await client.listen(to: serverStream) } - - let messageStream = AsyncThrowingStream { continuation in - client.onMessage { message, _ in - continuation.yield(message) - // Expect only one message - continuation.finish() - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - } try await client.sendConnectionInit( payload: TokenInitPayload( @@ -131,46 +126,39 @@ class GraphqlTransportWSTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages, + #expect( + messages == ["\(ErrorCode.unauthorized): Unauthorized"] ) } /// Tests a single-op conversation - func testSingleOp() async throws { + @Test func singleOp() async throws { + let api = TestAPI() + let context = TestContext() + let id = UUID().description + let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true return subscription } ) - let client = Client(messenger: clientMessenger) - let clientStream = clientMessenger.stream - let serverStream = serverMessenger.stream - Task { - try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) - } - - let id = UUID().description - - let messageStream = AsyncThrowingStream { continuation in - client.onConnectionAck { _, client in + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onConnectionAck: { _, client in try await client.sendStart( payload: GraphQLRequest( query: """ @@ -181,16 +169,24 @@ class GraphqlTransportWSTests: XCTestCase { ), id: id ) + }, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onComplete: { _, _ in + messageContinuation.finish() + }, + onMessage: { message, _ in + messageContinuation.yield(message) } - client.onMessage { message, _ in - continuation.yield(message) - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - client.onComplete { _, _ in - continuation.finish() - } + ) + let clientStream = clientMessenger.stream + let serverStream = serverMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) } try await client.sendConnectionInit( @@ -202,50 +198,43 @@ class GraphqlTransportWSTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages.count, + #expect( + messages.count == 3, // 1 connection_ack, 1 next, 1 complete - "Messages: \(messages.description)" ) } /// Tests a streaming conversation - func testStreaming() async throws { + @Test func streaming() async throws { + let api = TestAPI() + let context = TestContext() + let id = UUID().description + var dataIndex = 1 + let dataIndexMax = 3 + + let (subscribeReadyStream, subscribeReadyContinuation) = AsyncStream.makeStream() let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true + subscribeReadyContinuation.finish() return subscription } ) - let client = Client(messenger: clientMessenger) - let clientStream = clientMessenger.stream - let serverStream = serverMessenger.stream - Task { - try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) - } - - let id = UUID().description - - var dataIndex = 1 - let dataIndexMax = 3 - - let messageStream = AsyncThrowingStream { continuation in - client.onConnectionAck { _, client in + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onConnectionAck: { _, client in try await client.sendStart( payload: GraphQLRequest( query: """ @@ -258,34 +247,35 @@ class GraphqlTransportWSTests: XCTestCase { ) // Wait until server has registered subscription - var i = 0 - while !self.subscribeReady, i < 50 { - usleep(1000) - i = i + 1 - } - if i == 50 { - XCTFail("Subscription timeout: Took longer than 50ms to set up") - } + for await _ in subscribeReadyStream {} - self.context.publisher.emit(event: "hello \(dataIndex)") - } - client.onNext { _, _ in + context.publisher.emit(event: "hello \(dataIndex)") + }, + onNext: { _, _ in dataIndex = dataIndex + 1 if dataIndex <= dataIndexMax { - self.context.publisher.emit(event: "hello \(dataIndex)") + context.publisher.emit(event: "hello \(dataIndex)") } else { - self.context.publisher.cancel() + context.publisher.cancel() } + }, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onComplete: { _, _ in + messageContinuation.finish() + }, + onMessage: { message, _ in + messageContinuation.yield(message) } - client.onMessage { message, _ in - continuation.yield(message) - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - client.onComplete { _, _ in - continuation.finish() - } + ) + let clientStream = clientMessenger.stream + let serverStream = serverMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) } try await client.sendConnectionInit( @@ -297,10 +287,8 @@ class GraphqlTransportWSTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages.count, - 5, // 1 connection_ack, 3 next, 1 complete - "Messages: \(messages.description)" + #expect( + messages.count == 5, // 1 connection_ack, 3 next, 1 complete ) } From dcb4e31f76fcb4e197c24f3152e4b6d17912c393 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 00:42:03 -0700 Subject: [PATCH 4/8] chore: Formatting --- Sources/GraphQLTransportWS/Client.swift | 32 +++++++++---------- .../GraphqlTransportWSError.swift | 2 +- Sources/GraphQLTransportWS/Messenger.swift | 2 +- Sources/GraphQLTransportWS/Server.swift | 32 +++++++++---------- .../GraphQLTransportWSTests.swift | 26 ++++++--------- .../Utils/TestMessenger.swift | 3 +- 6 files changed, 45 insertions(+), 52 deletions(-) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index db79afc..426d6ad 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -38,7 +38,7 @@ public actor Client { /// - Parameter incoming: The server message sequence that the client should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await self.onMessage(message, self) + try await onMessage(message, self) // Detect and ignore error responses. if message.starts(with: "44") { @@ -47,13 +47,13 @@ public actor Client { } guard let json = message.data(using: .utf8) else { - try await self.error(.invalidEncoding()) + try await error(.invalidEncoding()) return } let response: Response do { - response = try self.decoder.decode(Response.self, from: json) + response = try decoder.decode(Response.self, from: json) } catch { try await self.error(.noType()) return @@ -61,31 +61,31 @@ public actor Client { switch response.type { case .connectionAck: - guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .connectionAck)) + guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .connectionAck)) return } - try await self.onConnectionAck(connectionAckResponse, self) + try await onConnectionAck(connectionAckResponse, self) case .next: - guard let nextResponse = try? self.decoder.decode(NextResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .next)) + guard let nextResponse = try? decoder.decode(NextResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .next)) return } - try await self.onNext(nextResponse, self) + try await onNext(nextResponse, self) case .error: - guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .error)) + guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .error)) return } - try await self.onError(errorResponse, self) + try await onError(errorResponse, self) case .complete: - guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .complete)) + guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .complete)) return } - try await self.onComplete(completeResponse, self) + try await onComplete(completeResponse, self) default: - try await self.error(.invalidType()) + try await error(.invalidType()) } } } diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index 61775a1..2b200dc 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -89,7 +89,7 @@ struct GraphQLTransportWSError: Error { /// Error codes for miscellaneous issues public enum ErrorCode: Int, CustomStringConvertible, Sendable { - // Miscellaneous + /// Miscellaneous case miscellaneous = 4400 // Internal errors diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index aa4e1a9..f50baa6 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -4,7 +4,7 @@ import Foundation public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character + func send(_ message: S) async throws -> Void where S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index e7cf0e2..9b28d0d 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -12,11 +12,11 @@ public actor Server< SubscriptionSequenceType.Element == GraphQLResult { let messenger: Messenger - + let onInit: (InitPayload) async throws -> InitPayloadResult let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType - + let onMessage: (String) async throws -> Void let onOperationComplete: (String) async throws -> Void let onOperationError: (String, [Error]) async throws -> Void @@ -44,7 +44,7 @@ public actor Server< onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, onMessage: @escaping (String) async throws -> Void = { _ in }, onOperationComplete: @escaping (String) async throws -> Void = { _ in }, - onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in }, + onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { self.messenger = messenger self.onInit = onInit @@ -59,7 +59,7 @@ public actor Server< /// - Parameter incoming: The client message sequence that the server should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await self.onMessage(message) + try await onMessage(message) // Detect and ignore error responses. if message.starts(with: "44") { @@ -68,13 +68,13 @@ public actor Server< } guard let json = message.data(using: .utf8) else { - try await self.error(.invalidEncoding()) + try await error(.invalidEncoding()) return } let request: Request do { - request = try self.decoder.decode(Request.self, from: json) + request = try decoder.decode(Request.self, from: json) } catch { try await self.error(.noType()) return @@ -83,25 +83,25 @@ public actor Server< // handle incoming message switch request.type { case .connectionInit: - guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .connectionInit)) + guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .connectionInit)) return } - try await self.onConnectionInit(connectionInitRequest, messenger) + try await onConnectionInit(connectionInitRequest, messenger) case .subscribe: - guard let subscribeRequest = try? self.decoder.decode(SubscribeRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .subscribe)) + guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .subscribe)) return } - try await self.onSubscribe(subscribeRequest) + try await onSubscribe(subscribeRequest) case .complete: - guard let completeRequest = try? self.decoder.decode(CompleteRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .complete)) + guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .complete)) return } - try await self.onOperationComplete(completeRequest) + try await onOperationComplete(completeRequest) default: - try await self.error(.invalidType()) + try await error(.invalidType()) } } } diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 77796d3..3b108cd 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -1,9 +1,7 @@ import Foundation - import GraphQL -import Testing - import GraphQLTransportWS +import Testing @Suite struct GraphqlTransportWSTests { @@ -24,11 +22,10 @@ struct GraphqlTransportWSTests { ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - return subscription } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() @@ -51,7 +48,6 @@ struct GraphqlTransportWSTests { Task { try await client.listen(to: serverStream) } - try await client.sendStart( payload: GraphQLRequest( @@ -69,7 +65,7 @@ struct GraphqlTransportWSTests { } #expect( messages == - ["\(ErrorCode.notInitialized): Connection not initialized"] + ["\(ErrorCode.notInitialized): Connection not initialized"] ) } @@ -89,11 +85,10 @@ struct GraphqlTransportWSTests { ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - return subscription } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() @@ -128,7 +123,7 @@ struct GraphqlTransportWSTests { } #expect( messages == - ["\(ErrorCode.unauthorized): Unauthorized"] + ["\(ErrorCode.unauthorized): Unauthorized"] ) } @@ -137,7 +132,7 @@ struct GraphqlTransportWSTests { let api = TestAPI() let context = TestContext() let id = UUID().description - + let server = Server>( messenger: serverMessenger, onInit: { _ in }, @@ -148,11 +143,10 @@ struct GraphqlTransportWSTests { ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - return subscription } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() @@ -200,7 +194,7 @@ struct GraphqlTransportWSTests { } #expect( messages.count == - 3, // 1 connection_ack, 1 next, 1 complete + 3 // 1 connection_ack, 1 next, 1 complete ) } @@ -211,7 +205,7 @@ struct GraphqlTransportWSTests { let id = UUID().description var dataIndex = 1 let dataIndexMax = 3 - + let (subscribeReadyStream, subscribeReadyContinuation) = AsyncStream.makeStream() let server = Server>( messenger: serverMessenger, @@ -288,7 +282,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages.count == 5, // 1 connection_ack, 3 next, 1 complete + messages.count == 5 // 1 connection_ack, 3 next, 1 complete ) } diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index 20b1e70..90449c9 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -1,6 +1,5 @@ import Foundation - @testable import GraphQLTransportWS /// Messenger for simple testing that doesn't require starting up a websocket server. @@ -15,7 +14,7 @@ actor TestMessenger: Messenger { self.continuation = continuation } - func send(_ message: S) async throws where S: Collection, S.Element == Character { + func send(_ message: S) async throws where S.Element == Character { continuation.yield(String(message)) } From f1283d2f46d5c79cd9e3281ce3fbe34014c96f36 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 00:52:26 -0700 Subject: [PATCH 5/8] docs: Readme updates --- README.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 3b99f6b..b4789fd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # GraphQLTransportWS +[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FGraphQLSwift%2FGraphQLTransportWS%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/GraphQLSwift/GraphQLTransportWS) +[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FGraphQLSwift%2FGraphQLTransportWS%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/GraphQLSwift/GraphQLTransportWS) + This implements the [graphql-transport-ws WebSocket subprotocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). It is mainly intended for server support, but there is a basic client implementation included. @@ -14,7 +17,7 @@ Features: To use this package, include it in your `Package.swift` dependencies: ```swift -.package(url: "git@gitlab.com:PassiveLogic/platform/GraphQLTransportWS.git", from: "") +.package(url: "https://github.com/GraphQLSwift/GraphQLTransportWS", from: "") ``` Then create a class to implement the `Messenger` protocol. Here's an example using @@ -25,12 +28,8 @@ import WebSocketKit import GraphQLTransportWS /// Messenger wrapper for WebSockets -class WebSocketMessenger: Messenger { - let websocket: WebSocket? - - init(websocket: WebSocket) { - self.websocket = websocket - } +struct WebSocketMessenger: Messenger { + let websocket: WebSocket func send(_ message: S) where S: Collection, S.Element == Character async throws { try await websocket.send(message) From 7d54144af027d4364d0e4af359e4d65361df6196 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 11:31:15 -0700 Subject: [PATCH 6/8] feat!: Remove `onMessage` callback This is not necessary anymore, since you can just map the incoming AsyncStream --- Sources/GraphQLTransportWS/Client.swift | 7 +- Sources/GraphQLTransportWS/Server.swift | 6 -- .../GraphQLTransportWSTests.swift | 67 ++++++++++--------- 3 files changed, 36 insertions(+), 44 deletions(-) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 426d6ad..3baa6b0 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -9,7 +9,6 @@ public actor Client { let onNext: (NextResponse, Client) async throws -> Void let onError: (ErrorResponse, Client) async throws -> Void let onComplete: (CompleteResponse, Client) async throws -> Void - let onMessage: (String, Client) async throws -> Void let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() @@ -23,23 +22,19 @@ public actor Client { onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, onNext: @escaping (NextResponse, Client) async throws -> Void = { _, _ in }, onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, - onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in }, - onMessage: @escaping (String, Client) async throws -> Void = { _, _ in } + onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in } ) { self.messenger = messenger self.onConnectionAck = onConnectionAck self.onNext = onNext self.onError = onError self.onComplete = onComplete - self.onMessage = onMessage } /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. /// - Parameter incoming: The server message sequence that the client should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await onMessage(message, self) - // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 9b28d0d..51859ed 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -16,8 +16,6 @@ public actor Server< let onInit: (InitPayload) async throws -> InitPayloadResult let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType - - let onMessage: (String) async throws -> Void let onOperationComplete: (String) async throws -> Void let onOperationError: (String, [Error]) async throws -> Void @@ -42,7 +40,6 @@ public actor Server< onInit: @escaping (InitPayload) async throws -> InitPayloadResult, onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, - onMessage: @escaping (String) async throws -> Void = { _ in }, onOperationComplete: @escaping (String) async throws -> Void = { _ in }, onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { @@ -50,7 +47,6 @@ public actor Server< self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe - self.onMessage = onMessage self.onOperationComplete = onOperationComplete self.onOperationError = onOperationError } @@ -59,8 +55,6 @@ public actor Server< /// - Parameter incoming: The client message sequence that the server should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await onMessage(message) - // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 3b108cd..9597059 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -29,24 +29,23 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + return message + } let client = Client( messenger: clientMessenger, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) - }, - onMessage: { message, _ in - messageContinuation.yield(message) - // Expect only one message - messageContinuation.finish() + await clientMessenger.close() } ) - let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendStart( @@ -59,6 +58,7 @@ struct GraphqlTransportWSTests { ), id: UUID().uuidString ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) @@ -92,24 +92,23 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + return message + } let client = Client( messenger: clientMessenger, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) - }, - onMessage: { message, _ in - messageContinuation.yield(message) - // Expect only one message - messageContinuation.finish() + await clientMessenger.close() } ) let clientStream = clientMessenger.stream - let serverStream = serverMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -117,6 +116,7 @@ struct GraphqlTransportWSTests { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) @@ -150,6 +150,10 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + return message + } let client = Client( messenger: clientMessenger, onConnectionAck: { _, client in @@ -166,21 +170,17 @@ struct GraphqlTransportWSTests { }, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() }, onComplete: { _, _ in messageContinuation.finish() - }, - onMessage: { message, _ in - messageContinuation.yield(message) + await clientMessenger.close() } ) let clientStream = clientMessenger.stream - let serverStream = serverMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -188,6 +188,7 @@ struct GraphqlTransportWSTests { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) @@ -226,6 +227,11 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + // Used to extract the server messages + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + return message + } let client = Client( messenger: clientMessenger, onConnectionAck: { _, client in @@ -255,21 +261,17 @@ struct GraphqlTransportWSTests { }, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() }, onComplete: { _, _ in messageContinuation.finish() - }, - onMessage: { message, _ in - messageContinuation.yield(message) + await clientMessenger.close() } ) let clientStream = clientMessenger.stream - let serverStream = serverMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -277,6 +279,7 @@ struct GraphqlTransportWSTests { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) From a66eeadbde325a2ec228a2029bf71b6206f45639 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 11:31:54 -0700 Subject: [PATCH 7/8] fix!: Avoids closing server on single executions --- Sources/GraphQLTransportWS/Server.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 51859ed..fc457a1 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -166,7 +166,6 @@ public actor Server< } catch { try await sendError(error, id: id) } - try await messenger.close() } } From 8287e178dfb7633d3523e9c04151d3f95ea28689 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 23:21:51 -0700 Subject: [PATCH 8/8] docs: Updates public function documentation --- Sources/GraphQLTransportWS/Client.swift | 4 ++++ Sources/GraphQLTransportWS/Messenger.swift | 2 +- Sources/GraphQLTransportWS/Server.swift | 1 - 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 3baa6b0..bcac773 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -17,6 +17,10 @@ public actor Client { /// /// - Parameters: /// - messenger: The messenger to bind the client to. + /// - onConnectionAck: The callback run on receipt of a `connection_ack` message + /// - onNext: The callback run on receipt of a `next` message + /// - onError: The callback run on receipt of an `error` message + /// - onComplete: The callback run on receipt of a `complete` message public init( messenger: Messenger, onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index f50baa6..86ca9d4 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -1,6 +1,6 @@ import Foundation -/// Protocol for an object that can send messages. This allows mocking in tests. +/// Protocol for an object that can send messages. public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index fc457a1..10f0bfe 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -32,7 +32,6 @@ public actor Server< /// - messenger: The messenger to bind the server to. /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. - /// - onMessage: Optional callback run on every message event /// - onOperationComplete: Optional callback run when an operation completes /// - onOperationError: Optional callback run when an operation errors public init(