Skip to content

Commit c696508

Browse files
committed
feat: Implement NIOThreadPool using swift concurrency
1 parent c2509c2 commit c696508

1 file changed

Lines changed: 286 additions & 0 deletions

File tree

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftNIO open source project
4+
//
5+
// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import DequeModule
16+
import NIOConcurrencyHelpers
17+
18+
import class Atomics.ManagedAtomic
19+
import protocol NIOCore.EventLoop
20+
import class NIOCore.EventLoopFuture
21+
import enum NIOCore.System
22+
23+
/// Errors that may be thrown when executing work on a `NIOThreadPool`.
24+
public enum NIOThreadPoolError: Sendable {
25+
public struct ThreadPoolInactive: Error {
26+
public init() {}
27+
}
28+
29+
public struct UnsupportedOperation: Error {
30+
public init() {}
31+
}
32+
}
33+
34+
/// Drop‑in stand‑in for `NIOThreadPool`, powered by Swift Concurrency.
35+
@available(macOS 10.15, *)
36+
public final class NIOThreadPool: @unchecked Sendable {
37+
/// The state of the `WorkItem`.
38+
public enum WorkItemState: Sendable {
39+
/// The work item is currently being executed.
40+
case active
41+
/// The work item has been cancelled and will not run.
42+
case cancelled
43+
}
44+
45+
/// The work that should be done by the thread pool.
46+
public typealias WorkItem = @Sendable (WorkItemState) -> Void
47+
48+
@usableFromInline
49+
struct IdentifiableWorkItem: Sendable {
50+
@usableFromInline var workItem: WorkItem
51+
@usableFromInline var id: Int?
52+
}
53+
54+
private let shutdownFlag = ManagedAtomic(false)
55+
private let started = ManagedAtomic(false)
56+
private let numberOfThreads: Int
57+
private let workQueue = WorkQueue()
58+
private let workerTasksLock = NIOLock()
59+
private var workerTasks: [Task<Void, Never>] = []
60+
61+
public init(numberOfThreads: Int? = nil) {
62+
let threads = numberOfThreads ?? System.coreCount
63+
self.numberOfThreads = max(1, threads)
64+
}
65+
66+
public func start() {
67+
startWorkersIfNeeded()
68+
}
69+
70+
private var isActive: Bool {
71+
self.started.load(ordering: .acquiring) && !self.shutdownFlag.load(ordering: .acquiring)
72+
}
73+
74+
// MARK: - Public API -
75+
76+
public func submit(_ body: @escaping WorkItem) {
77+
guard self.isActive else {
78+
body(.cancelled)
79+
return
80+
}
81+
82+
startWorkersIfNeeded()
83+
84+
Task {
85+
await self.workQueue.enqueue(IdentifiableWorkItem(workItem: body, id: nil))
86+
}
87+
}
88+
89+
@preconcurrency
90+
public func submit<T>(on eventLoop: EventLoop, _ fn: @escaping @Sendable () throws -> T)
91+
-> EventLoopFuture<T>
92+
{
93+
self.submit(on: eventLoop) { () throws -> _UncheckedSendable<T> in
94+
_UncheckedSendable(try fn())
95+
}.map { $0.value }
96+
}
97+
98+
public func submit<T: Sendable>(
99+
on eventLoop: EventLoop,
100+
_ fn: @escaping @Sendable () throws -> T
101+
) -> EventLoopFuture<T> {
102+
self.makeFutureByRunningOnPool(eventLoop: eventLoop, fn)
103+
}
104+
105+
/// Async helper mirroring `runIfActive` without an EventLoop context.
106+
public func runIfActive<T: Sendable>(_ body: @escaping @Sendable () throws -> T) async throws -> T
107+
{
108+
try Task.checkCancellation()
109+
guard self.isActive else { throw CancellationError() }
110+
111+
return try await Task {
112+
try Task.checkCancellation()
113+
guard self.isActive else { throw CancellationError() }
114+
return try body()
115+
}.value
116+
}
117+
118+
/// Event‑loop variant returning only the future.
119+
@preconcurrency
120+
public func runIfActive<T>(eventLoop: EventLoop, _ body: @escaping @Sendable () throws -> T)
121+
-> EventLoopFuture<T>
122+
{
123+
self.runIfActive(eventLoop: eventLoop) { () throws -> _UncheckedSendable<T> in
124+
_UncheckedSendable(try body())
125+
}.map { $0.value }
126+
}
127+
128+
public func runIfActive<T: Sendable>(
129+
eventLoop: EventLoop,
130+
_ body: @escaping @Sendable () throws -> T
131+
) -> EventLoopFuture<T> {
132+
self.makeFutureByRunningOnPool(eventLoop: eventLoop, body)
133+
}
134+
135+
private func makeFutureByRunningOnPool<T: Sendable>(
136+
eventLoop: EventLoop,
137+
_ body: @escaping @Sendable () throws -> T
138+
) -> EventLoopFuture<T> {
139+
guard self.isActive else {
140+
return eventLoop.makeFailedFuture(NIOThreadPoolError.ThreadPoolInactive())
141+
}
142+
143+
let promise = eventLoop.makePromise(of: T.self)
144+
self.submit { state in
145+
switch state {
146+
case .active:
147+
do {
148+
let value = try body()
149+
promise.succeed(value)
150+
} catch {
151+
promise.fail(error)
152+
}
153+
case .cancelled:
154+
promise.fail(NIOThreadPoolError.ThreadPoolInactive())
155+
}
156+
}
157+
return promise.futureResult
158+
}
159+
160+
// Lifecycle --------------------------------------------------------------
161+
162+
public static let singleton: NIOThreadPool = {
163+
let pool = NIOThreadPool()
164+
pool.start()
165+
return pool
166+
}()
167+
168+
@preconcurrency
169+
public func shutdownGracefully(_ callback: @escaping @Sendable (Error?) -> Void = { _ in }) {
170+
_shutdownGracefully {
171+
callback(nil)
172+
}
173+
}
174+
175+
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
176+
public func shutdownGracefully() async throws {
177+
try await withCheckedThrowingContinuation { continuation in
178+
_shutdownGracefully {
179+
continuation.resume(returning: ())
180+
}
181+
}
182+
}
183+
184+
private func _shutdownGracefully(completion: (@Sendable () -> Void)? = nil) {
185+
if shutdownFlag.exchange(true, ordering: .acquiring) {
186+
completion?()
187+
return
188+
}
189+
190+
Task {
191+
let remaining = await workQueue.shutdown()
192+
for item in remaining {
193+
item.workItem(.cancelled)
194+
}
195+
196+
workerTasksLock.withLock {
197+
for worker in workerTasks {
198+
worker.cancel()
199+
}
200+
workerTasks.removeAll()
201+
}
202+
203+
started.store(false, ordering: .releasing)
204+
completion?()
205+
}
206+
}
207+
208+
// MARK: - Worker infrastructure
209+
210+
private func startWorkersIfNeeded() {
211+
if self.shutdownFlag.load(ordering: .acquiring) {
212+
return
213+
}
214+
215+
if self.started.compareExchange(expected: false, desired: true, ordering: .acquiring).exchanged
216+
{
217+
spawnWorkers()
218+
}
219+
}
220+
221+
private func spawnWorkers() {
222+
workerTasksLock.withLock {
223+
guard workerTasks.isEmpty else { return }
224+
for index in 0..<numberOfThreads {
225+
workerTasks.append(
226+
Task.detached { [weak self] in
227+
await self?.workerLoop(identifier: index)
228+
}
229+
)
230+
}
231+
}
232+
}
233+
234+
private func workerLoop(identifier _: Int) async {
235+
while let workItem = await workQueue.nextWorkItem(shutdownFlag: shutdownFlag) {
236+
if self.shutdownFlag.load(ordering: .acquiring) {
237+
workItem.workItem(.cancelled)
238+
} else {
239+
workItem.workItem(.active)
240+
}
241+
}
242+
}
243+
244+
actor WorkQueue {
245+
private var queue = Deque<IdentifiableWorkItem>()
246+
private var waiters: [CheckedContinuation<IdentifiableWorkItem?, Never>] = []
247+
private var isShuttingDown = false
248+
249+
func enqueue(_ item: IdentifiableWorkItem) {
250+
if let continuation = waiters.popLast() {
251+
continuation.resume(returning: item)
252+
} else {
253+
queue.append(item)
254+
}
255+
}
256+
257+
func nextWorkItem(shutdownFlag: ManagedAtomic<Bool>) async -> IdentifiableWorkItem? {
258+
if !queue.isEmpty {
259+
return queue.removeFirst()
260+
}
261+
262+
if isShuttingDown || shutdownFlag.load(ordering: .acquiring) {
263+
return nil
264+
}
265+
266+
return await withCheckedContinuation { continuation in
267+
waiters.append(continuation)
268+
}
269+
}
270+
271+
func shutdown() -> [IdentifiableWorkItem] {
272+
isShuttingDown = true
273+
let remaining = Array(queue)
274+
queue.removeAll()
275+
while let waiter = waiters.popLast() {
276+
waiter.resume(returning: nil)
277+
}
278+
return remaining
279+
}
280+
}
281+
282+
private struct _UncheckedSendable<T>: @unchecked Sendable {
283+
let value: T
284+
init(_ value: T) { self.value = value }
285+
}
286+
}

0 commit comments

Comments
 (0)