Skip to content

Commit ee1f901

Browse files
committed
feat: context
1 parent 856825a commit ee1f901

File tree

5 files changed

+578
-10
lines changed

5 files changed

+578
-10
lines changed

src/Std/Sync.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ public import Std.Sync.Notify
1616
public import Std.Sync.Broadcast
1717
public import Std.Sync.StreamMap
1818
public import Std.Sync.CancellationToken
19+
public import Std.Sync.Context
1920

2021
@[expose] public section

src/Std/Sync/CancellationToken.lean

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@ that a cancellation has occurred.
2323
namespace Std
2424
open Std.Internal.IO.Async
2525

26+
/--
27+
Reasons for cancellation.
28+
-/
29+
inductive CancellationReason where
30+
/-- Cancelled due to a deadline or timeout -/
31+
| deadline
32+
/-- Cancelled due to shutdown -/
33+
| shutdown
34+
/-- Explicitly cancelled -/
35+
| cancel
36+
/-- Custom cancellation reason -/
37+
| custom (msg : String)
38+
deriving Repr, BEq
39+
40+
instance : ToString CancellationReason where
41+
toString
42+
| .deadline => "deadline"
43+
| .shutdown => "shutdown"
44+
| .cancel => "cancel"
45+
| .custom msg => s!"custom(\"{msg}\")"
46+
2647
inductive CancellationToken.Consumer where
2748
| normal (promise : IO.Promise Unit)
2849
| select (finished : Waiter Unit)
@@ -44,9 +65,9 @@ The central state structure for a `CancellationToken`.
4465
-/
4566
structure CancellationToken.State where
4667
/--
47-
Whether this token has been cancelled.
68+
The cancellation reason if cancelled, none otherwise.
4869
-/
49-
cancelled : Bool
70+
reason : Option CancellationReason
5071

5172
/--
5273
Consumers that are blocked waiting for cancellation.
@@ -66,21 +87,21 @@ namespace CancellationToken
6687
Create a new cancellation token.
6788
-/
6889
def new : BaseIO CancellationToken := do
69-
return { state := ← Std.Mutex.new { cancelled := false, consumers := ∅ } }
90+
return { state := ← Std.Mutex.new { reason := none, consumers := ∅ } }
7091

7192
/--
72-
Cancel the token, notifying all currently waiting consumers with `true`.
93+
Cancel the token with the given reason, notifying all currently waiting consumers.
7394
Once cancelled, the token remains cancelled.
7495
-/
75-
def cancel (x : CancellationToken) : BaseIO Unit := do
96+
def cancel (x : CancellationToken) (reason : CancellationReason := .cancel) : BaseIO Unit := do
7697
x.state.atomically do
7798
let mut st ← get
7899

79-
if st.cancelled then
100+
if st.reason.isSome then
80101
return
81102

82103
let mut remainingConsumers := st.consumers
83-
st := { cancelled := true, consumers := ∅ }
104+
st := { reason := some reason, consumers := ∅ }
84105

85106
while true do
86107
if let some (consumer, rest) := remainingConsumers.dequeue? then
@@ -97,7 +118,15 @@ Check if the token is cancelled.
97118
def isCancelled (x : CancellationToken) : BaseIO Bool := do
98119
x.state.atomically do
99120
let st ← get
100-
return st.cancelled
121+
return st.reason.isSome
122+
123+
/--
124+
Get the cancellation reason if the token is cancelled.
125+
-/
126+
def getCancellationReason (x : CancellationToken) : BaseIO (Option CancellationReason) := do
127+
x.state.atomically do
128+
let st ← get
129+
return st.reason
101130

102131
/--
103132
Wait for cancellation. Returns a task that completes when cancelled,
@@ -106,7 +135,7 @@ def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
106135
x.state.atomically do
107136
let st ← get
108137

109-
if st.cancelled then
138+
if st.reason.isSome then
110139
return Task.pure (.ok ())
111140

