diff --git a/README.md b/README.md index f9d2d2b..b4789fd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # GraphQLTransportWS +[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FGraphQLSwift%2FGraphQLTransportWS%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/GraphQLSwift/GraphQLTransportWS) +[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FGraphQLSwift%2FGraphQLTransportWS%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/GraphQLSwift/GraphQLTransportWS) + This implements the [graphql-transport-ws WebSocket subprotocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). It is mainly intended for server support, but there is a basic client implementation included. @@ -14,7 +17,7 @@ Features: To use this package, include it in your `Package.swift` dependencies: ```swift -.package(url: "git@gitlab.com:PassiveLogic/platform/GraphQLTransportWS.git", from: "") +.package(url: "https://github.com/GraphQLSwift/GraphQLTransportWS", from: "") ``` Then create a class to implement the `Messenger` protocol. Here's an example using @@ -25,33 +28,18 @@ import WebSocketKit import GraphQLTransportWS /// Messenger wrapper for WebSockets -class WebSocketMessenger: Messenger { - private weak var websocket: WebSocket? - private var onReceive: (String) async throws -> Void = { _ in } - - init(websocket: WebSocket) { - self.websocket = websocket - websocket.onText { _, message in - try await self.onReceive(message) - } - } +struct WebSocketMessenger: Messenger { + let websocket: WebSocket func send(_ message: S) where S: Collection, S.Element == Character async throws { - guard let websocket = websocket else { return } try await websocket.send(message) } - func onReceive(callback: @escaping (String) async throws -> Void) { - self.onReceive = callback - } - func error(_ message: String, code: Int) async throws { - guard let websocket = websocket else { return } try await websocket.send("\(code): \(message)") } func close() async throws { - guard let websocket = websocket else { return } try await websocket.close() } } @@ -85,6 +73,12 @@ routes.webSocket( ) } ) + let incoming = AsyncStream { continuation in + websocket.onText { _, message in + continuation.yield(message) + } + } + try await server.listen(to: incoming) } ) ``` @@ -125,12 +119,3 @@ This example would require `connection_init` message from the client to look lik ``` If the `payload` field is not required on your server, you may make Server's generic declaration optional like `Server` - -## Memory Management - -Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket -implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the -objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger -callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must -be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes, -the Server will no longer respond to messages. diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 9b0e816..bcac773 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -2,15 +2,13 @@ import Foundation import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. -public class Client { - // We keep this weak because we strongly inject this object into the messenger callback - weak var messenger: Messenger? +public actor Client { + let messenger: Messenger - var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } - var onNext: (NextResponse, Client) async throws -> Void = { _, _ in } - var onError: (ErrorResponse, Client) async throws -> Void = { _, _ in } - var onComplete: (CompleteResponse, Client) async throws -> Void = { _, _ in } - var onMessage: (String, Client) async throws -> Void = { _, _ in } + let onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void + let onNext: (NextResponse, Client) async throws -> Void + let onError: (ErrorResponse, Client) async throws -> Void + let onComplete: (CompleteResponse, Client) async throws -> Void let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() @@ -19,13 +17,28 @@ public class Client { /// /// - Parameters: /// - messenger: The messenger to bind the client to. + /// - onConnectionAck: The callback run on receipt of a `connection_ack` message + /// - onNext: The callback run on receipt of a `next` message + /// - onError: The callback run on receipt of an `error` message + /// - onComplete: The callback run on receipt of a `complete` message public init( - messenger: Messenger + messenger: Messenger, + onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, + onNext: @escaping (NextResponse, Client) async throws -> Void = { _, _ in }, + onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, + onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in } ) { self.messenger = messenger - messenger.onReceive { message in - try await self.onMessage(message, self) + self.onConnectionAck = onConnectionAck + self.onNext = onNext + self.onError = onError + self.onComplete = onComplete + } + /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. + /// - Parameter incoming: The server message sequence that the client should react to. + public func listen(to incoming: A) async throws -> Void where A.Element == String { + for try await message in incoming { // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages @@ -33,13 +46,13 @@ public class Client { } guard let json = message.data(using: .utf8) else { - try await self.error(.invalidEncoding()) + try await error(.invalidEncoding()) return } let response: Response do { - response = try self.decoder.decode(Response.self, from: json) + response = try decoder.decode(Response.self, from: json) } catch { try await self.error(.noType()) return @@ -47,68 +60,37 @@ public class Client { switch response.type { case .connectionAck: - guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .connectionAck)) + guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .connectionAck)) return } - try await self.onConnectionAck(connectionAckResponse, self) + try await onConnectionAck(connectionAckResponse, self) case .next: - guard let nextResponse = try? self.decoder.decode(NextResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .next)) + guard let nextResponse = try? decoder.decode(NextResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .next)) return } - try await self.onNext(nextResponse, self) + try await onNext(nextResponse, self) case .error: - guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .error)) + guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .error)) return } - try await self.onError(errorResponse, self) + try await onError(errorResponse, self) case .complete: - guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .complete)) + guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .complete)) return } - try await self.onComplete(completeResponse, self) + try await onComplete(completeResponse, self) default: - try await self.error(.invalidType()) + try await error(.invalidType()) } } } - /// Define the callback run on receipt of a `connection_ack` message - /// - Parameter callback: The callback to assign - public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) async throws -> Void) { - onConnectionAck = callback - } - - /// Define the callback run on receipt of a `next` message - /// - Parameter callback: The callback to assign - public func onNext(_ callback: @escaping (NextResponse, Client) async throws -> Void) { - onNext = callback - } - - /// Define the callback run on receipt of an `error` message - /// - Parameter callback: The callback to assign - public func onError(_ callback: @escaping (ErrorResponse, Client) async throws -> Void) { - onError = callback - } - - /// Define the callback run on receipt of a `complete` message - /// - Parameter callback: The callback to assign - public func onComplete(_ callback: @escaping (CompleteResponse, Client) async throws -> Void) { - onComplete = callback - } - - /// Define the callback run on receipt of any message - /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String, Client) async throws -> Void) { - onMessage = callback - } - /// Send a `connection_init` request through the messenger public func sendConnectionInit(payload: InitPayload) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionInitRequest( payload: payload @@ -118,7 +100,6 @@ public class Client { /// Send a `subscribe` request through the messenger public func sendStart(payload: GraphQLRequest, id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( SubscribeRequest( payload: payload, @@ -129,7 +110,6 @@ public class Client { /// Send a `complete` request through the messenger public func sendStop(id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( CompleteRequest( id: id @@ -139,7 +119,6 @@ public class Client { /// Send an error through the messenger and close the connection private func error(_ error: GraphQLTransportWSError) async throws { - guard let messenger = messenger else { return } try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index 61775a1..2b200dc 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -89,7 +89,7 @@ struct GraphQLTransportWSError: Error { /// Error codes for miscellaneous issues public enum ErrorCode: Int, CustomStringConvertible, Sendable { - // Miscellaneous + /// Miscellaneous case miscellaneous = 4400 // Internal errors diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 3a9c157..86ca9d4 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -1,15 +1,10 @@ import Foundation -/// Protocol for an object that can send and recieve messages. This allows mocking in tests -public protocol Messenger: AnyObject { - // AnyObject compliance requires that the implementing object is a class and we can reference it weakly - +/// Protocol for an object that can send messages. +public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character - - /// Set the callback that should be run when a message is recieved - func onReceive(callback: @escaping (String) async throws -> Void) + func send(_ message: S) async throws -> Void where S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 751b863..10f0bfe 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -4,29 +4,26 @@ import GraphQL /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server< +public actor Server< InitPayload: Equatable & Codable & Sendable, + InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable ->: @unchecked Sendable where +> where SubscriptionSequenceType.Element == GraphQLResult { - // We keep this weak because we strongly inject this object into the messenger callback - weak var messenger: Messenger? + let messenger: Messenger - let onExecute: (GraphQLRequest) async throws -> GraphQLResult - let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType - var auth: (InitPayload) async throws -> Void - - var onExit: () async throws -> Void = {} - var onMessage: (String) async throws -> Void = { _ in } - var onOperationComplete: (String) async throws -> Void = { _ in } - var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } - - var initialized = false + let onInit: (InitPayload) async throws -> InitPayloadResult + let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult + let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType + let onOperationComplete: (String) async throws -> Void + let onOperationError: (String, [Error]) async throws -> Void let decoder = JSONDecoder() let encoder = GraphQLJSONEncoder() + private var initialized = false + private var initResult: InitPayloadResult? private var subscriptionTasks = [String: Task]() /// Create a new server @@ -35,21 +32,28 @@ public class Server< /// - messenger: The messenger to bind the server to. /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. + /// - onOperationComplete: Optional callback run when an operation completes + /// - onOperationError: Optional callback run when an operation errors public init( messenger: Messenger, - onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType + onInit: @escaping (InitPayload) async throws -> InitPayloadResult, + onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, + onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, + onOperationComplete: @escaping (String) async throws -> Void = { _ in }, + onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { self.messenger = messenger + self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe - auth = { _ in } - - messenger.onReceive { message in - guard let messenger = self.messenger else { return } - - try await self.onMessage(message) + self.onOperationComplete = onOperationComplete + self.onOperationError = onOperationError + } + /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. + /// - Parameter incoming: The client message sequence that the server should react to. + public func listen(to incoming: A) async throws -> Void where A.Element == String { + for try await message in incoming { // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages @@ -57,13 +61,13 @@ public class Server< } guard let json = message.data(using: .utf8) else { - try await self.error(.invalidEncoding()) + try await error(.invalidEncoding()) return } let request: Request do { - request = try self.decoder.decode(Request.self, from: json) + request = try decoder.decode(Request.self, from: json) } catch { try await self.error(.noType()) return @@ -72,25 +76,25 @@ public class Server< // handle incoming message switch request.type { case .connectionInit: - guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .connectionInit)) + guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .connectionInit)) return } - try await self.onConnectionInit(connectionInitRequest, messenger) + try await onConnectionInit(connectionInitRequest, messenger) case .subscribe: - guard let subscribeRequest = try? self.decoder.decode(SubscribeRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .subscribe)) + guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .subscribe)) return } - try await self.onSubscribe(subscribeRequest) + try await onSubscribe(subscribeRequest) case .complete: - guard let completeRequest = try? self.decoder.decode(CompleteRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .complete)) + guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .complete)) return } - try await self.onOperationComplete(completeRequest) + try await onOperationComplete(completeRequest) default: - try await self.error(.invalidType()) + try await error(.invalidType()) } } } @@ -99,37 +103,6 @@ public class Server< subscriptionTasks.values.forEach { $0.cancel() } } - /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. - /// Throw from this closure to indicate that authorization has failed. - /// - Parameter callback: The callback to assign - public func auth(_ callback: @escaping (InitPayload) async throws -> Void) { - auth = callback - } - - /// Define the callback run when the communication is shut down, either by the client or server - /// - Parameter callback: The callback to assign - public func onExit(_ callback: @escaping () -> Void) { - onExit = callback - } - - /// Define the callback run on receipt of any message - /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String) -> Void) { - onMessage = callback - } - - /// Define the callback run on the completion a full operation (query/mutation, end of subscription) - /// - Parameter callback: The callback to assign - public func onOperationComplete(_ callback: @escaping (String) -> Void) { - onOperationComplete = callback - } - - /// Define the callback to run on error of any full operation (failed query, interrupted subscription) - /// - Parameter callback: The callback to assign - public func onOperationError(_ callback: @escaping (String, [Error]) -> Void) { - onOperationError = callback - } - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { guard !initialized else { try await error(.tooManyInitializations()) @@ -137,7 +110,7 @@ public class Server< } do { - try await auth(connectionInitRequest.payload) + initResult = try await onInit(connectionInitRequest.payload) } catch { try await self.error(.unauthorized()) return @@ -148,7 +121,7 @@ public class Server< } private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws { - guard initialized else { + guard initialized, let initResult else { try await error(.notInitialized()) return } @@ -171,7 +144,7 @@ public class Server< if isStreaming { subscriptionTasks[id] = Task { do { - let stream = try await onSubscribe(graphQLRequest) + let stream = try await onSubscribe(graphQLRequest, initResult) for try await event in stream { try Task.checkCancellation() try await self.sendNext(event, id: id) @@ -186,13 +159,12 @@ public class Server< } } else { do { - let result = try await onExecute(graphQLRequest) + let result = try await onExecute(graphQLRequest, initResult) try await sendNext(result, id: id) try await sendComplete(id: id) } catch { try await sendError(error, id: id) } - try await messenger?.close() } } @@ -212,7 +184,6 @@ public class Server< /// Send a `connection_ack` response through the messenger private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionAckResponse(payload: payload).toJSON(encoder) ) @@ -220,7 +191,6 @@ public class Server< /// Send a `next` response through the messenger private func sendNext(_ payload: GraphQLResult? = nil, id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( NextResponse( payload: payload, @@ -231,7 +201,6 @@ public class Server< /// Send a `complete` response through the messenger private func sendComplete(id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( CompleteResponse( id: id @@ -242,7 +211,6 @@ public class Server< /// Send an `error` response through the messenger private func sendError(_ errors: [Error], id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( ErrorResponse( errors, @@ -264,7 +232,6 @@ public class Server< /// Send an error through the messenger and close the connection private func error(_ error: GraphQLTransportWSError) async throws { - guard let messenger = messenger else { return } try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index d8bc65b..9597059 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -1,59 +1,51 @@ import Foundation - import GraphQL -import XCTest - import GraphQLTransportWS +import Testing -class GraphqlTransportWSTests: XCTestCase { - var clientMessenger: TestMessenger! - var serverMessenger: TestMessenger! - var server: Server>! - var context: TestContext! - var subscribeReady: Bool! = false - - override func setUp() { - // Point the client and server at each other - clientMessenger = TestMessenger() - serverMessenger = TestMessenger() - clientMessenger.other = serverMessenger - serverMessenger.other = clientMessenger +@Suite +struct GraphqlTransportWSTests { + let clientMessenger = TestMessenger() + let serverMessenger = TestMessenger() + /// Tests that trying to run methods before `connection_init` is not allowed + @Test func initialize() async throws { let api = TestAPI() let context = TestContext() - - server = .init( + let server = Server>( messenger: serverMessenger, - onExecute: { graphQLRequest in + onInit: { _ in }, + onExecute: { graphQLRequest, _ in try await api.execute( request: graphQLRequest.query, context: context ) }, - onSubscribe: { graphQLRequest in - let subscription = try await api.subscribe( + onSubscribe: { graphQLRequest, _ in + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - self.subscribeReady = true - return subscription } ) - self.context = context - } - - /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() async throws { - let client = Client(messenger: clientMessenger) - let messageStream = AsyncThrowingStream { continuation in - client.onMessage { message, _ in - continuation.yield(message) - // Expect only one message - continuation.finish() - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + return message + } + let client = Client( + messenger: clientMessenger, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() } + ) + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + await serverMessenger.close() } try await client.sendStart( @@ -66,32 +58,57 @@ class GraphqlTransportWSTests: XCTestCase { ), id: UUID().uuidString ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages, - ["\(ErrorCode.notInitialized): Connection not initialized"] + #expect( + messages == + ["\(ErrorCode.notInitialized): Connection not initialized"] ) } /// Tests that throwing in the authorization callback forces an unauthorized error - func testAuthWithThrow() async throws { - server.auth { _ in - throw TestError.couldBeAnything - } - - let client = Client(messenger: clientMessenger) - let messageStream = AsyncThrowingStream { continuation in - client.onMessage { message, _ in - continuation.yield(message) - // Expect only one message - continuation.finish() + @Test func authWithThrow() async throws { + let api = TestAPI() + let context = TestContext() + let server = Server>( + messenger: serverMessenger, + onInit: { _ in + throw TestError.couldBeAnything + }, + onExecute: { graphQLRequest, _ in + try await api.execute( + request: graphQLRequest.query, + context: context + ) + }, + onSubscribe: { graphQLRequest, _ in + try await api.subscribe( + request: graphQLRequest.query, + context: context + ).get() } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) + ) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + return message + } + let client = Client( + messenger: clientMessenger, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() } + ) + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -99,23 +116,47 @@ class GraphqlTransportWSTests: XCTestCase { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages, - ["\(ErrorCode.unauthorized): Unauthorized"] + #expect( + messages == + ["\(ErrorCode.unauthorized): Unauthorized"] ) } /// Tests a single-op conversation - func testSingleOp() async throws { + @Test func singleOp() async throws { + let api = TestAPI() + let context = TestContext() let id = UUID().description - let client = Client(messenger: clientMessenger) - let messageStream = AsyncThrowingStream { continuation in - client.onConnectionAck { _, client in + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await api.execute( + request: graphQLRequest.query, + context: context + ) + }, + onSubscribe: { graphQLRequest, _ in + try await api.subscribe( + request: graphQLRequest.query, + context: context + ).get() + } + ) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + return message + } + let client = Client( + messenger: clientMessenger, + onConnectionAck: { _, client in try await client.sendStart( payload: GraphQLRequest( query: """ @@ -126,16 +167,20 @@ class GraphqlTransportWSTests: XCTestCase { ), id: id ) + }, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() + }, + onComplete: { _, _ in + messageContinuation.finish() + await clientMessenger.close() } - client.onMessage { message, _ in - continuation.yield(message) - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - client.onComplete { _, _ in - continuation.finish() - } + ) + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -143,27 +188,53 @@ class GraphqlTransportWSTests: XCTestCase { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages.count, - 3, // 1 connection_ack, 1 next, 1 complete - "Messages: \(messages.description)" + #expect( + messages.count == + 3 // 1 connection_ack, 1 next, 1 complete ) } /// Tests a streaming conversation - func testStreaming() async throws { + @Test func streaming() async throws { + let api = TestAPI() + let context = TestContext() let id = UUID().description - var dataIndex = 1 let dataIndexMax = 3 - let client = Client(messenger: clientMessenger) - let messageStream = AsyncThrowingStream { continuation in - client.onConnectionAck { _, client in + let (subscribeReadyStream, subscribeReadyContinuation) = AsyncStream.makeStream() + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await api.execute( + request: graphQLRequest.query, + context: context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await api.subscribe( + request: graphQLRequest.query, + context: context + ).get() + subscribeReadyContinuation.finish() + return subscription + } + ) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + // Used to extract the server messages + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + return message + } + let client = Client( + messenger: clientMessenger, + onConnectionAck: { _, client in try await client.sendStart( payload: GraphQLRequest( query: """ @@ -176,34 +247,31 @@ class GraphqlTransportWSTests: XCTestCase { ) // Wait until server has registered subscription - var i = 0 - while !self.subscribeReady, i < 50 { - usleep(1000) - i = i + 1 - } - if i == 50 { - XCTFail("Subscription timeout: Took longer than 50ms to set up") - } + for await _ in subscribeReadyStream {} - self.context.publisher.emit(event: "hello \(dataIndex)") - } - client.onNext { _, _ in + context.publisher.emit(event: "hello \(dataIndex)") + }, + onNext: { _, _ in dataIndex = dataIndex + 1 if dataIndex <= dataIndexMax { - self.context.publisher.emit(event: "hello \(dataIndex)") + context.publisher.emit(event: "hello \(dataIndex)") } else { - self.context.publisher.cancel() + context.publisher.cancel() } + }, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() + }, + onComplete: { _, _ in + messageContinuation.finish() + await clientMessenger.close() } - client.onMessage { message, _ in - continuation.yield(message) - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - client.onComplete { _, _ in - continuation.finish() - } + ) + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -211,14 +279,13 @@ class GraphqlTransportWSTests: XCTestCase { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages.count, - 5, // 1 connection_ack, 3 next, 1 complete - "Messages: \(messages.description)" + #expect( + messages.count == 5 // 1 connection_ack, 3 next, 1 complete ) } diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index a35aa09..90449c9 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -1,35 +1,29 @@ import Foundation - @testable import GraphQLTransportWS /// Messenger for simple testing that doesn't require starting up a websocket server. -/// -/// Note that this only retains a weak reference to 'other', so the client should retain references -/// or risk them being deinitialized early -class TestMessenger: Messenger, @unchecked Sendable { - weak var other: TestMessenger? - var onReceive: (String) async throws -> Void = { _ in } - let queue: DispatchQueue = .init(label: "Test messenger") - - init() {} - - func send(_ message: S) async throws where S: Collection, S.Element == Character { - guard let other = other else { - return - } - try await other.onReceive(String(message)) +actor TestMessenger: Messenger { + /// An async stream of the messages sent through this messenger. + let stream: AsyncStream + private var continuation: AsyncStream.Continuation + + init() { + let (stream, continuation) = AsyncStream.makeStream() + self.stream = stream + self.continuation = continuation } - func onReceive(callback: @escaping (String) async throws -> Void) { - onReceive = callback + func send(_ message: S) async throws where S.Element == Character { + continuation.yield(String(message)) } func error(_ message: String, code: Int) async throws { - try await send("\(code): \(message)") + continuation.yield("\(code): \(message)") + continuation.finish() } func close() { - // This is a testing no-op + continuation.finish() } }