Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions Sources/NIOPosix/NIORandomizedDNSResolver.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2026 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#if !os(WASI)

import NIOCore

/// A DNS resolver that randomizes the order of addresses returned by the system resolver
/// to enable round-robin DNS load balancing.
///
/// By default, the system's `getaddrinfo` function returns addresses in a deterministic order
/// as specified by RFC 6724 (destination address selection). While this ordering is correct
/// per the RFC, it defeats DNS-based load balancing techniques such as those used by Kubernetes
/// headless services, where multiple A/AAAA records are returned and clients are expected to
/// distribute connections across all available backends.
///
/// `NIORandomizedDNSResolver` wraps the standard `getaddrinfo`-based resolver and shuffles the
/// returned addresses within each address family (IPv4 and IPv6 independently) before returning
/// them. This ensures that connection attempts are distributed across all available backends
/// rather than always targeting the same one.
///
/// This resolver is a single-use object: it can only be used to perform a single host resolution,
/// just like the underlying system resolver.
///
/// ### Usage with `ClientBootstrap`
///
/// ```swift
/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
/// let bootstrap = ClientBootstrap(group: group)
/// .resolver(NIORandomizedDNSResolver(loop: group.next()))
/// let channel = try await bootstrap.connect(
/// host: "my-service.default.svc.cluster.local",
/// port: 8080
/// ).get()
/// ```
public final class NIORandomizedDNSResolver: Resolver, Sendable {
private let underlying: Resolver & Sendable

/// The function used to shuffle address results. Defaults to `Array.shuffled()`.
/// Exposed as `internal` to allow deterministic testing via `@testable import`.
internal let shuffleFunction: @Sendable ([SocketAddress]) -> [SocketAddress]

/// Create a new `NIORandomizedDNSResolver` for use with TCP stream connections.
///
/// - Parameters:
/// - loop: The `EventLoop` to use for DNS resolution.
public init(loop: EventLoop) {
self.underlying = GetaddrinfoResolver(
loop: loop,
aiSocktype: .stream,
aiProtocol: .tcp
)
self.shuffleFunction = { $0.shuffled() }
}

/// Internal initializer for testing with an injectable resolver and shuffle function.
///
/// - Parameters:
/// - resolver: The underlying resolver to delegate DNS queries to.
/// - shuffleFunction: A function that reorders the address array. Defaults to `Array.shuffled()`.
init(
resolver: Resolver & Sendable,
shuffleFunction: @escaping @Sendable ([SocketAddress]) -> [SocketAddress] = { $0.shuffled() }
) {
self.underlying = resolver
self.shuffleFunction = shuffleFunction
}

/// Initiate a DNS A query for a given host.
///
/// The results from the underlying system resolver are shuffled before being returned.
///
/// - Parameters:
/// - host: The hostname to do an A lookup on.
/// - port: The port we'll be connecting to.
/// - Returns: An `EventLoopFuture` that fires with the shuffled result of the A lookup.
public func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.underlying.initiateAQuery(host: host, port: port).map { self.shuffleFunction($0) }
}

/// Initiate a DNS AAAA query for a given host.
///
/// The results from the underlying system resolver are shuffled before being returned.
///
/// - Parameters:
/// - host: The hostname to do a AAAA lookup on.
/// - port: The port we'll be connecting to.
/// - Returns: An `EventLoopFuture` that fires with the shuffled result of the AAAA lookup.
public func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.underlying.initiateAAAAQuery(host: host, port: port).map { self.shuffleFunction($0) }
}

/// Cancel all outstanding DNS queries.
///
/// This forwards the cancellation to the underlying system resolver.
public func cancelQueries() {
self.underlying.cancelQueries()
}
}

#endif // !os(WASI)
205 changes: 205 additions & 0 deletions Tests/NIOPosixTests/NIORandomizedDNSResolverTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2026 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOCore
import Testing

#if !os(WASI)
@testable import NIOPosix

/// A mock DNS resolver that returns pre-configured results without real DNS lookups.
private final class MockResolver: Resolver, @unchecked Sendable {
private let loop: EventLoop
private let v4Results: [SocketAddress]
private let v6Results: [SocketAddress]

init(loop: EventLoop, v4Results: [SocketAddress] = [], v6Results: [SocketAddress] = []) {
self.loop = loop
self.v4Results = v4Results
self.v6Results = v6Results
}

func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.loop.makeSucceededFuture(self.v4Results)
}

