diff --git a/Sources/NIOPosix/NIORandomizedDNSResolver.swift b/Sources/NIOPosix/NIORandomizedDNSResolver.swift new file mode 100644 index 0000000000..56fd497927 --- /dev/null +++ b/Sources/NIOPosix/NIORandomizedDNSResolver.swift @@ -0,0 +1,112 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2026 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if !os(WASI) + +import NIOCore + +/// A DNS resolver that randomizes the order of addresses returned by the system resolver +/// to enable round-robin DNS load balancing. +/// +/// By default, the system's `getaddrinfo` function returns addresses in a deterministic order +/// as specified by RFC 6724 (destination address selection). While this ordering is correct +/// per the RFC, it defeats DNS-based load balancing techniques such as those used by Kubernetes +/// headless services, where multiple A/AAAA records are returned and clients are expected to +/// distribute connections across all available backends. +/// +/// `NIORandomizedDNSResolver` wraps the standard `getaddrinfo`-based resolver and shuffles the +/// returned addresses within each address family (IPv4 and IPv6 independently) before returning +/// them. This ensures that connection attempts are distributed across all available backends +/// rather than always targeting the same one. +/// +/// This resolver is a single-use object: it can only be used to perform a single host resolution, +/// just like the underlying system resolver. +/// +/// ### Usage with `ClientBootstrap` +/// +/// ```swift +/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) +/// let bootstrap = ClientBootstrap(group: group) +/// .resolver(NIORandomizedDNSResolver(loop: group.next())) +/// let channel = try await bootstrap.connect( +/// host: "my-service.default.svc.cluster.local", +/// port: 8080 +/// ).get() +/// ``` +public final class NIORandomizedDNSResolver: Resolver, Sendable { + private let underlying: Resolver & Sendable + + /// The function used to shuffle address results. Defaults to `Array.shuffled()`. + /// Exposed as `internal` to allow deterministic testing via `@testable import`. + internal let shuffleFunction: @Sendable ([SocketAddress]) -> [SocketAddress] + + /// Create a new `NIORandomizedDNSResolver` for use with TCP stream connections. + /// + /// - Parameters: + /// - loop: The `EventLoop` to use for DNS resolution. + public init(loop: EventLoop) { + self.underlying = GetaddrinfoResolver( + loop: loop, + aiSocktype: .stream, + aiProtocol: .tcp + ) + self.shuffleFunction = { $0.shuffled() } + } + + /// Internal initializer for testing with an injectable resolver and shuffle function. + /// + /// - Parameters: + /// - resolver: The underlying resolver to delegate DNS queries to. + /// - shuffleFunction: A function that reorders the address array. Defaults to `Array.shuffled()`. + init( + resolver: Resolver & Sendable, + shuffleFunction: @escaping @Sendable ([SocketAddress]) -> [SocketAddress] = { $0.shuffled() } + ) { + self.underlying = resolver + self.shuffleFunction = shuffleFunction + } + + /// Initiate a DNS A query for a given host. + /// + /// The results from the underlying system resolver are shuffled before being returned. + /// + /// - Parameters: + /// - host: The hostname to do an A lookup on. + /// - port: The port we'll be connecting to. + /// - Returns: An `EventLoopFuture` that fires with the shuffled result of the A lookup. + public func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.underlying.initiateAQuery(host: host, port: port).map { self.shuffleFunction($0) } + } + + /// Initiate a DNS AAAA query for a given host. + /// + /// The results from the underlying system resolver are shuffled before being returned. + /// + /// - Parameters: + /// - host: The hostname to do a AAAA lookup on. + /// - port: The port we'll be connecting to. + /// - Returns: An `EventLoopFuture` that fires with the shuffled result of the AAAA lookup. + public func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.underlying.initiateAAAAQuery(host: host, port: port).map { self.shuffleFunction($0) } + } + + /// Cancel all outstanding DNS queries. + /// + /// This forwards the cancellation to the underlying system resolver. + public func cancelQueries() { + self.underlying.cancelQueries() + } +} + +#endif // !os(WASI) diff --git a/Tests/NIOPosixTests/NIORandomizedDNSResolverTest.swift b/Tests/NIOPosixTests/NIORandomizedDNSResolverTest.swift new file mode 100644 index 0000000000..2c70e4b6ea --- /dev/null +++ b/Tests/NIOPosixTests/NIORandomizedDNSResolverTest.swift @@ -0,0 +1,202 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2026 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import Testing + +#if !os(WASI) +@testable import NIOPosix + +/// A mock DNS resolver that returns pre-configured results without real DNS lookups. +private final class MockResolver: Resolver, @unchecked Sendable { + private let loop: EventLoop + private let v4Results: [SocketAddress] + private let v6Results: [SocketAddress] + + init(loop: EventLoop, v4Results: [SocketAddress] = [], v6Results: [SocketAddress] = []) { + self.loop = loop + self.v4Results = v4Results + self.v6Results = v6Results + } + + func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.loop.makeSucceededFuture(self.v4Results) + } + + func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.loop.makeSucceededFuture(self.v6Results) + } + + func cancelQueries() {} +} + +/// A mock DNS resolver that always fails with an error. +private final class FailingMockResolver: Resolver, @unchecked Sendable { + private let loop: EventLoop + + struct MockDNSError: Error {} + + init(loop: EventLoop) { + self.loop = loop + } + + func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.loop.makeFailedFuture(MockDNSError()) + } + + func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.loop.makeFailedFuture(MockDNSError()) + } + + func cancelQueries() {} +} + +/// A mock DNS resolver that tracks whether cancelQueries was called. +private final class CancelTrackingMockResolver: Resolver, @unchecked Sendable { + private let loop: EventLoop + private(set) var cancelQueriesCalled = false + + init(loop: EventLoop) { + self.loop = loop + } + + func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.loop.makeSucceededFuture([]) + } + + func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { + self.loop.makeSucceededFuture([]) + } + + func cancelQueries() { + self.cancelQueriesCalled = true + } +} + +@Suite("NIORandomizedDNSResolverTest") +struct NIORandomizedDNSResolverTest { + + @Test + func defaultInitializerResolves() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + #expect(throws: Never.self) { + try group.syncShutdownGracefully() + } + } + + let v4Addresses = [try SocketAddress(ipAddress: "127.0.0.1", port: 80)] + let mock = MockResolver(loop: group.next(), v4Results: v4Addresses) + let resolver = NIORandomizedDNSResolver(resolver: mock) + + let addressV4 = try resolver.initiateAQuery(host: "127.0.0.1", port: 80).wait() + let addressV6 = try resolver.initiateAAAAQuery(host: "127.0.0.1", port: 80).wait() + let expectedV4 = try SocketAddress(ipAddress: "127.0.0.1", port: 80) + #expect(addressV4.count == 1) + #expect(addressV4[0] == expectedV4) + #expect(addressV6.isEmpty) + } + + @Test + func dnsFailurePropagates() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + #expect(throws: Never.self) { + try group.syncShutdownGracefully() + } + } + + let mock = FailingMockResolver(loop: group.next()) + let resolver = NIORandomizedDNSResolver(resolver: mock) + + #expect(throws: (any Error).self) { + try resolver.initiateAQuery(host: "any.host", port: 80).wait() + } + #expect(throws: (any Error).self) { + try resolver.initiateAAAAQuery(host: "any.host", port: 80).wait() + } + } + + @Test + func multipleAddressesAreReversed() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + #expect(throws: Never.self) { + try group.syncShutdownGracefully() + } + } + + let v4Addresses = try (1...5).map { + try SocketAddress(ipAddress: "10.0.0.\($0)", port: 80) + } + let mock = MockResolver(loop: group.next(), v4Results: v4Addresses) + let resolver = NIORandomizedDNSResolver( + resolver: mock, + shuffleFunction: { $0.reversed() } + ) + + let results = try resolver.initiateAQuery(host: "multi.example.com", port: 80).wait() + let expected = try (1...5).reversed().map { + try SocketAddress(ipAddress: "10.0.0.\($0)", port: 80) + } + + #expect(results.count == 5) + #expect(results == expected) + } + + @Test + func multipleV6AddressesAreReversed() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + #expect(throws: Never.self) { + try group.syncShutdownGracefully() + } + } + + let v6Addresses = try (1...5).map { + try SocketAddress(ipAddress: "fe80::\($0)", port: 443) + } + let mock = MockResolver(loop: group.next(), v6Results: v6Addresses) + let resolver = NIORandomizedDNSResolver( + resolver: mock, + shuffleFunction: { $0.reversed() } + ) + + let results = try resolver.initiateAAAAQuery(host: "multi.example.com", port: 443).wait() + let expected = try (1...5).reversed().map { + try SocketAddress(ipAddress: "fe80::\($0)", port: 443) + } + + #expect(results.count == 5) + #expect(results == expected) + } + + @Test + func cancelQueriesForwardsToUnderlying() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + #expect(throws: Never.self) { + try group.syncShutdownGracefully() + } + } + + let mock = CancelTrackingMockResolver(loop: group.next()) + let resolver = NIORandomizedDNSResolver(resolver: mock) + + #expect(!mock.cancelQueriesCalled) + resolver.cancelQueries() + #expect(mock.cancelQueriesCalled) + } +} +#endif // !os(WASI)