Skip to content

Fix NioAsyncWriter test on concurrency thread pool with single thread #3135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ let package = Package(
"NIOCore",
"NIOEmbedded",
"NIOFoundationCompat",
"NIOTestUtils",
swiftAtomics,
],
swiftSettings: strictConcurrencySettings
Expand Down Expand Up @@ -520,6 +521,7 @@ let package = Package(
dependencies: [
"NIOTestUtils",
"NIOCore",
"NIOConcurrencyHelpers",
"NIOEmbedded",
"NIOPosix",
]
Expand Down
116 changes: 116 additions & 0 deletions Sources/NIOTestUtils/NIOThreadPoolTaskExecutor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#if compiler(>=6)

import NIOPosix

/// Run a `NIOThreadPool` based `TaskExecutor` while executing the given `body`.
///
/// This function provides a `TaskExecutor`, **not** a `SerialExecutor`. The executor can be
/// used for setting the executor preference of a task.
///
/// Example usage:
/// ```swift
/// await withNIOThreadPoolTaskExecutor(numberOfThreads: 2) { taskExecutor in
/// await withDiscardingTaskGroup { group in
/// group.addTask(executorPreference: taskExecutor) { ... }
/// }
/// }
/// ```
///
/// - warning: Do not escape the task executor from the closure for later use and make sure that
/// all tasks running on the executor are completely finished before `body` returns.
/// For unstructured tasks, this means awaiting their results. If any task is still
/// running on the executor when `body` returns, this results in a fatalError.
/// It is highly recommended to use structured concurrency with this task executor.
///
/// - Parameters:
/// - numberOfThreads: The number of threads in the pool.
/// - body: The closure that will accept the task executor.
///
/// - Throws: When `body` throws.
///
/// - Returns: The value returned by `body`.
@inlinable
public func withNIOThreadPoolTaskExecutor<T, Failure>(
numberOfThreads: Int,
body: (NIOThreadPoolTaskExecutor) async throws(Failure) -> T
) async throws(Failure) -> T {
let taskExecutor = NIOThreadPoolTaskExecutor(numberOfThreads: numberOfThreads)
taskExecutor.start()

let result: Result<T, Failure>
do {
result = .success(try await body(taskExecutor))
} catch {
result = .failure(error)
}

await taskExecutor.shutdownGracefully()

return try result.get()
}

/// A task executor based on NIOThreadPool.
///
/// Provides a `TaskExecutor`, **not** a `SerialExecutor`. The executor can be
/// used for setting the executor preference of a task.
///
public final class NIOThreadPoolTaskExecutor: TaskExecutor {
let nioThreadPool: NIOThreadPool

/// Initialize a `NIOThreadPoolTaskExecutor`, using a thread pool with `numberOfThreads` threads.
///
/// - Parameters:
/// - numberOfThreads: The number of threads to use for the thread pool.
public init(numberOfThreads: Int) {
self.nioThreadPool = NIOThreadPool(numberOfThreads: numberOfThreads)
}

/// Start the `NIOThreadPoolTaskExecutor`.
public func start() {
nioThreadPool.start()
}

/// Gracefully shutdown this `NIOThreadPoolTaskExecutor`.
///
/// Make sure that all tasks running on the executor are finished before shutting down.
///
/// - warning: If any task is still running on the executor, this results in a fatalError.
public func shutdownGracefully() async {
do {
try await nioThreadPool.shutdownGracefully()
} catch {
fatalError("Failed to shutdown NIOThreadPool")
}
}

/// Enqueue a job.
///
/// Called by the concurrency runtime.
///
/// - Parameter job: The job to enqueue.
public func enqueue(_ job: consuming ExecutorJob) {
let unownedJob = UnownedJob(job)
self.nioThreadPool.submit { shouldRun in
guard case shouldRun = NIOThreadPool.WorkItemState.active else {
fatalError("Shutdown before all tasks finished")
}
unownedJob.runSynchronously(on: self.asUnownedTaskExecutor())
}
}
}

#endif // compiler(>=6)
90 changes: 50 additions & 40 deletions Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import DequeModule
import NIOConcurrencyHelpers
import NIOTestUtils
import XCTest

@testable import NIOCore
Expand Down Expand Up @@ -606,48 +607,57 @@ final class NIOAsyncWriterTests: XCTestCase {
self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 1)
}

