Skip to content

Commit e314d45

Browse files
committed
feat: add base functions
1 parent 551a3e8 commit e314d45

File tree

2 files changed

+142
-11
lines changed

2 files changed

+142
-11
lines changed

src/Std/Internal/Async/Basic.lean

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,20 @@ Converts a `EAsync` to a `ExceptTask`.
355355
protected def toEIO (x : EAsync ε α) : EIO ε (ExceptTask ε α) :=
356356
MaybeExceptTask.toTask <$> x.toRawEIO
357357

358+
/--
359+
Creates a new `EAsync` out of a `Task`.
360+
-/
361+
@[inline]
362+
protected def ofTask (x : Task α) : EAsync ε α :=
363+
.mk (pure (MaybeExceptTask.ofTask <| x.map (.ok)))
364+
365+
/--
366+
Creates a new `EAsync` out of a `ExceptTask`.
367+
-/
368+
@[inline]
369+
protected def ofExceptTask (x : ExceptTask ε α) : EAsync ε α :=
370+
.mk (pure (MaybeExceptTask.ofTask x))
371+
358372
/--
359373
Creates an `EAsync` computation that immediately returns the given value.
360374
-/
@@ -475,6 +489,9 @@ def async (self : EAsync ε α) : EAsync ε (ExceptTask ε α) :=
475489
instance : MonadAwait (ExceptTask ε) (EAsync ε) where
476490
await t := mk <| pure <| .ofTask t
477491

492+
instance : MonadAwait Task (EAsync ε) where
493+
await t := mk <| pure <| .ofTask (t.map (.ok))
494+
478495
instance : MonadAwait AsyncTask (EAsync IO.Error) where
479496
await t := mk <| pure <| .ofTask t
480497

@@ -515,13 +532,6 @@ instance : OrElse (EAsync ε α) where
515532
instance [Inhabited ε] : Inhabited (EAsync ε α) where
516533
default := ⟨.error default⟩
517534

518-
/--
519-
Starts the given `ExceptTask` in the background and discards the result.
520-
-/
521-
@[inline]
522-
def parallel {α : Type} (x : EAsync ε (ExceptTask ε α)) : EAsync ε Unit :=
523-
discard <| x
524-
525535
/--
526536
A tail recursive version of the `forIn` for while loops inside the `EAsync` Monad.
527537
-/
@@ -579,9 +589,16 @@ Returns the `Async` computation inside an `AsyncTask`, so it can be awaited.
579589
def async (self : Async α) : Async (AsyncTask α) :=
580590
EAsync.lift <| self.asTask
581591

582-
@[default_instance] instance : MonadAwait AsyncTask Async := inferInstanceAs (MonadAwait AsyncTask (EAsync IO.Error))
583-
@[default_instance] instance : MonadAsync AsyncTask Async := inferInstanceAs (MonadAsync (ExceptTask IO.Error) (EAsync IO.Error))
584-
instance : MonadAwait IO.Promise Async := inferInstanceAs (MonadAwait IO.Promise (EAsync IO.Error))
592+
@[default_instance]
593+
instance : MonadAwait AsyncTask Async :=
594+
inferInstanceAs (MonadAwait AsyncTask (EAsync IO.Error))
595+
596+
@[default_instance]
597+
instance : MonadAsync AsyncTask Async :=
598+
inferInstanceAs (MonadAsync (ExceptTask IO.Error) (EAsync IO.Error))
599+
600+
instance : MonadAwait IO.Promise Async :=
601+
inferInstanceAs (MonadAwait IO.Promise (EAsync IO.Error))
585602

586603
end Async
587604

@@ -641,11 +658,74 @@ instance : MonadAwait Task BaseAsync where
641658
instance : MonadAsync Task BaseAsync where
642659
async := BaseAsync.async
643660

661+
instance : MonadLiftT BaseAsync (EAsync ε) where
662+
monadLift {α} x :=
663+
let r : EIO ε (MaybeExceptTask ε α) := do
664+
let r ← BaseIO.toEIO (x.toRawEIO)
665+
match r with
666+
| .pure res => pure <| .pure res
667+
| .ofTask t => pure <| .ofTask <| Task.map (fun (.ok t) => .ok t) t
668+
⟨r⟩
669+
670+
instance : MonadLiftT BaseAsync Async :=
671+
inferInstanceAs (MonadLiftT BaseAsync (EAsync IO.Error))
672+
644673
end BaseAsync
645674

