Skip to content

Commit 9f66953

Browse files
committed
feat: add contextual monad
1 parent 9114332 commit 9f66953

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed

src/Std/Internal/Async.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module
77

88
prelude
99
public import Std.Internal.Async.Basic
10+
public import Std.Internal.Async.Context
1011
public import Std.Internal.Async.Timer
1112
public import Std.Internal.Async.TCP
1213
public import Std.Internal.Async.UDP
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Sofia Rodrigues
5+
-/
6+
module
7+
8+
prelude
9+
public import Std.Time
10+
public import Std.Internal.UV
11+
public import Std.Internal.Async.Basic
12+
public import Std.Internal.Async.Timer
13+
public import Std.Sync.Context
14+
15+
public section
16+
17+
namespace Std
18+
namespace Internal
19+
namespace IO
20+
namespace Async
21+
22+
/--
23+
An asynchronous computation with cancellation support via a `Context`.
24+
-/
25+
@[expose] abbrev ContextAsync (α : Type) := ReaderT Context Async α
26+
27+
namespace ContextAsync
28+
29+
/--
30+
Run a `ContextAsync` computation with a given context.
31+
-/
32+
@[inline]
33+
protected def run (ctx : Context) (x : ContextAsync α) : Async α :=
34+
x ctx
35+
36+
/--
37+
Create a `ContextAsync` from an `Async` computation.
38+
-/
39+
@[inline]
40+
protected def lift (x : Async α) : ContextAsync α :=
41+
fun _ => x
42+
43+
/--
44+
Get the current context.
45+
-/
46+
@[inline]
47+
def getContext : ContextAsync Context :=
48+
fun ctx => pure ctx
49+
50+
/--
51+
Fork a child context and run a computation within it.
52+
-/
53+
@[inline]
54+
def fork (x : ContextAsync α) : ContextAsync α :=
55+
fun parent => do
56+
let child ← Context.fork parent
57+
x child
58+
59+
/--
60+
Check if the current context is cancelled.
61+
-/
62+
@[inline]
63+
def isCancelled : ContextAsync Bool := do
64+
let ctx ← getContext
65+
ctx.isCancelled
66+
67+
/--
68+
Get the cancellation reason if the context is cancelled.
69+
-/
70+
@[inline]
71+
def getCancellationReason : ContextAsync (Option CancellationReason) := do
72+
let ctx ← getContext
73+
ctx.getCancellationReason
74+
75+
/--
76+
Cancel the current context with the given reason.
77+
-/
78+
@[inline]
79+
def cancel (reason : CancellationReason) : ContextAsync Unit := do
80+
let ctx ← getContext
81+
ctx.cancel reason
82+
83+
/--
84+
Wait for the current context to be cancelled.
85+
-/
86+
@[inline]
87+
def doneSelector : ContextAsync (Selector Unit) := do
88+
let ctx ← getContext
89+
return ctx.doneSelector
90+
91+
/--
92+
Wait for the current context to be cancelled.
93+
-/
94+
@[inline]
95+
def awaitCancellation : ContextAsync Unit := do
96+
let ctx ← getContext
97+
let task ← ctx.done
98+
await task
99+
100+
/--
101+
Run two computations concurrently and return both results. If either fails or is cancelled,
102+
both are cancelled.
103+
-/
104+
@[inline, specialize]
105+
def concurrently (x : ContextAsync α) (y : ContextAsync β)
106+
(prio := Task.Priority.default) : ContextAsync (α × β) := do
107+
let ctx ← getContext
108+
Async.concurrently (x ctx) (y ctx) prio
109+
110+
/--
111+
Run two computations concurrently and return the result of the first to complete.
112+
The loser's context is cancelled.
113+
-/
114+
@[inline, specialize]
115+
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
116+
(prio := Task.Priority.default) : ContextAsync α := do
117+
let parent ← getContext
118+
let ctx1 ← Context.fork parent
119+
let ctx2 ← Context.fork parent
120+
121+
let task1 ← async (x ctx1) prio
122+
let task2 ← async (y ctx2) prio
123+
124+
let result ← Async.race
125+
(await task1 <* ctx2.cancel .cancel)
126+
(await task2 <* ctx1.cancel .cancel)
127+
prio
128+
129+
pure result
130+
131+
/--
132+
Run all computations concurrently and collect results. If any fails or is cancelled,
133+
all are cancelled.
134+
-/
135+
@[inline, specialize]
136+
def concurrentlyAll (xs : Array (ContextAsync α))
137+
(prio := Task.Priority.default) : ContextAsync (Array α) := do
138+
let ctx ← getContext
139+
Async.concurrentlyAll (xs.map (· ctx)) prio
140+
141+
/--
142+
Run all computations concurrently and return the first result. All losers are cancelled.
143+
-/
144+
@[inline, specialize]
145+
def raceAll [ForM ContextAsync c (ContextAsync α)] (xs : c)
146+
(prio := Task.Priority.default) : ContextAsync α := do
147+
let parent ← getContext
148+
let promise ← IO.Promise.new
149+
150+
ForM.forM xs fun x => do
151+
let ctx ← Context.fork parent
152+
let task ← async (x ctx) prio
153+
background do
154+
try
155+
let result ← await task
156+
-- Cancel parent to stop other tasks
157+
parent.cancel .cancel
158+
promise.resolve (.ok result)
159+
catch e =>
160+
-- Only set error if promise not already resolved
161+
discard $ promise.resolve (.error e)
162+
163+
let result ← await promise
164+
Async.ofExcept result
165+
166+
/--
167+
Run a computation in the background with its own forked context.
168+
-/
169+
@[inline]
170+
def background (x : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do
171+
let parent ← getContext
172+
let child ← Context.fork parent
173+
Async.background (x child) prio
174+
175+
instance : Functor ContextAsync where
176+
map f x := fun ctx => f <$> x ctx
177+
178+
instance : Monad ContextAsync where
179+
pure a := fun _ => pure a
180+
bind x f := fun ctx => x ctx >>= fun a => f a ctx
181+
182+
instance : MonadLift Async ContextAsync where
183+
monadLift := ContextAsync.lift
184+
185+
instance : MonadLift IO ContextAsync where
186+
monadLift x := fun _ => Async.ofIOTask (Task.pure <$> x)
187+
188+
instance : MonadLift BaseIO ContextAsync where
189+
monadLift x := fun _ => liftM (m := Async) x
190+
191+
instance : MonadExcept IO.Error ContextAsync where
192+
throw e := fun _ => throw e
193+
tryCatch x h := fun ctx => tryCatch (x ctx) (fun e => h e ctx)
194+
195+
instance : MonadFinally ContextAsync where
196+
tryFinally' x f := fun ctx =>
197+
tryFinally' (x ctx) (fun opt => f opt ctx)
198+
199+
instance [Inhabited α] : Inhabited (ContextAsync α) where
200+
default := fun _ => default
201+
202+
instance : MonadAwait AsyncTask ContextAsync where
203+
await t := fun _ => await t
204+
205+
instance : MonadAsync AsyncTask ContextAsync where
206+
async x prio := fun ctx => async (x ctx) prio
207+
208+
end ContextAsync
209+
210+
end Async
211+
end IO
212+
end Internal
213+
end Std

0 commit comments

Comments
 (0)