func testSuspendingBufferedYield_whenWriterFinished() async throws {
self.sink.setWritability(to: false)

let bothSuspended = expectation(description: "suspended on both yields")
let suspendedAgain = ConditionLock(value: false)
self.delegate.didSuspendHandler = {
if self.delegate.didSuspendCallCount == 2 {
bothSuspended.fulfill()
} else if self.delegate.didSuspendCallCount > 2 {
suspendedAgain.lock()
suspendedAgain.unlock(withValue: true)
}
}

self.delegate.didYieldHandler = { _ in
if self.delegate.didYieldCallCount == 1 {
// Delay this yield until the other yield is suspended again.
suspendedAgain.lock(whenValue: true)
suspendedAgain.unlock()
func testWriterFinish_AndSuspendBufferedYield() async throws {
#if compiler(>=6)
try await withNIOThreadPoolTaskExecutor(numberOfThreads: 2) { taskExecutor in
try await withThrowingTaskGroup(of: Void.self) { group in
self.sink.setWritability(to: false)

let bothSuspended = expectation(description: "suspended on both yields")
let suspendedAgain = ConditionLock(value: false)
self.delegate.didSuspendHandler = {
if self.delegate.didSuspendCallCount == 2 {
bothSuspended.fulfill()
} else if self.delegate.didSuspendCallCount > 2 {
suspendedAgain.lock()
suspendedAgain.unlock(withValue: true)
}
}

self.delegate.didYieldHandler = { _ in
if self.delegate.didYieldCallCount == 1 {
// Delay this yield until the other yield is suspended again.
if suspendedAgain.lock(whenValue: true, timeoutSeconds: 5) {
suspendedAgain.unlock()
} else {
XCTFail("Timeout while waiting for other yield to suspend again.")
}
}
}

group.addTask(executorPreference: taskExecutor) { [writer] in
try await writer!.yield("message1")
}
group.addTask(executorPreference: taskExecutor) { [writer] in
try await writer!.yield("message2")
}

await fulfillment(of: [bothSuspended], timeout: 5)
self.writer.finish()

self.assert(suspendCallCount: 2, yieldCallCount: 0, terminateCallCount: 0)

// We have to become writable again to unbuffer the yields
// The first call to didYield will pause, so that the other yield will be suspended again.
self.sink.setWritability(to: true)

await XCTAssertNoThrow(try await group.next())
await XCTAssertNoThrow(try await group.next())

self.assert(suspendCallCount: 3, yieldCallCount: 2, terminateCallCount: 1)
}
}

let task1 = Task { [writer] in
try await writer!.yield("message1")
}
let task2 = Task { [writer] in
try await writer!.yield("message2")
}

await fulfillment(of: [bothSuspended], timeout: 1)
self.writer.finish()

self.assert(suspendCallCount: 2, yieldCallCount: 0, terminateCallCount: 0)

// We have to become writable again to unbuffer the yields
// The first call to didYield will pause, so that the other yield will be suspended again.
self.sink.setWritability(to: true)

await XCTAssertNoThrow(try await task1.value)
await XCTAssertNoThrow(try await task2.value)

self.assert(suspendCallCount: 3, yieldCallCount: 2, terminateCallCount: 1)
#endif // compiler(>=6)
}

func testWriterFinish_whenFinished() {
Expand Down
81 changes: 81 additions & 0 deletions Tests/NIOTestUtilsTests/NIOThreadPoolTaskExecutorTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2019-2025 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOConcurrencyHelpers
import NIOTestUtils
import XCTest

class NIOThreadPoolTaskExecutorTest: XCTestCase {
struct TestError: Error {}

func runTasksSimultaneously(numberOfTasks: Int) async {
await withNIOThreadPoolTaskExecutor(numberOfThreads: numberOfTasks) { taskExecutor in
await withDiscardingTaskGroup { group in
var taskBlockers = [ConditionLock<Bool>]()
defer {
// Unblock all tasks
for taskBlocker in taskBlockers {
taskBlocker.lock()
taskBlocker.unlock(withValue: true)
}
}

for taskNumber in 1...numberOfTasks {
let taskStarted = ConditionLock(value: false)
let taskBlocker = ConditionLock(value: false)
taskBlockers.append(taskBlocker)

// Start task and block it
group.addTask(executorPreference: taskExecutor) {
taskStarted.lock()
taskStarted.unlock(withValue: true)
taskBlocker.lock(whenValue: true)
taskBlocker.unlock()
}

// Verify that task was able to start
if taskStarted.lock(whenValue: true, timeoutSeconds: 5) {
taskStarted.unlock()
} else {
XCTFail("Task \(taskNumber) failed to start.")
break
}
}
}
}
}

func testRunsTaskOnSingleThread() async {
await runTasksSimultaneously(numberOfTasks: 1)
}

func testRunsMultipleTasksOnMultipleThreads() async {
await runTasksSimultaneously(numberOfTasks: 3)
}

func testReturnsBodyResult() async {
let expectedResult = "result"
let result = await withNIOThreadPoolTaskExecutor(numberOfThreads: 1) { _ in return expectedResult }
XCTAssertEqual(result, expectedResult)
}

func testRethrows() async {
do {
try await withNIOThreadPoolTaskExecutor(numberOfThreads: 1) { _ in throw TestError() }
XCTFail("Function did not rethrow.")
} catch {
XCTAssertTrue(error is TestError)
}
}
}
Loading