Skip to content

Commit fc8cd02

Browse files
committed
Merge branch 'sofia/sync-broadcast' of https://github.com/leanprover/lean4 into sofia/sync-broadcast
2 parents df99764 + f53c92a commit fc8cd02

File tree

1 file changed

+18
-51
lines changed

1 file changed

+18
-51
lines changed

src/Std/Sync/Broadcast.lean

Lines changed: 18 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,9 @@ public import Init.System.Promise
1111
public import Init.Data.Queue
1212
public import Std.Sync.Mutex
1313
public import Std.Internal.Async.Select
14-
public import Std.Internal.Async.IO
1514

1615
public section
1716

18-
namespace Std
19-
20-
open Std.Internal.Async.IO
21-
open Std.Internal.IO.Async
22-
2317
/-!
2418
The `Std.Sync.Broadcast` module implements a broadcasting primitive for sending values
2519
to multiple consumers. It maintains a queue of values and supports both synchronous
@@ -29,6 +23,8 @@ This module is heavily inspired by `Std.Sync.Channel` as well as
2923
[tokio’s broadcast implementation](https://github.com/tokio-rs/tokio/blob/master/tokio/src/sync/broadcast.rs).
3024
-/
3125

26+
namespace Std
27+
3228
/--
3329
Errors that may be thrown while interacting with the broadcast channel API.
3430
-/
@@ -333,18 +329,11 @@ private def unsubscribe (bd : Bounded.Receiver α) : IO Unit := do
333329
let st ← get
334330

335331
let some next := st.receivers.get? bd.id
336-
| return Except.error Broadcast.Error.notSubscribed
332+
| return Except.error Broadcast.Error.noSubscribers
337333

338-
let mut currentSt := st
339-
let mut currentNext := next
334+
discard <| getValueByPosition next
340335

341-
while currentNext < currentSt.pos ∧ currentSt.size > 0 do
342-
let some _val ← getValueByPosition currentNext | break
343-
344-
currentSt ← get
345-
currentNext := currentNext + 1
346-
347-
set { currentSt with receivers := currentSt.receivers.erase bd.id }
336+
set { st with receivers := st.receivers.erase bd.id }
348337

349338
pure <| .ok ()
350339

@@ -358,24 +347,21 @@ private def tryRecv'
358347
let next := st.receivers.get! receiverId
359348

360349
if let some val ← getValueByPosition next then
361-
modify ({ · with receivers := st.receivers.modify receiverId (· + 1) })
350+
set { st with receivers := st.receivers.modify receiverId (· + 1) }
362351
return some val
363352
else
364353
return none
365354

366355
private def tryRecv (ch : Bounded.Receiver α) : BaseIO (Option α) :=
367356
ch.state.atomically do
368-
if (← get).closed ∨ ¬(← get).receivers.contains ch.id then
369-
return none
370-
371357
tryRecv' ch.id
372358

373359
private partial def recv (ch : Bounded.Receiver α) : BaseIO (Task (Option α)) := do
374360
ch.state.atomically do
375-
if (← get).closed ∨ ¬(← get).receivers.contains ch.id then
376-
return .pure none
377-
else if let some val ← tryRecv' ch.id then
361+
if let some val ← tryRecv' ch.id then
378362
return .pure <| some val
363+
else if (← get).closed then
364+
return .pure <| none
379365
else
380366
let promise ← IO.Promise.new
381367
modify fun st => { st with waiters := st.waiters.enqueue ⟨promise, none⟩ }
@@ -506,13 +492,6 @@ Subscribes a new `Receiver` from the `Broadcast` channel.
506492
def subscribe (ch : Broadcast α) : IO (Broadcast.Receiver α) := do
507493
Broadcast.Receiver.mk <$> ch.inner.subscribe
508494

509-
/--
510-
Closes a `Broadcast` channel.
511-
-/
512-
@[inline]
513-
def close (ch : Broadcast α) : IO Unit := do
514-
ch.inner.close
515-
516495
/--
517496
Send a value through the broadcast channel, returning a task that will resolve once the transmission
518497
could be completed.
@@ -538,8 +517,11 @@ def tryRecv (ch : Broadcast.Receiver α) : BaseIO (Option α) :=
538517
Receive a value from the broadcast receiver, returning a task that will resolve with
539518
the next available message. This will block until a message is available.
540519
-/
541-
def recv [Inhabited α] (ch : Broadcast.Receiver α) : BaseIO (Task (Option α)) := do
542-
Std.Bounded.Receiver.recv ch.inner
520+
def recv [Inhabited α] (ch : Broadcast.Receiver α) : BaseIO (Task α) := do
521+
BaseIO.bindTask (sync := true) (← Std.Bounded.Receiver.recv ch.inner)
522+
fun
523+
| some val => return .pure val
524+
| none => unreachable!
543525

544526
open Internal.IO.Async in
545527

@@ -581,18 +563,6 @@ partial def forAsync (f : α → BaseIO Unit) (ch : Broadcast.Receiver α)
581563
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
582564
ch.inner.forAsync f prio
583565

584-
instance [Inhabited α] : AsyncStream (Broadcast.Receiver α) α where
585-
next channel := channel.recvSelector
586-
stop channel := channel.unsubscribe
587-
588-
instance [Inhabited α] : AsyncRead (Broadcast.Receiver α) (Option α) where
589-
read receiver := Internal.IO.Async.Async.ofIOTask receiver.recv
590-
591-
instance [Inhabited α] : AsyncWrite (Broadcast α) α where
592-
write receiver x := do
593-
let task ← receiver.send x
594-
discard <| Async.ofETask <| task
595-
596566
end Receiver
597567

598568
/--
@@ -638,18 +608,15 @@ def tryRecv (ch : Sync.Receiver α) : BaseIO (Option α) := Broadcast.Receiver.t
638608
/--
639609
Receive a value from the channel, blocking until the transmission could be completed.
640610
-/
641-
def recv [Inhabited α] (ch : Sync.Receiver α) : BaseIO (Option α) := do
611+
def recv [Inhabited α] (ch : Sync.Receiver α) : BaseIO α := do
642612
IO.wait (← Broadcast.Receiver.recv ch)
643613

644614
partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
645615
(ch : Sync.Receiver α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
646616
let a ← ch.recv
647-
match a with
648-
| none => pure b
649-
| some a =>
650-
match ← f a b with
651-
| .done b => pure b
652-
| .yield b => ch.forIn f b
617+
match ← f a b with
618+
| .done b => pure b
619+
| .yield b => ch.forIn f b
653620

654621
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
655622
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync.Receiver α) α where

0 commit comments

Comments
 (0)