Skip to content

Commit cb128b3

Browse files
committed
feat: cancellation
1 parent af842cf commit cb128b3

File tree

2 files changed

+240
-264
lines changed

2 files changed

+240
-264
lines changed

src/Std/Sync/CancellationToken.lean

Lines changed: 114 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,50 +15,144 @@ public import Std.Internal.Async.Select
1515
public section
1616

1717
/-!
18-
This module implements a hierarchical cancellation token system with bottom-up cancellation propagation and automatic cleanup.
18+
This module contains the implementation of `Std.CancellationToken`. `Std.CancellationToken` provides a
19+
cancellation primitive for signaling cancellation between tasks or threads. It supports both synchronous
20+
and asynchronous waiting, and is useful for cases where you want to notify one or more waiters
21+
that a cancellation has occurred.
22+
23+
When cancelled, all waiting consumers receive `true`. When unregistered normally, consumers receive `false`.
24+
Once cancelled, the token remains in a cancelled state and all future waits immediately return `true`.
1925
-/
2026

2127
namespace Std
22-
2328
open Std.Internal.IO.Async
2429

30+
inductive CancellationToken.Consumer where
31+
| normal (promise : IO.Promise Unit)
32+
| select (finished : Waiter Unit)
33+
34+
def CancellationToken.Consumer.resolve (c : Consumer) : BaseIO Bool := do
35+
match c with
36+
| .normal promise =>
37+
promise.resolve ()
38+
return true
39+
| .select waiter =>
40+
let lose := return false
41+
let win promise := do
42+
promise.resolve (.ok ())
43+
return true
44+
waiter.race lose win
45+
2546
/--
26-
A cancellation token provides a way to cancel operations and gracefully shutdown.
47+
The central state structure for a `CancellationToken`.
2748
-/
28-
@[expose]
29-
def CancellationToken := IO.Promise Unit
49+
structure CancellationToken.State where
50+
/--
51+
Whether this token has been cancelled.
52+
-/
53+
cancelled : Bool
54+
55+
/--
56+
Consumers that are blocked waiting for cancellation.
57+
--/
58+
consumers : Std.Queue (CancellationToken.Consumer)
59+
60+
/--
61+
A cancellation token is a synchronization primitive that allows multiple consumers to wait
62+
until cancellation is requested.
63+
-/
64+
structure CancellationToken where
65+
state : Std.Mutex CancellationToken.State
3066

3167
namespace CancellationToken
3268

3369
/--
34-
Creates a new cancellation token.
70+
Create a new cancellation token.
71+
-/
72+
def new : BaseIO CancellationToken := do
73+
return { state := ← Std.Mutex.new { cancelled := false, consumers := ∅ } }
74+
75+
/--
76+
Cancel the token, notifying all currently waiting consumers with `true`.
77+
Once cancelled, the token remains cancelled.
78+
-/
79+
def cancel (x : CancellationToken) : BaseIO Unit := do
80+
x.state.atomically do
81+
let mut st ← get
82+
83+
if st.cancelled then
84+
return
85+
86+
let mut remainingConsumers := st.consumers
87+
st := { cancelled := true, consumers := ∅ }
88+
89+
while true do
90+
if let some (consumer, rest) := remainingConsumers.dequeue? then
91+
remainingConsumers := rest
92+
discard <| consumer.resolve
93+
else
94+
break
95+
96+
set st
97+
98+
/--
99+
Check if the token is cancelled.
35100
-/
36-
def new : IO CancellationToken :=
37-
IO.Promise.new
101+
def isCancelled (x : CancellationToken) : BaseIO Bool := do
102+
x.state.atomically do
103+
let st ← get
104+
return st.cancelled
38105

39106
/--
40-
Cancels the token.
107+
Wait for cancellation. Returns a task that completes with `true` when cancelled,
108+
or `false` if unregistered normally. If already cancelled, immediately returns `true`.
41109
-/
42-
def cancel (token : CancellationToken) : BaseIO Unit :=
43-
token.resolve ()
110+
def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
111+
x.state.atomically do
112+
let st ← get
113+
114+
if st.cancelled then
115+
return Task.pure (.ok ())
116+
117+
let promise ← IO.Promise.new
118+
119+
modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) }
120+
121+
IO.bindTask promise.result? fun
122+
| some _ => pure <| Task.pure (.ok ())
123+
| none => throw (IO.userError "cancellation token dropped")
44124

45125
/--
46-
Creates a selector that resolves when the token is cancelled.
126+
Creates a selector that waits for cancellation
47127
-/
48-
def selector (token : CancellationToken) : Selector Unit where
128+
def selector (token : CancellationToken) : Selector Unit := {
49129
tryFn := do
50-
if ← token.isResolved then
130+
if ← token.isCancelled then
51131
return some ()
52132
else
53133
return none
54134

55-
registerFn waiter := do
56-
IO.chainTask token.result? fun
57-
| some _ => waiter.promise.resolve (.ok ())
58-
| none => return ()
135+
registerFn := fun waiter => do
136+
token.state.atomically do
137+
let st ← get
59138

60-
unregisterFn := return ()
139+
if st.cancelled then
140+
discard <| waiter.race (return false) (fun promise => do
141+
promise.resolve (.ok ())
142+
return true)
143+
else
144+
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
61145

62-
end CancellationToken
146+
unregisterFn := do
147+
token.state.atomically do
148+
let st ← get
63149

150+
let consumers ← st.consumers.filterM fun
151+
| .normal _ => return true
152+
| .select waiter => return !(← waiter.checkFinished)
153+
154+
set { st with consumers }
155+
}
156+
157+
end CancellationToken
64158
end Std

0 commit comments

Comments
 (0)