646675
export MonadAsync (async)
647676
export MonadAwait (await)
648-
export EAsync (parallel)
677+
678+
/--
679+
Starts the given async task in the background and discards the result.
680+
-/
681+
@[inline, specialize]
682+
def parallel [Monad m] [MonadAsync t m] (x : m (t α)) : m Unit :=
683+
discard <| x
684+
685+
/--
686+
Runs two computations concurrently and returns both results as a pair.
687+
-/
688+
@[inline, specialize]
689+
def concurrently [Monad m] [MonadAwait t m] [MonadAsync t m] (x : m α) (y : m β) : m (α × β) := do
690+
let taskX : t α ← async x
691+
let taskY : t β ← async y
692+
let resultX ← await taskX
693+
let resultY ← await taskY
694+
return (resultX, resultY)
695+
696+
/--
697+
Runs two computations concurrently and returns the result of the one that finishes first.
698+
The other result is discarded.
699+
-/
700+
@[inline, specialize]
701+
def race [MonadLiftT BaseIO m] [MonadAwait Task m] [MonadAsync t m] [MonadAwait t m] [Monad m] [Inhabited α] (x : m α) (y : m α) : m α := do
702+
let promise ← IO.Promise.new
703+
704+
discard (async (n := t) <| Bind.bind x (liftM ∘ promise.resolve))
705+
discard (async (n := t) <| Bind.bind y (liftM ∘ promise.resolve))
706+
707+
await promise.result!
708+
709+
/--
710+
Runs all computations in an `Array` concurrently and returns all results as an array.
711+
-/
712+
@[inline, specialize]
713+
def concurrentlyAll [Monad m] [MonadAwait t m] [MonadAsync t m] (xs : Array (m α)) : m (Array α) := do
714+
let tasks : Array (t α) ← xs.mapM async
715+
tasks.mapM await
716+
717+
/--
718+
Runs all computations concurrently and returns the result of the first one to finish.
719+
All other results are discarded.
720+
-/
721+
@[inline, specialize]
722+
def raceAll [ForM m c (m α)] [MonadLiftT BaseIO m] [MonadAwait Task m] [MonadAsync t m] [MonadAwait t m] [Monad m] [Inhabited α] (xs : c) : m α := do
723+
let promise ← IO.Promise.new
724+
725+
ForM.forM xs fun x =>
726+
discard (async (n := t) <| Bind.bind x (liftM ∘ promise.resolve))
727+
728+
await promise.result!
649729

650730
end Async
651731
end IO
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import Std.Internal.Async
2+
3+
open Std.Internal.IO.Async
4+
5+
def wait (ms : Nat) (ref : IO.Ref Nat) (val : Nat) : Async Unit := do
6+
ref.modify (· * val)
7+
IO.sleep ms
8+
ref.modify (· + val)
9+
10+
-- Tests
11+
12+
def sequential : Async Unit := do
13+
let ref ← IO.mkRef 0
14+
wait 200 ref 1
15+
wait 400 ref 2
16+
ref.modify (· * 10)
17+
assert! (← ref.get) == 40
18+
19+
#eval do (← sequential.toRawEIO).wait
20+
21+
def conc : Async Unit := do
22+
let ref ← IO.mkRef 0
23+
discard <| concurrently (wait 200 ref 1) (wait 400 ref 2)
24+
ref.modify (· * 10)
25+
assert! (← ref.get) == 30
26+
27+
#eval do (← conc.toRawEIO).wait
28+
29+
def racer : Async Unit := do
30+
let ref ← IO.mkRef 0
31+
race (wait 200 ref 1) (wait 400 ref 2)
32+
ref.modify (· * 10)
33+
assert! (← ref.get) == 10
34+
35+
#eval do (← racer.toRawEIO).wait
36+
37+
def concAll : Async Unit := do
38+
let ref ← IO.mkRef 0
39+
discard <| concurrentlyAll #[(wait 200 ref 1), (wait 400 ref 2)]
40+
ref.modify (· * 10)
41+
assert! (← ref.get) == 30
42+
43+
#eval do (← concAll.toRawEIO).wait
44+
45+
def racerAll : Async Unit := do
46+
let ref ← IO.mkRef 0
47+
raceAll #[(wait 200 ref 1), (wait 400 ref 2)]
48+
ref.modify (· * 10)
49+
assert! (← ref.get) == 10
50+
51+
#eval do (← racerAll.toRawEIO).wait

0 commit comments

Comments
 (0)