112141
let promise ← IO.Promise.new
@@ -131,7 +160,7 @@ def selector (token : CancellationToken) : Selector Unit := {
131160
token.state.atomically do
132161
let st ← get
133162

134-
if st.cancelled then
163+
if st.reason.isSome then
135164
discard <| waiter.race (return false) (fun promise => do
136165
promise.resolve (.ok ())
137166
return true)

src/Std/Sync/Context.lean

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Sofia Rodrigues
5+
-/
6+
7+
module
8+
9+
prelude
10+
public import Std.Data
11+
public import Init.System.Promise
12+
public import Init.Data.Queue
13+
public import Std.Sync.Mutex
14+
public import Std.Sync.CancellationToken
15+
public import Std.Internal.Async.Select
16+
17+
public section
18+
19+
/-!
20+
Context interface for cancellation and deadline management
21+
-/
22+
23+
namespace Std
24+
open Std.Internal.IO.Async
25+
26+
/--
27+
The central state structure shared by all context types.cc
28+
-/
29+
structure ContextState where
30+
/--
31+
Map of token IDs to optional tokens and their children.
32+
`none` represents a background context that cannot be cancelled.
33+
-/
34+
tokens : TreeMap UInt64 (Option CancellationToken × Array UInt64) := .empty
35+
36+
/--
37+
Next available ID
38+
-/
39+
id : UInt64 := 1
40+
41+
/--
42+
A cancellation context that allows multiple consumers to wait
43+
until cancellation is requested. Forms a tree structure where
44+
cancelling a parent cancels all children.
45+
-/
46+
structure Context where
47+
state : Std.Mutex ContextState
48+
token : CancellationToken
49+
id : UInt64
50+
51+
namespace Context
52+
53+
/--
54+
Create a new root cancellation context.
55+
-/
56+
def new : BaseIO Context := do
57+
let token ← Std.CancellationToken.new
58+
return {
59+
state := ← Std.Mutex.new { tokens := .empty |>.insert 0 (some token, #[]) },
60+
token,
61+
id := 0
62+
}
63+
64+
/--
65+
Fork a child context from a parent. If the parent is already cancelled,
66+
returns the parent context. Otherwise, creates a new child that will be
67+
cancelled when the parent is cancelled.
68+
-/
69+
def fork (root : Context) : BaseIO Context := do
70+
if ← root.token.isCancelled then
71+
return root
72+
73+
let token ← Std.CancellationToken.new
74+
75+
root.state.atomically do
76+
let st ← get
77+
let newId := st.id
78+
set { st with
79+
id := newId + 1,
80+
tokens := st.tokens.insert newId (some token, #[])
81+
|>.modify root.id (.map (·) (.push · newId))
82+
}
83+
return { state := root.state, token, id := newId }
84+
85+
/--
86+
Recursively cancel a context and all its children with the given reason.
87+
-/
88+
private partial def cancelChildren (state : ContextState) (id : UInt64) (reason : CancellationReason) : BaseIO ContextState := do
89+
let mut state := state
90+
91+
let some (tokenOpt, children) := state.tokens.get? id
92+
| return state
93+
94+
for tokenId in children do
95+
state ← cancelChildren state tokenId reason
96+
97+
if let some token := tokenOpt then
98+
token.cancel reason
99+
100+
pure { state with tokens := state.tokens.erase id }
101+
102+
/--
103+
Cancel this context and all child contexts with the given reason.
104+
-/
105+
def cancel (x : Context) (reason : CancellationReason) : BaseIO Unit := do
106+
if ← x.token.isCancelled then
107+
return
108+
109+
x.state.atomically do
110+
let st ← get
111+
let st ← cancelChildren st x.id reason
112+
set st
113+
114+
/--
115+
Check if the context is cancelled.
116+
-/
117+
@[inline]
118+
def isCancelled (x : Context) : BaseIO Bool := do
119+
x.token.isCancelled
120+
121+
/--
122+
Get the cancellation reason if the context is cancelled.
123+
-/
124+
@[inline]
125+
def getCancellationReason (x : Context) : BaseIO (Option CancellationReason) := do
126+
x.token.getCancellationReason
127+
128+
/--
129+
Wait for cancellation. Returns a task that completes when the context is cancelled.
130+
-/
131+
@[inline]
132+
def done (x : Context) : IO (AsyncTask Unit) :=
133+
x.token.wait
134+
135+
/--
136+
Creates a selector that waits for cancellation.
137+
-/
138+
@[inline]
139+
def doneSelector (x : Context) : Selector Unit :=
140+
x.token.selector
141+
142+
end Context
143+
end Std

0 commit comments

Comments
 (0)