Skip to content

Commit f7f6ff1

Browse files
committed
Fix async waitUntil waiting until whenever the closure finishes before returning
1 parent a535e4c commit f7f6ff1

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

Sources/Nimble/Utils/AsyncAwait.swift

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,30 @@ final class BlockingTask: Sendable {
6161
}
6262

6363
func run() async {
64-
if let continuation = lock.withLock({ self.continuation }) {
64+
let continuation: CheckedContinuation<Void, Never>? = {
65+
lock.lock()
66+
let continuation = self.continuation
67+
lock.unlock()
68+
return continuation
69+
}()
70+
71+
if let continuation {
6572
continuation.resume()
6673
}
6774
await withTaskCancellationHandler {
6875
await withCheckedContinuation {
6976
lock.lock()
70-
defer { lock.unlock() }
7177

78+
let shouldResume: Bool
7279
if finished {
73-
$0.resume()
80+
shouldResume = true
7481
} else {
7582
self.continuation = $0
83+
shouldResume = false
84+
}
85+
lock.unlock()
86+
if shouldResume {
87+
$0.resume()
7688
}
7789
}
7890
} onCancel: {
@@ -83,25 +95,27 @@ final class BlockingTask: Sendable {
8395

8496
func complete() {
8597
lock.lock()
86-
defer { lock.unlock() }
98+
let wasFinished = finished
99+
finished = true
100+
lock.unlock()
87101

88-
if finished {
102+
if wasFinished {
89103
fail(
90104
"waitUntil(...) expects its completion closure to be only called once",
91105
location: sourceLocation
92106
)
93107
} else {
94-
finished = true
95108
self.continuation?.resume()
96109
self.continuation = nil
97110
}
98111
}
99112

100113
func handleCancellation() {
101114
lock.lock()
102-
defer { lock.unlock() }
115+
let wasFinished = finished
116+
lock.unlock()
103117

104-
guard finished == false else {
118+
guard wasFinished == false else {
105119
return
106120
}
107121
continuation?.resume()
@@ -156,15 +170,15 @@ internal func performBlock(
156170
#endif
157171
#endif
158172

159-
return await withTaskGroup(of: Void.self) { taskGroup in
173+
return await withTaskGroup(of: Void.self, returning: AsyncPollResult<Void>.self) { taskGroup in
160174
let blocker = BlockingTask(sourceLocation: sourceLocation)
161175
let tracker = ResultTracker<Void>()
162176

163177
taskGroup.addTask {
164178
await blocker.run()
165179
}
166180

167-
taskGroup.addTask {
181+
let task = Task {
168182
do {
169183
try await closure {
170184
blocker.complete()
@@ -179,9 +193,7 @@ internal func performBlock(
179193
do {
180194
try await Task.sleep(nanoseconds: (timeout + leeway).nanoseconds)
181195
tracker.finish(with: .timedOut)
182-
} catch {
183-
184-
}
196+
} catch {}
185197
}
186198

187199
var result: AsyncPollResult<Void> = .incomplete
@@ -194,6 +206,7 @@ internal func performBlock(
194206
break
195207
}
196208
taskGroup.cancelAll()
209+
task.cancel()
197210
return result
198211
}
199212
}

Tests/NimbleTests/AsyncAwaitTest.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,14 @@ final class AsyncAwaitTest: XCTestCase { // swiftlint:disable:this type_body_len
282282
}
283283
timer.resume()
284284

285-
for index in 0..<1000 {
286-
if failed { break }
285+
let runQueue = DispatchQueue(label: "Nimble.waitUntilTest.runQueue", attributes: .concurrent)
286+
287+
for _ in 0..<1000 {
288+
if failed {
289+
break
290+
}
287291
await waitUntil() { done in
288-
DispatchQueue(label: "Nimble.waitUntilTest.\(index)").async {
292+
runQueue.async {
289293
done()
290294
}
291295
}

0 commit comments

Comments
 (0)