diff --git a/Sources/NIOHTTPServer/NIOHTTPServer+HTTP1_1.swift b/Sources/NIOHTTPServer/NIOHTTPServer+HTTP1_1.swift index 83affad..dc588a0 100644 --- a/Sources/NIOHTTPServer/NIOHTTPServer+HTTP1_1.swift +++ b/Sources/NIOHTTPServer/NIOHTTPServer+HTTP1_1.swift @@ -65,47 +65,52 @@ extension NIOHTTPServer { func setupHTTP1_1ServerChannels( bindTargets: [NIOHTTPServerConfiguration.BindTarget] - ) async throws -> [NIOAsyncChannel, Never>] { - let bootstrap = ServerBootstrap(group: .singletonMultiThreadedEventLoopGroup) + ) async throws -> [( + NIOAsyncChannel, Never>, ServerQuiescingHelper + )] { + let bootstrap = ServerBootstrap(group: self.eventLoopGroup) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) - .serverChannelInitializer { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler( - self.serverQuiescingHelper.makeServerChannelHandler(channel: channel) - ) - } - } - var serverChannels = [NIOAsyncChannel, Never>]() + var serverChannels = [ + (NIOAsyncChannel, Never>, ServerQuiescingHelper) + ]() + do { for bindTarget in bindTargets { switch bindTarget.backing { case .hostAndPort(let host, let port): - let serverChannel = - try await bootstrap.bind(host: host, port: port) { channel in - self.setupHTTP1_1Connection( - channel: channel, - asyncChannelConfiguration: .init( - backPressureStrategy: .init(self.configuration.backpressureStrategy), - isOutboundHalfClosureEnabled: true - ), - isSecure: false + let serverQuiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) + + let serverChannel = try await bootstrap.serverChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler( + serverQuiescingHelper.makeServerChannelHandler(channel: channel) ) } - serverChannels.append(serverChannel) + }.bind(host: host, port: port) { channel in + self.setupHTTP1_1Connection( + channel: channel, + asyncChannelConfiguration: .init( + backPressureStrategy: .init(self.configuration.backpressureStrategy), + isOutboundHalfClosureEnabled: true + ), + isSecure: false + ) + } + serverChannels.append((serverChannel, serverQuiescingHelper)) } } } catch { // A later bind failed: close any channels we already bound to avoid leaking sockets. // We await the closes so the sockets are fully released by the time we throw, giving the // caller deterministic semantics: when `serve` throws, all cleanup is done. - for serverChannel in serverChannels { + for (serverChannel, _) in serverChannels { try? await serverChannel.channel.close() } throw error } - try self.addressesBound(serverChannels.map { $0.channel.localAddress }) + try self.addressesBound(serverChannels.map { (serverChannel, _) in serverChannel.channel.localAddress }) return serverChannels } diff --git a/Sources/NIOHTTPServer/NIOHTTPServer+SecureUpgrade.swift b/Sources/NIOHTTPServer/NIOHTTPServer+SecureUpgrade.swift index 5ce6f22..bba1c16 100644 --- a/Sources/NIOHTTPServer/NIOHTTPServer+SecureUpgrade.swift +++ b/Sources/NIOHTTPServer/NIOHTTPServer+SecureUpgrade.swift @@ -157,44 +157,44 @@ extension NIOHTTPServer { bindTargets: [NIOHTTPServerConfiguration.BindTarget], supportedHTTPVersions: Set, sslContext: NIOSSLContext - ) async throws -> [NIOAsyncChannel, Never>] { - let bootstrap = ServerBootstrap(group: .singletonMultiThreadedEventLoopGroup) + ) async throws -> [(NIOAsyncChannel, Never>, ServerQuiescingHelper)] { + let bootstrap = ServerBootstrap(group: self.eventLoopGroup) .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) - .serverChannelInitializer { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler( - self.serverQuiescingHelper.makeServerChannelHandler(channel: channel) - ) - } - } - var serverChannels = [NIOAsyncChannel, Never>]() + var serverChannels = [(NIOAsyncChannel, Never>, ServerQuiescingHelper)]() do { for bindTarget in bindTargets { switch bindTarget.backing { case .hostAndPort(let host, let port): - let serverChannel = - try await bootstrap.bind(host: host, port: port) { channel in - self.setupSecureUpgradeConnectionChildChannel( - channel: channel, - supportedHTTPVersions: supportedHTTPVersions, - sslContext: sslContext + let serverQuiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) + + let serverChannel = try await bootstrap.serverChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler( + serverQuiescingHelper.makeServerChannelHandler(channel: channel) ) } - serverChannels.append(serverChannel) + }.bind(host: host, port: port) { channel in + self.setupSecureUpgradeConnectionChildChannel( + channel: channel, + supportedHTTPVersions: supportedHTTPVersions, + sslContext: sslContext + ) + } + serverChannels.append((serverChannel, serverQuiescingHelper)) } } } catch { // A later bind failed: close any channels we already bound to avoid leaking sockets. // We await the closes so the sockets are fully released by the time we throw, giving the // caller deterministic semantics: when `serve` throws, all cleanup is done. - for serverChannel in serverChannels { + for (serverChannel, _) in serverChannels { try? await serverChannel.channel.close() } throw error } - try self.addressesBound(serverChannels.map { $0.channel.localAddress }) + try self.addressesBound(serverChannels.map { (serverChannel, _) in serverChannel.channel.localAddress }) return serverChannels } diff --git a/Sources/NIOHTTPServer/NIOHTTPServer.swift b/Sources/NIOHTTPServer/NIOHTTPServer.swift index 0434465..8210ca2 100644 --- a/Sources/NIOHTTPServer/NIOHTTPServer.swift +++ b/Sources/NIOHTTPServer/NIOHTTPServer.swift @@ -88,7 +88,12 @@ public struct NIOHTTPServer: HTTPServer { let logger: Logger let configuration: NIOHTTPServerConfiguration - let serverQuiescingHelper: ServerQuiescingHelper + /// The event loop group on which the server runs. + /// + /// This event loop group is used for every channel the server binds. It also provides the event loop that fulfills + /// the listening address promise and the group from which a `ServerQuiescingHelper` is created for each bound + /// channel. + let eventLoopGroup: MultiThreadedEventLoopGroup var listeningAddressState: NIOLockedValueBox @@ -109,10 +114,8 @@ public struct NIOHTTPServer: HTTPServer { self.configuration = configuration // TODO: If we allow users to pass in an event loop, use that instead of the singleton MTELG. - let eventLoopGroup: MultiThreadedEventLoopGroup = .singletonMultiThreadedEventLoopGroup - self.listeningAddressState = .init(.idle(eventLoopGroup.any().makePromise())) - - self.serverQuiescingHelper = .init(group: eventLoopGroup) + self.eventLoopGroup = .singletonMultiThreadedEventLoopGroup + self.listeningAddressState = .init(.idle(self.eventLoopGroup.any().makePromise())) } /// Starts an HTTP server with the specified request handler. @@ -123,13 +126,9 @@ public struct NIOHTTPServer: HTTPServer { /// /// ## All-or-nothing listening /// - /// The server treats its set of listening addresses as a single unit. If any one of the bound addresses - /// stops listening — whether due to its underlying socket closing, an unrecoverable error on the - /// listening channel, or any other reason — the server stops listening on **all** remaining addresses - /// and this method returns. After that point, ``listeningAddresses`` will throw - /// ``ListeningAddressError/serverClosed``. - /// - /// This also applies during graceful shutdown and task cancellation: all channels are shut down together. + /// The server treats its set of listening addresses as a single unit. If an unrecoverable error occurs on any of + /// the listening channels, the server stops listening on **all** remaining addresses and this method returns. After + /// that point, ``listeningAddresses`` will throw ``ListeningAddressError/serverClosed``. /// /// - Parameter handler: A ``HTTPServerRequestHandler`` implementation that processes incoming HTTP /// requests. The handler receives each request along with a body reader and response sender function. @@ -164,7 +163,7 @@ public struct NIOHTTPServer: HTTPServer { try await withGracefulShutdownHandler { try await self._serve(serverChannels: serverChannels, handler: handler) } onGracefulShutdown: { - self.beginGracefulShutdown() + self.beginGracefulShutdown(serverChannels: serverChannels) } } onCancel: { // Forcefully close down the server channels @@ -177,7 +176,9 @@ public struct NIOHTTPServer: HTTPServer { switch self.configuration.transportSecurity.backing { case .plaintext: return try await self.setupHTTP1_1ServerChannels(bindTargets: self.configuration.bindTargets) - .map { .plaintextHTTP1_1($0) } + .map { channel, quiescingHelper in + .plaintextHTTP1_1(channel: channel, quiescingHelper: quiescingHelper) + } case .tls, .mTLS: return try await self.setupSecureUpgradeServerChannels( @@ -187,7 +188,9 @@ public struct NIOHTTPServer: HTTPServer { transportSecurity: self.configuration.transportSecurity, alpnIdentifiers: self.configuration.supportedHTTPVersions.alpnIdentifiers ), - ).map { .secureUpgrade($0) } + ).map { channel, quiescingHelper in + .secureUpgrade(channel: channel, quiescingHelper: quiescingHelper) + } } } @@ -199,19 +202,22 @@ public struct NIOHTTPServer: HTTPServer { for serverChannel in serverChannels { group.addTask { switch serverChannel { - case .plaintextHTTP1_1(let http1Channel): + case .plaintextHTTP1_1(let http1Channel, _): try await self.serveInsecureHTTP1_1(serverChannel: http1Channel, handler: handler) - case .secureUpgrade(let secureUpgradeChannel): + case .secureUpgrade(let secureUpgradeChannel, _): try await self.serveSecureUpgrade(serverChannel: secureUpgradeChannel, handler: handler) } } } - // Wait for the first channel to complete (either normally or by throwing). - // If any channel stops serving, bring down all remaining channels. - try await group.next() - group.cancelAll() + // If an error occurs in any channel, bring down all other channels too and propagate the error. + do { + for try await _ in group {} + } catch { + // Propagate the error. This will cancel the entire group. + throw error + } } } @@ -303,9 +309,16 @@ public struct NIOHTTPServer: HTTPServer { } /// Initiates a graceful shutdown, allowing existing connections to drain before closing. - private func beginGracefulShutdown() { + private func beginGracefulShutdown(serverChannels: [ServerChannel]) { self.finishListeningAddressPromise() - self.serverQuiescingHelper.initiateShutdown(promise: nil) + + for serverChannel in serverChannels { + switch serverChannel { + case .plaintextHTTP1_1(_, let quiescingHelper), + .secureUpgrade(_, let quiescingHelper): + quiescingHelper.initiateShutdown(promise: nil) + } + } } /// Forcefully closes the server channels without waiting for existing connections to drain. @@ -314,10 +327,10 @@ public struct NIOHTTPServer: HTTPServer { for serverChannel in serverChannels { switch serverChannel { - case .plaintextHTTP1_1(let http1Channel): + case .plaintextHTTP1_1(let http1Channel, _): http1Channel.channel.close(promise: nil) - case .secureUpgrade(let secureUpgradeChannel): + case .secureUpgrade(let secureUpgradeChannel, _): secureUpgradeChannel.channel.close(promise: nil) } } diff --git a/Sources/NIOHTTPServer/ServerChannel.swift b/Sources/NIOHTTPServer/ServerChannel.swift index 339be9a..ae65104 100644 --- a/Sources/NIOHTTPServer/ServerChannel.swift +++ b/Sources/NIOHTTPServer/ServerChannel.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import NIOCore +import NIOExtras import NIOHTTPTypes @available(anyAppleOS 26.0, *) @@ -20,7 +21,14 @@ extension NIOHTTPServer { /// Abstracts over the two types of server channels ``NIOHTTPServer`` can create: plaintext HTTP/1.1 and Secure /// Upgrade. enum ServerChannel { - case plaintextHTTP1_1(NIOAsyncChannel, Never>) - case secureUpgrade(NIOAsyncChannel, Never>) + case plaintextHTTP1_1( + channel: NIOAsyncChannel, Never>, + quiescingHelper: ServerQuiescingHelper + ) + + case secureUpgrade( + channel: NIOAsyncChannel, Never>, + quiescingHelper: ServerQuiescingHelper + ) } } diff --git a/Tests/NIOHTTPServerTests/HTTPKeepAliveHandlerTests.swift b/Tests/NIOHTTPServerTests/HTTPKeepAliveHandlerTests.swift index 70ec9ac..8611778 100644 --- a/Tests/NIOHTTPServerTests/HTTPKeepAliveHandlerTests.swift +++ b/Tests/NIOHTTPServerTests/HTTPKeepAliveHandlerTests.swift @@ -528,13 +528,13 @@ struct HTTPKeepAliveHandlerTests { try await writer.writeAndConclude("".utf8.span, finalElement: nil) }, body: { serverAddress in - let clientChannel = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) .connectToTestSecureUpgradeHTTPServer( at: serverAddress, trustRoots: serverChain.chain, applicationProtocol: HTTPVersion.http1_1.alpnIdentifier ) - let client = try await NIOHTTPServerTests.unwrapNegotiatedChannel(clientChannel, .http1_1) + .unwrapChannel(expectedHTTPVersion: .http1_1) try await client.executeThenClose { inbound, outbound in try await outbound.write( diff --git a/Tests/NIOHTTPServerTests/NIOHTTPServer+ServiceLifecycleTests.swift b/Tests/NIOHTTPServerTests/NIOHTTPServer+ServiceLifecycleTests.swift index 22cf55d..4145d02 100644 --- a/Tests/NIOHTTPServerTests/NIOHTTPServer+ServiceLifecycleTests.swift +++ b/Tests/NIOHTTPServerTests/NIOHTTPServer+ServiceLifecycleTests.swift @@ -15,6 +15,7 @@ import AsyncStreaming import HTTPTypes import Logging +import NIOConcurrencyHelpers import NIOCore import NIOHTTPTypes import NIOPosix @@ -32,20 +33,16 @@ struct NIOHTTPServiceLifecycleTests { static let trailer: HTTPFields = [.trailer: "test_trailer"] static let reqEnd = HTTPRequestPart.end(trailer) - let serverLogger = Logger(label: "Test Server") - let serviceGroupLogger = Logger(label: "Test ServiceGroup") + let serverLogger = Logger(label: "NIOHTTPServiceLifecycleTests") + let serviceGroupLogger = Logger(label: "NIOHTTPServiceLifecycleTests_ServiceGroup") - @Test("HTTP/1.1 active connection completes when graceful shutdown triggered", ) + @Test( + "Active connection completes when graceful shutdown triggered", + arguments: [HTTPVersion.http1_1, HTTPVersion.http2] + ) @available(anyAppleOS 26.0, *) - func activeHTTP1ConnectionCanCompleteWhenGracefulShutdown() async throws { - let server = NIOHTTPServer( - logger: self.serverLogger, - configuration: try .init( - bindTarget: .hostAndPort(host: "127.0.0.1", port: 0), - supportedHTTPVersions: [.http1_1], - transportSecurity: .plaintext - ) - ) + func activeConnectionCanCompleteWhenGracefullyShutdown(httpVersion: HTTPVersion) async throws { + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) // This promise will be fulfilled when the server receives the first part of the body. Once this happens, we can // initiate the graceful shutdown and then send the remaining body. If graceful shutdown is respected, we should @@ -85,7 +82,12 @@ struct NIOHTTPServiceLifecycleTests { let serverAddress = try await server.listeningAddresses.first! let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) - .connectToTestHTTP1Server(at: serverAddress) + .connectToTestSecureUpgradeHTTPServer( + at: serverAddress, + trustRoots: serverChain.chain, + applicationProtocol: httpVersion.alpnIdentifier + ) + .unwrapChannel(expectedHTTPVersion: httpVersion) try await client.executeThenClose { inbound, outbound in try await outbound.write(Self.reqHead) @@ -124,17 +126,13 @@ struct NIOHTTPServiceLifecycleTests { } } - @Test("HTTP/1.1 active connection forcefully shutdown when server task cancelled") + @Test( + "Active connection forcefully shutdown when server task cancelled", + arguments: [HTTPVersion.http1_1, HTTPVersion.http2] + ) @available(anyAppleOS 26.0, *) - func activeHTTP1ConnectionForcefullyShutdownWhenServerTaskCancelled() async throws { - let server = NIOHTTPServer( - logger: self.serverLogger, - configuration: try .init( - bindTarget: .hostAndPort(host: "127.0.0.1", port: 0), - supportedHTTPVersions: [.http1_1], - transportSecurity: .plaintext - ) - ) + func activeConnectionForcefullyShutdownWhenServerTaskCancelled(httpVersion: HTTPVersion) async throws { + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) // This promise will be fulfilled when the server receives the first part of the request body. Once this // happens, we cancel the server task and test whether the in-flight request's connection was forcefully shut. @@ -171,7 +169,12 @@ struct NIOHTTPServiceLifecycleTests { let serverAddress = try await server.listeningAddresses.first! let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) - .connectToTestHTTP1Server(at: serverAddress) + .connectToTestSecureUpgradeHTTPServer( + at: serverAddress, + trustRoots: serverChain.chain, + applicationProtocol: httpVersion.alpnIdentifier + ) + .unwrapChannel(expectedHTTPVersion: httpVersion) try await client.executeThenClose { inbound, outbound in try await outbound.write(Self.reqHead) @@ -297,4 +300,158 @@ struct NIOHTTPServiceLifecycleTests { } } } + + @Test( + "Active connections across different listeners can complete when graceful shutdown triggered", + arguments: [ + (HTTPVersion.http1_1, HTTPVersion.http1_1), + (HTTPVersion.http1_1, HTTPVersion.http2), + (HTTPVersion.http2, HTTPVersion.http1_1), + (HTTPVersion.http2, HTTPVersion.http2), + ] + ) + @available(anyAppleOS 26.0, *) + func activeConnectionsAcrossDifferentListenersCanCompleteWhenGracefullyShutdown( + firstClientHTTPVersion: HTTPVersion, + secondClientHTTPVersion: HTTPVersion + ) async throws { + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer( + bindTargets: [ + // Configure two bind targets. We want to test whether graceful shutdown works independently on each + // bind target. + .hostAndPort(host: "127.0.0.1", port: 0), + .hostAndPort(host: "127.0.0.1", port: 0), + ], + logger: self.serverLogger + ) + + // The test needs both clients to have an active in-flight request before triggering graceful shutdown. To + // express this, we create two promises (one for each bind target), which will be fulfilled by the server's + // request handler once it has *started* processing the corresponding request. + let elg = MultiThreadedEventLoopGroup.singletonMultiThreadedEventLoopGroup + let firstTargetRequestStartedPromise = elg.any().makePromise(of: Void.self) + let secondTargetRequestStartedPromise = elg.any().makePromise(of: Void.self) + + // The server handler needs to know which of the two promises to fulfill. Since the second client only sends + // its request after the server has started processing the first client's request, we set up a counter so that + // the server can know to fulfill `firstTargetRequestStartedPromise` on the first request and + // `secondTargetRequestStartedPromise` on the second request. + let requestNumber = NIOLockedValueBox(0) + + let serverService = ClosureService { + try await server.serve { request, requestContext, requestReader, responseSender in + _ = try await requestReader.consumeAndConclude { bodyReader in + var bodyReader = bodyReader + try await bodyReader.read { _ in } + + let count = requestNumber.withLockedValue { n in + n += 1 + return n + } + + if count == 1 { + firstTargetRequestStartedPromise.succeed() + } else if count == 2 { + secondTargetRequestStartedPromise.succeed() + } + + var requestFinished = false + while !requestFinished { + try await bodyReader.read { if $0.isEmpty { requestFinished = true } } + } + } + + let responseBodyWriter = try await responseSender.send(.init(status: .ok)) + try await responseBodyWriter.produceAndConclude { writer in + var writer = writer + try await writer.write([1, 2].span) + return .none + } + } + } + + try await confirmation(expectedCount: 2) { responseReceived in + try await testGracefulShutdown { trigger in + try await withThrowingTaskGroup { group in + let serviceGroup = ServiceGroup(services: [serverService], logger: self.serviceGroupLogger) + group.addTask { try await serviceGroup.run() } + + let firstServerAddress = try await server.listeningAddresses[0] + let secondServerAddress = try await server.listeningAddresses[1] + + let firstClient = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + .connectToTestSecureUpgradeHTTPServer( + at: firstServerAddress, + trustRoots: serverChain.chain, + applicationProtocol: firstClientHTTPVersion.alpnIdentifier + ) + .unwrapChannel(expectedHTTPVersion: firstClientHTTPVersion) + + try await firstClient.executeThenClose { firstInbound, firstOutbound in + try await firstOutbound.write(Self.reqHead) + try await firstOutbound.write(Self.reqBody) + + // Wait until the server has received the body part. + try await firstTargetRequestStartedPromise.futureResult.get() + + let secondClient = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + .connectToTestSecureUpgradeHTTPServer( + at: secondServerAddress, + trustRoots: serverChain.chain, + applicationProtocol: secondClientHTTPVersion.alpnIdentifier + ) + .unwrapChannel(expectedHTTPVersion: secondClientHTTPVersion) + + try await secondClient.executeThenClose { secondInbound, secondOutbound in + try await secondOutbound.write(Self.reqHead) + try await secondOutbound.write(Self.reqBody) + + // Wait until the server has received the body part. + try await secondTargetRequestStartedPromise.futureResult.get() + + // Now start the shutdown. + trigger.triggerGracefulShutdown() + + // The second client should be able to complete its request. + try await secondOutbound.write(Self.reqBody) + try await secondOutbound.write(Self.reqEnd) + + for try await response in secondInbound { + switch response { + case .head(let head): + #expect(head.status == .ok) + case .body(let body): + #expect(body == .init([1, 2])) + case .end(let trailers): + #expect(trailers == nil) + } + } + + responseReceived() + } + + // And so should the first client. + try await firstOutbound.write(Self.reqBody) + try await firstOutbound.write(Self.reqEnd) + + for try await response in firstInbound { + switch response { + case .head(let head): + #expect(head.status == .ok) + case .body(let body): + #expect(body == .init([1, 2])) + case .end(let trailers): + #expect(trailers == nil) + } + } + + responseReceived() + + // The server should now shut down. Wait for this. + try await group.waitForAll() + } + } + } + } + } } diff --git a/Tests/NIOHTTPServerTests/NIOHTTPServerTests.swift b/Tests/NIOHTTPServerTests/NIOHTTPServerTests.swift index d314fc3..f06f044 100644 --- a/Tests/NIOHTTPServerTests/NIOHTTPServerTests.swift +++ b/Tests/NIOHTTPServerTests/NIOHTTPServerTests.swift @@ -39,7 +39,7 @@ struct NIOHTTPServerTests { @Test("Obtain the listening address correctly") func testListeningAddress() async throws { let server = NIOHTTPServer( - logger: Logger(label: "NIOHTTPServerTests"), + logger: self.serverLogger, configuration: try .init( bindTarget: .hostAndPort(host: "127.0.0.1", port: 1234), supportedHTTPVersions: [.http1_1], @@ -67,7 +67,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) func testPlaintext() async throws { let server = NIOHTTPServer( - logger: Logger(label: "NIOHTTPServerTests"), + logger: self.serverLogger, configuration: try .init( bindTarget: .hostAndPort(host: "127.0.0.1", port: 0), supportedHTTPVersions: [.http1_1], @@ -173,14 +173,14 @@ struct NIOHTTPServerTests { } }, body: { serverAddress in - let clientChannel = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) .connectToTestSecureUpgradeHTTPServerOverMTLS( at: serverAddress, clientChain: clientChain, trustRoots: [serverChain.ca], applicationProtocol: httpVersion.alpnIdentifier ) - let client = try await Self.unwrapNegotiatedChannel(clientChannel, httpVersion) + .unwrapChannel(expectedHTTPVersion: httpVersion) try await client.executeThenClose { inbound, outbound in try await outbound.write(.head(.init(method: .post, scheme: "https", authority: "", path: "/"))) @@ -205,7 +205,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) @Test("Multiple informational response headers", arguments: [HTTPVersion.http1_1, HTTPVersion.http2]) func testMultipleInformationalResponseHeaders(httpVersion: HTTPVersion) async throws { - let (server, serverChain) = try self.makeSecureUpgradeServer() + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) try await confirmation { responseReceived in try await Self.withServer( @@ -222,13 +222,13 @@ struct NIOHTTPServerTests { } }, body: { serverAddress in - let clientChannel = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) .connectToTestSecureUpgradeHTTPServer( at: serverAddress, trustRoots: serverChain.chain, applicationProtocol: httpVersion.alpnIdentifier ) - let client = try await Self.unwrapNegotiatedChannel(clientChannel, httpVersion) + .unwrapChannel(expectedHTTPVersion: httpVersion) try await client.executeThenClose { inbound, outbound in try await outbound.write(.head(.init(method: .get, scheme: "https", authority: "", path: "/"))) @@ -255,7 +255,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) @Test("Client closes stream without sending end part", arguments: [HTTPVersion.http1_1, HTTPVersion.http2]) func testRequestWithoutEndPart(httpVersion: HTTPVersion) async throws { - let (server, serverChain) = try self.makeSecureUpgradeServer() + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) let elg: EventLoopGroup = .singletonMultiThreadedEventLoopGroup let requestReadPromise = elg.any().makePromise(of: Void.self) @@ -287,13 +287,13 @@ struct NIOHTTPServerTests { } }, body: { serverAddress in - let clientChannel = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) .connectToTestSecureUpgradeHTTPServer( at: serverAddress, trustRoots: serverChain.chain, applicationProtocol: httpVersion.alpnIdentifier ) - let client = try await Self.unwrapNegotiatedChannel(clientChannel, httpVersion) + .unwrapChannel(expectedHTTPVersion: httpVersion) try await client.executeThenClose { inbound, outbound in // Only send a request head; finish the stream immediately afterwards. @@ -313,7 +313,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) @Test("Bi-directional streaming", arguments: [HTTPVersion.http1_1, HTTPVersion.http2]) func testBidirectionalStreaming(httpVersion: HTTPVersion) async throws { - let (server, serverChain) = try self.makeSecureUpgradeServer() + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) try await Self.withServer( server: server, @@ -344,13 +344,13 @@ struct NIOHTTPServerTests { } }, body: { serverAddress in - let clientChannel = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) .connectToTestSecureUpgradeHTTPServer( at: serverAddress, trustRoots: serverChain.chain, applicationProtocol: httpVersion.alpnIdentifier ) - let client = try await Self.unwrapNegotiatedChannel(clientChannel, httpVersion) + .unwrapChannel(expectedHTTPVersion: httpVersion) try await client.executeThenClose { inbound, outbound in try await outbound.write(.head(.init(method: .post, scheme: "https", authority: "", path: "/"))) @@ -438,7 +438,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) @Test("Multiple concurrent connections", arguments: [HTTPVersion.http1_1, HTTPVersion.http2]) func testMultipleConcurrentConnections(httpVersion: HTTPVersion) async throws { - let (server, serverChain) = try self.makeSecureUpgradeServer() + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) // We will create 10 connections and send a request from each connection. The server will fulfill the // `allOtherRequestsCanProceedPromise` promise after seeing the 10th request. All other requests will be blocked @@ -470,16 +470,14 @@ struct NIOHTTPServerTests { await withThrowingTaskGroup { group in for _ in 1...numConnections { group.addTask { - let clientChannel = try await ClientBootstrap( - group: .singletonMultiThreadedEventLoopGroup - ) - .connectToTestSecureUpgradeHTTPServer( - at: serverAddress, - trustRoots: serverChain.chain, - applicationProtocol: httpVersion.alpnIdentifier - ) + let client = try await ClientBootstrap(group: .singletonMultiThreadedEventLoopGroup) + .connectToTestSecureUpgradeHTTPServer( + at: serverAddress, + trustRoots: serverChain.chain, + applicationProtocol: httpVersion.alpnIdentifier + ) + .unwrapChannel(expectedHTTPVersion: httpVersion) - let client = try await Self.unwrapNegotiatedChannel(clientChannel, httpVersion) try await client.executeThenClose { inbound, outbound in try await outbound.write( .head(.init(method: .post, scheme: "https", authority: "", path: "/")) @@ -507,7 +505,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) @Test("Multiple concurrent HTTP/2 streams") func testMultipleConcurrentHTTP2Streams() async throws { - let (server, serverChain) = try self.makeSecureUpgradeServer() + let (server, serverChain) = try NIOHTTPServerTests.makeSecureUpgradeServer(logger: self.serverLogger) let numStreams = 10 let requestCounter = Mutex(0) @@ -577,7 +575,7 @@ struct NIOHTTPServerTests { @available(anyAppleOS 26.0, *) @Test("Server can still process other connections despite one failing") func testServerCanContinueDespiteFailedConnection() async throws { - let server = try self.makePlaintextHTTP1Server() + let server = try NIOHTTPServerTests.makePlaintextHTTP1Server(logger: self.serverLogger) let elg: EventLoopGroup = .singletonMultiThreadedEventLoopGroup let firstRequestErrorCaught = elg.any().makePromise(of: Void.self) @@ -637,7 +635,7 @@ struct NIOHTTPServerTests { @Test("Bind to multiple addresses") func testMultipleBindAddresses() async throws { let server = NIOHTTPServer( - logger: Logger(label: "NIOHTTPServerTests"), + logger: self.serverLogger, configuration: try .init( bindTargets: [ .hostAndPort(host: "127.0.0.1", port: 0), @@ -662,7 +660,7 @@ struct NIOHTTPServerTests { @Test("Serve requests on multiple addresses independently") func testServeOnMultipleAddresses() async throws { let server = NIOHTTPServer( - logger: Logger(label: "NIOHTTPServerTests"), + logger: self.serverLogger, configuration: try .init( bindTargets: [ .hostAndPort(host: "127.0.0.1", port: 0), @@ -732,7 +730,7 @@ struct NIOHTTPServerTests { @Test("All addresses stop together and listeningAddresses throws after server stops") func testAllAddressesStopTogether() async throws { let server = NIOHTTPServer( - logger: Logger(label: "NIOHTTPServerTests"), + logger: self.serverLogger, configuration: try .init( bindTargets: [ .hostAndPort(host: "127.0.0.1", port: 0), @@ -823,7 +821,7 @@ struct NIOHTTPServerTests { // Configure a server that binds to [firstPort, occupiedPort]. The first bind should succeed, // the second should fail with "address already in use", causing cleanup of the first channel. let server = NIOHTTPServer( - logger: Logger(label: "NIOHTTPServerTests"), + logger: self.serverLogger, configuration: try .init( bindTargets: [ .hostAndPort(host: "127.0.0.1", port: firstPort), @@ -860,9 +858,9 @@ extension NIOHTTPServerTests { static let reqEnd = HTTPRequestPart.end(trailer) @available(anyAppleOS 26.0, *) - func makePlaintextHTTP1Server() throws -> NIOHTTPServer { + static func makePlaintextHTTP1Server(logger: Logger) throws -> NIOHTTPServer { let server = NIOHTTPServer( - logger: self.serverLogger, + logger: logger, configuration: try .init( bindTarget: .hostAndPort(host: "127.0.0.1", port: 0), supportedHTTPVersions: [.http1_1], @@ -874,13 +872,16 @@ extension NIOHTTPServerTests { } @available(anyAppleOS 26.0, *) - func makeSecureUpgradeServer() throws -> (NIOHTTPServer, ChainPrivateKeyPair) { + static func makeSecureUpgradeServer( + bindTargets: [NIOHTTPServerConfiguration.BindTarget] = [.hostAndPort(host: "127.0.0.1", port: 0)], + logger: Logger + ) throws -> (NIOHTTPServer, ChainPrivateKeyPair) { let serverChain = try TestCA.makeSelfSignedChain() let server = NIOHTTPServer( - logger: self.serverLogger, + logger: logger, configuration: try .init( - bindTarget: .hostAndPort(host: "127.0.0.1", port: 0), + bindTargets: bindTargets, supportedHTTPVersions: [.http1_1, .http2(config: .defaults)], transportSecurity: .tls( credentials: .inMemory(certificateChain: serverChain.chain, privateKey: serverChain.privateKey) @@ -924,32 +925,6 @@ extension NIOHTTPServerTests { } } - /// Unwraps a negotiated channel, asserting it matches the expected `httpVersion`. For HTTP/2, opens and returns a - /// new stream channel. - static func unwrapNegotiatedChannel( - _ negotiatedChannel: NegotiatedClientConnection, - _ httpVersion: HTTPVersion, - sourceLocation: SourceLocation = #_sourceLocation - ) async throws -> NIOAsyncChannel { - switch negotiatedChannel { - case .http1(let http1Channel): - #expect( - httpVersion == .http1_1, - "Unexpectedly established an HTTP/1 connection.", - sourceLocation: sourceLocation - ) - return http1Channel - - case .http2(let http2StreamManager): - #expect( - httpVersion == .http2, - "Unexpectedly established an HTTP/2 connection.", - sourceLocation: sourceLocation - ) - return try await http2StreamManager.openStream() - } - } - /// Returns the body encoding header fields required for the given HTTP version. static func makeBodyEncodingHeaders(for httpVersion: HTTPVersion) -> HTTPFields { switch httpVersion { diff --git a/Tests/NIOHTTPServerTests/Utilities/NegotiatedClientConnection.swift b/Tests/NIOHTTPServerTests/Utilities/NegotiatedClientConnection.swift index 5292c6c..9e438b0 100644 --- a/Tests/NIOHTTPServerTests/Utilities/NegotiatedClientConnection.swift +++ b/Tests/NIOHTTPServerTests/Utilities/NegotiatedClientConnection.swift @@ -16,6 +16,7 @@ import NIOCore import NIOHTTP2 import NIOHTTPTypes import NIOHTTPTypesHTTP2 +import Testing /// A testing utility that wraps the result of ALPN negotiation for HTTP/1.1 or HTTP/2 client connections. /// @@ -61,3 +62,30 @@ enum NegotiatedClientConnection { } } } + +extension NegotiatedClientConnection { + /// Unwraps a negotiated channel, asserting it matches the expected `httpVersion`. For HTTP/2, opens and returns a + /// new stream channel. + func unwrapChannel( + expectedHTTPVersion: HTTPVersion, + sourceLocation: SourceLocation = #_sourceLocation + ) async throws -> NIOAsyncChannel { + switch self { + case .http1(let http1Channel): + #expect( + expectedHTTPVersion == .http1_1, + "Unexpectedly established an HTTP/1 connection.", + sourceLocation: sourceLocation + ) + return http1Channel + + case .http2(let http2StreamManager): + #expect( + expectedHTTPVersion == .http2, + "Unexpectedly established an HTTP/2 connection.", + sourceLocation: sourceLocation + ) + return try await http2StreamManager.openStream() + } + } +}