|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Sofia Rodrigues |
| 5 | +-/ |
| 6 | +module |
| 7 | + |
| 8 | +prelude |
| 9 | +public import Std.Time |
| 10 | +public import Std.Internal.UV |
| 11 | +public import Std.Internal.Async.Basic |
| 12 | +public import Std.Internal.Async.Timer |
| 13 | +public import Std.Sync.Context |
| 14 | + |
| 15 | +public section |
| 16 | + |
| 17 | +namespace Std |
| 18 | +namespace Internal |
| 19 | +namespace IO |
| 20 | +namespace Async |
| 21 | + |
| 22 | +/-- |
| 23 | +An asynchronous computation with cancellation support via a `Context`. |
| 24 | +-/ |
| 25 | +@[expose] abbrev ContextAsync (α : Type) := ReaderT Context Async α |
| 26 | + |
| 27 | +namespace ContextAsync |
| 28 | + |
| 29 | +/-- |
| 30 | +Run a `ContextAsync` computation with a given context. |
| 31 | +-/ |
| 32 | +@[inline] |
| 33 | +protected def run (ctx : Context) (x : ContextAsync α) : Async α := |
| 34 | + x ctx |
| 35 | + |
| 36 | +/-- |
| 37 | +Create a `ContextAsync` from an `Async` computation. |
| 38 | +-/ |
| 39 | +@[inline] |
| 40 | +protected def lift (x : Async α) : ContextAsync α := |
| 41 | + fun _ => x |
| 42 | + |
| 43 | +/-- |
| 44 | +Get the current context. |
| 45 | +-/ |
| 46 | +@[inline] |
| 47 | +def getContext : ContextAsync Context := |
| 48 | + fun ctx => pure ctx |
| 49 | + |
| 50 | +/-- |
| 51 | +Fork a child context and run a computation within it. |
| 52 | +-/ |
| 53 | +@[inline] |
| 54 | +def fork (x : ContextAsync α) : ContextAsync α := |
| 55 | + fun parent => do |
| 56 | + let child ← Context.fork parent |
| 57 | + x child |
| 58 | + |
| 59 | +/-- |
| 60 | +Check if the current context is cancelled. |
| 61 | +-/ |
| 62 | +@[inline] |
| 63 | +def isCancelled : ContextAsync Bool := do |
| 64 | + let ctx ← getContext |
| 65 | + ctx.isCancelled |
| 66 | + |
| 67 | +/-- |
| 68 | +Get the cancellation reason if the context is cancelled. |
| 69 | +-/ |
| 70 | +@[inline] |
| 71 | +def getCancellationReason : ContextAsync (Option CancellationReason) := do |
| 72 | + let ctx ← getContext |
| 73 | + ctx.getCancellationReason |
| 74 | + |
| 75 | +/-- |
| 76 | +Cancel the current context with the given reason. |
| 77 | +-/ |
| 78 | +@[inline] |
| 79 | +def cancel (reason : CancellationReason) : ContextAsync Unit := do |
| 80 | + let ctx ← getContext |
| 81 | + ctx.cancel reason |
| 82 | + |
| 83 | +/-- |
| 84 | +Wait for the current context to be cancelled. |
| 85 | +-/ |
| 86 | +@[inline] |
| 87 | +def doneSelector : ContextAsync (Selector Unit) := do |
| 88 | + let ctx ← getContext |
| 89 | + return ctx.doneSelector |
| 90 | + |
| 91 | +/-- |
| 92 | +Wait for the current context to be cancelled. |
| 93 | +-/ |
| 94 | +@[inline] |
| 95 | +def awaitCancellation : ContextAsync Unit := do |
| 96 | + let ctx ← getContext |
| 97 | + let task ← ctx.done |
| 98 | + await task |
| 99 | + |
| 100 | +/-- |
| 101 | +Run two computations concurrently and return both results. If either fails or is cancelled, |
| 102 | +both are cancelled. |
| 103 | +-/ |
| 104 | +@[inline, specialize] |
| 105 | +def concurrently (x : ContextAsync α) (y : ContextAsync β) |
| 106 | + (prio := Task.Priority.default) : ContextAsync (α × β) := do |
| 107 | + let ctx ← getContext |
| 108 | + Async.concurrently (x ctx) (y ctx) prio |
| 109 | + |
| 110 | +/-- |
| 111 | +Run two computations concurrently and return the result of the first to complete. |
| 112 | +The loser's context is cancelled. |
| 113 | +-/ |
| 114 | +@[inline, specialize] |
| 115 | +def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α) |
| 116 | + (prio := Task.Priority.default) : ContextAsync α := do |
| 117 | + let parent ← getContext |
| 118 | + let ctx1 ← Context.fork parent |
| 119 | + let ctx2 ← Context.fork parent |
| 120 | + |
| 121 | + let task1 ← async (x ctx1) prio |
| 122 | + let task2 ← async (y ctx2) prio |
| 123 | + |
| 124 | + let result ← Async.race |
| 125 | + (await task1 <* ctx2.cancel .cancel) |
| 126 | + (await task2 <* ctx1.cancel .cancel) |
| 127 | + prio |
| 128 | + |
| 129 | + pure result |
| 130 | + |
| 131 | +/-- |
| 132 | +Run all computations concurrently and collect results. If any fails or is cancelled, |
| 133 | +all are cancelled. |
| 134 | +-/ |
| 135 | +@[inline, specialize] |
| 136 | +def concurrentlyAll (xs : Array (ContextAsync α)) |
| 137 | + (prio := Task.Priority.default) : ContextAsync (Array α) := do |
| 138 | + let ctx ← getContext |
| 139 | + Async.concurrentlyAll (xs.map (· ctx)) prio |
| 140 | + |
| 141 | +/-- |
| 142 | +Run all computations concurrently and return the first result. All losers are cancelled. |
| 143 | +-/ |
| 144 | +@[inline, specialize] |
| 145 | +def raceAll [ForM ContextAsync c (ContextAsync α)] (xs : c) |
| 146 | + (prio := Task.Priority.default) : ContextAsync α := do |
| 147 | + let parent ← getContext |
| 148 | + let promise ← IO.Promise.new |
| 149 | + |
| 150 | + ForM.forM xs fun x => do |
| 151 | + let ctx ← Context.fork parent |
| 152 | + let task ← async (x ctx) prio |
| 153 | + background do |
| 154 | + try |
| 155 | + let result ← await task |
| 156 | + -- Cancel parent to stop other tasks |
| 157 | + parent.cancel .cancel |
| 158 | + promise.resolve (.ok result) |
| 159 | + catch e => |
| 160 | + -- Only set error if promise not already resolved |
| 161 | + discard $ promise.resolve (.error e) |
| 162 | + |
| 163 | + let result ← await promise |
| 164 | + Async.ofExcept result |
| 165 | + |
| 166 | +/-- |
| 167 | +Run a computation in the background with its own forked context. |
| 168 | +-/ |
| 169 | +@[inline] |
| 170 | +def background (x : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do |
| 171 | + let parent ← getContext |
| 172 | + let child ← Context.fork parent |
| 173 | + Async.background (x child) prio |
| 174 | + |
| 175 | +instance : Functor ContextAsync where |
| 176 | + map f x := fun ctx => f <$> x ctx |
| 177 | + |
| 178 | +instance : Monad ContextAsync where |
| 179 | + pure a := fun _ => pure a |
| 180 | + bind x f := fun ctx => x ctx >>= fun a => f a ctx |
| 181 | + |
| 182 | +instance : MonadLift Async ContextAsync where |
| 183 | + monadLift := ContextAsync.lift |
| 184 | + |
| 185 | +instance : MonadLift IO ContextAsync where |
| 186 | + monadLift x := fun _ => Async.ofIOTask (Task.pure <$> x) |
| 187 | + |
| 188 | +instance : MonadLift BaseIO ContextAsync where |
| 189 | + monadLift x := fun _ => liftM (m := Async) x |
| 190 | + |
| 191 | +instance : MonadExcept IO.Error ContextAsync where |
| 192 | + throw e := fun _ => throw e |
| 193 | + tryCatch x h := fun ctx => tryCatch (x ctx) (fun e => h e ctx) |
| 194 | + |
| 195 | +instance : MonadFinally ContextAsync where |
| 196 | + tryFinally' x f := fun ctx => |
| 197 | + tryFinally' (x ctx) (fun opt => f opt ctx) |
| 198 | + |
| 199 | +instance [Inhabited α] : Inhabited (ContextAsync α) where |
| 200 | + default := fun _ => default |
| 201 | + |
| 202 | +instance : MonadAwait AsyncTask ContextAsync where |
| 203 | + await t := fun _ => await t |
| 204 | + |
| 205 | +instance : MonadAsync AsyncTask ContextAsync where |
| 206 | + async x prio := fun ctx => async (x ctx) prio |
| 207 | + |
| 208 | +end ContextAsync |
| 209 | + |
| 210 | +end Async |
| 211 | +end IO |
| 212 | +end Internal |
| 213 | +end Std |
0 commit comments