Skip to content

Commit a265638

Browse files
committed
Fix async waitUntil waiting until whenever the closure finishes before returning
1 parent 7f21319 commit a265638

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

Sources/Nimble/Utils/AsyncAwait.swift

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,30 @@ final class BlockingTask: Sendable {
5656
}
5757

5858
func run() async {
59-
if let continuation = lock.withLock({ self.continuation }) {
59+
let continuation: CheckedContinuation<Void, Never>? = {
60+
lock.lock()
61+
let continuation = self.continuation
62+
lock.unlock()
63+
return continuation
64+
}()
65+
66+
if let continuation {
6067
continuation.resume()
6168
}
6269
await withTaskCancellationHandler {
6370
await withCheckedContinuation {
6471
lock.lock()
65-
defer { lock.unlock() }
6672

73+
let shouldResume: Bool
6774
if finished {
68-
$0.resume()
75+
shouldResume = true
6976
} else {
7077
self.continuation = $0
78+
shouldResume = false
79+
}
80+
lock.unlock()
81+
if shouldResume {
82+
$0.resume()
7183
}
7284
}
7385
} onCancel: {
@@ -78,25 +90,27 @@ final class BlockingTask: Sendable {
7890

7991
func complete() {
8092
lock.lock()
81-
defer { lock.unlock() }
93+
let wasFinished = finished
94+
finished = true
95+
lock.unlock()
8296

83-
if finished {
97+
if wasFinished {
8498
fail(
8599
"waitUntil(...) expects its completion closure to be only called once",
86100
location: sourceLocation
87101
)
88102
} else {
89-
finished = true
90103
self.continuation?.resume()
91104
self.continuation = nil
92105
}
93106
}
94107

95108
func handleCancellation() {
96109
lock.lock()
97-
defer { lock.unlock() }
110+
let wasFinished = finished
111+
lock.unlock()
98112

99-
guard finished == false else {
113+
guard wasFinished == false else {
100114
return
101115
}
102116
continuation?.resume()
@@ -151,15 +165,15 @@ internal func performBlock(
151165
#endif
152166
#endif
153167

154-
return await withTaskGroup(of: Void.self) { taskGroup in
168+
return await withTaskGroup(of: Void.self, returning: AsyncPollResult<Void>.self) { taskGroup in
155169
let blocker = BlockingTask(sourceLocation: sourceLocation)
156170
let tracker = ResultTracker<Void>()
157171

158172
taskGroup.addTask {
159173
await blocker.run()
160174
}
161175

162-
taskGroup.addTask {
176+
let task = Task {
163177
do {
164178
try await closure {
165179
blocker.complete()
@@ -174,9 +188,7 @@ internal func performBlock(
174188
do {
175189
try await Task.sleep(nanoseconds: (timeout + leeway).nanoseconds)
176190
tracker.finish(with: .timedOut)
177-
} catch {
178-
179-
}
191+
} catch {}
180192
}
181193

182194
var result: AsyncPollResult<Void> = .incomplete
@@ -189,6 +201,7 @@ internal func performBlock(
189201
break
190202
}
191203
taskGroup.cancelAll()
204+
task.cancel()
192205
return result
193206
}
194207
}

Tests/NimbleTests/AsyncAwaitTest.swift

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ final class AsyncAwaitTest: XCTestCase { // swiftlint:disable:this type_body_len
272272
let timeoutQueue = DispatchQueue(label: "Nimble.waitUntilTest.timeout", qos: .background)
273273
let timer = DispatchSource.makeTimerSource(flags: .strict, queue: timeoutQueue)
274274
timer.schedule(
275-
deadline: DispatchTime.now() + 5,
275+
deadline: DispatchTime.now() + 60,
276276
repeating: .never,
277277
leeway: .milliseconds(1)
278278
)
@@ -282,10 +282,14 @@ final class AsyncAwaitTest: XCTestCase { // swiftlint:disable:this type_body_len
282282
}
283283
timer.resume()
284284

285-
for index in 0..<1000 {
286-
if failed { break }
285+
let runQueue = DispatchQueue(label: "Nimble.waitUntilTest.runQueue", attributes: .concurrent)
286+
287+
for _ in 0..<1000 {
288+
if failed {
289+
break
290+
}
287291
await waitUntil() { done in
288-
DispatchQueue(label: "Nimble.waitUntilTest.\(index)").async {
292+
runQueue.async {
289293
done()
290294
}
291295
}

0 commit comments

Comments
 (0)