func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.loop.makeSucceededFuture(self.v6Results)
}

func cancelQueries() {}
}

/// A mock DNS resolver that always fails with an error.
private final class FailingMockResolver: Resolver, @unchecked Sendable {
private let loop: EventLoop

struct MockDNSError: Error {}

init(loop: EventLoop) {
self.loop = loop
}

func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.loop.makeFailedFuture(MockDNSError())
}

func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.loop.makeFailedFuture(MockDNSError())
}

func cancelQueries() {}
}

/// A mock DNS resolver that tracks whether cancelQueries was called.
private final class CancelTrackingMockResolver: Resolver, @unchecked Sendable {
private let loop: EventLoop
private(set) var cancelQueriesCalled = false

init(loop: EventLoop) {
self.loop = loop
}

func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.loop.makeSucceededFuture([])
}

func initiateAAAAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> {
self.loop.makeSucceededFuture([])
}

func cancelQueries() {
self.cancelQueriesCalled = true
}
}

@Suite("NIORandomizedDNSResolverTest")
struct NIORandomizedDNSResolverTest {

@Test
func defaultInitializerResolves() async throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
#expect(throws: Never.self) {
try group.syncShutdownGracefully()
}
}

let resolver = NIORandomizedDNSResolver(loop: group.next())

// Both queries must be initiated before awaiting — initiateAAAAQuery
// triggers the actual getaddrinfo call that completes both futures.
let v4Future = resolver.initiateAQuery(host: "127.0.0.1", port: 80)
let v6Future = resolver.initiateAAAAQuery(host: "127.0.0.1", port: 80)

let addressV4 = try await v4Future.get()
let addressV6 = try await v6Future.get()
let expectedV4 = try SocketAddress(ipAddress: "127.0.0.1", port: 80)
#expect(addressV4.count == 1)
#expect(addressV4[0] == expectedV4)
#expect(addressV6.isEmpty)
}

@Test
func dnsFailurePropagates() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
#expect(throws: Never.self) {
try group.syncShutdownGracefully()
}
}

let mock = FailingMockResolver(loop: group.next())
let resolver = NIORandomizedDNSResolver(resolver: mock)

#expect(throws: (any Error).self) {
try resolver.initiateAQuery(host: "any.host", port: 80).wait()
}
#expect(throws: (any Error).self) {
try resolver.initiateAAAAQuery(host: "any.host", port: 80).wait()
}
}

@Test
func multipleAddressesAreReversed() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
#expect(throws: Never.self) {
try group.syncShutdownGracefully()
}
}

let v4Addresses = try (1...5).map {
try SocketAddress(ipAddress: "10.0.0.\($0)", port: 80)
}
let mock = MockResolver(loop: group.next(), v4Results: v4Addresses)
let resolver = NIORandomizedDNSResolver(
resolver: mock,
shuffleFunction: { $0.reversed() }
)

let results = try resolver.initiateAQuery(host: "multi.example.com", port: 80).wait()
let expected = try (1...5).reversed().map {
try SocketAddress(ipAddress: "10.0.0.\($0)", port: 80)
}

#expect(results.count == 5)
#expect(results == expected)
}

@Test
func multipleV6AddressesAreReversed() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
#expect(throws: Never.self) {
try group.syncShutdownGracefully()
}
}

let v6Addresses = try (1...5).map {
try SocketAddress(ipAddress: "fe80::\($0)", port: 443)
}
let mock = MockResolver(loop: group.next(), v6Results: v6Addresses)
let resolver = NIORandomizedDNSResolver(
resolver: mock,
shuffleFunction: { $0.reversed() }
)

let results = try resolver.initiateAAAAQuery(host: "multi.example.com", port: 443).wait()
let expected = try (1...5).reversed().map {
try SocketAddress(ipAddress: "fe80::\($0)", port: 443)
}

#expect(results.count == 5)
#expect(results == expected)
}

@Test
func cancelQueriesForwardsToUnderlying() throws {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer {
#expect(throws: Never.self) {
try group.syncShutdownGracefully()
}
}

let mock = CancelTrackingMockResolver(loop: group.next())
let resolver = NIORandomizedDNSResolver(resolver: mock)

#expect(!mock.cancelQueriesCalled)
resolver.cancelQueries()
#expect(mock.cancelQueriesCalled)
}
}
#endif // !os(WASI)
Loading