@@ -11,15 +11,9 @@ public import Init.System.Promise
1111public import Init.Data.Queue
1212public import Std.Sync.Mutex
1313public import Std.Internal.Async.Select
14- public import Std.Internal.Async.IO
1514
1615public section
1716
18- namespace Std
19-
20- open Std.Internal.Async.IO
21- open Std.Internal.IO.Async
22-
2317/-!
2418The `Std.Sync.Broadcast` module implements a broadcasting primitive for sending values
2519to multiple consumers. It maintains a queue of values and supports both synchronous
@@ -29,6 +23,8 @@ This module is heavily inspired by `Std.Sync.Channel` as well as
2923[ tokio’s broadcast implementation ] (https://github.com/tokio-rs/tokio/blob/master/tokio/src/sync/broadcast.rs).
3024-/
3125
26+ namespace Std
27+
3228/--
3329Errors that may be thrown while interacting with the broadcast channel API.
3430-/
@@ -333,18 +329,11 @@ private def unsubscribe (bd : Bounded.Receiver α) : IO Unit := do
333329 let st ← get
334330
335331 let some next := st.receivers.get? bd.id
336- | return Except.error Broadcast.Error.notSubscribed
332+ | return Except.error Broadcast.Error.noSubscribers
337333
338- let mut currentSt := st
339- let mut currentNext := next
334+ discard <| getValueByPosition next
340335
341- while currentNext < currentSt.pos ∧ currentSt.size > 0 do
342- let some _val ← getValueByPosition currentNext | break
343-
344- currentSt ← get
345- currentNext := currentNext + 1
346-
347- set { currentSt with receivers := currentSt.receivers.erase bd.id }
336+ set { st with receivers := st.receivers.erase bd.id }
348337
349338 pure <| .ok ()
350339
@@ -358,24 +347,21 @@ private def tryRecv'
358347 let next := st.receivers.get! receiverId
359348
360349 if let some val ← getValueByPosition next then
361- modify ({ · with receivers := st.receivers.modify receiverId (· + 1 ) })
350+ set { st with receivers := st.receivers.modify receiverId (· + 1 ) }
362351 return some val
363352 else
364353 return none
365354
366355private def tryRecv (ch : Bounded.Receiver α) : BaseIO (Option α) :=
367356 ch.state.atomically do
368- if (← get).closed ∨ ¬(← get).receivers.contains ch.id then
369- return none
370-
371357 tryRecv' ch.id
372358
373359private partial def recv (ch : Bounded.Receiver α) : BaseIO (Task (Option α)) := do
374360 ch.state.atomically do
375- if (← get).closed ∨ ¬(← get).receivers.contains ch.id then
376- return .pure none
377- else if let some val ← tryRecv' ch.id then
361+ if let some val ← tryRecv' ch.id then
378362 return .pure <| some val
363+ else if (← get).closed then
364+ return .pure <| none
379365 else
380366 let promise ← IO.Promise.new
381367 modify fun st => { st with waiters := st.waiters.enqueue ⟨promise, none⟩ }
@@ -506,13 +492,6 @@ Subscribes a new `Receiver` from the `Broadcast` channel.
506492def subscribe (ch : Broadcast α) : IO (Broadcast.Receiver α) := do
507493 Broadcast.Receiver.mk <$> ch.inner.subscribe
508494
509- /--
510- Closes a `Broadcast` channel.
511- -/
512- @[inline]
513- def close (ch : Broadcast α) : IO Unit := do
514- ch.inner.close
515-
516495/--
517496Send a value through the broadcast channel, returning a task that will resolve once the transmission
518497could be completed.
@@ -538,8 +517,11 @@ def tryRecv (ch : Broadcast.Receiver α) : BaseIO (Option α) :=
538517Receive a value from the broadcast receiver, returning a task that will resolve with
539518the next available message. This will block until a message is available.
540519-/
541- def recv [Inhabited α] (ch : Broadcast.Receiver α) : BaseIO (Task (Option α)) := do
542- Std.Bounded.Receiver.recv ch.inner
520+ def recv [Inhabited α] (ch : Broadcast.Receiver α) : BaseIO (Task α) := do
521+ BaseIO.bindTask (sync := true ) (← Std.Bounded.Receiver.recv ch.inner)
522+ fun
523+ | some val => return .pure val
524+ | none => unreachable!
543525
544526open Internal.IO.Async in
545527
@@ -581,18 +563,6 @@ partial def forAsync (f : α → BaseIO Unit) (ch : Broadcast.Receiver α)
581563 (prio : Task.Priority := .default) : BaseIO (Task Unit) := do
582564 ch.inner.forAsync f prio
583565
584- instance [Inhabited α] : AsyncStream (Broadcast.Receiver α) α where
585- next channel := channel.recvSelector
586- stop channel := channel.unsubscribe
587-
588- instance [Inhabited α] : AsyncRead (Broadcast.Receiver α) (Option α) where
589- read receiver := Internal.IO.Async.Async.ofIOTask receiver.recv
590-
591- instance [Inhabited α] : AsyncWrite (Broadcast α) α where
592- write receiver x := do
593- let task ← receiver.send x
594- discard <| Async.ofETask <| task
595-
596566end Receiver
597567
598568/--
@@ -638,18 +608,15 @@ def tryRecv (ch : Sync.Receiver α) : BaseIO (Option α) := Broadcast.Receiver.t
638608/--
639609Receive a value from the channel, blocking until the transmission could be completed.
640610-/
641- def recv [Inhabited α] (ch : Sync.Receiver α) : BaseIO (Option α) := do
611+ def recv [Inhabited α] (ch : Sync.Receiver α) : BaseIO α := do
642612 IO.wait (← Broadcast.Receiver.recv ch)
643613
644614partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
645615 (ch : Sync.Receiver α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
646616 let a ← ch.recv
647- match a with
648- | none => pure b
649- | some a =>
650- match ← f a b with
651- | .done b => pure b
652- | .yield b => ch.forIn f b
617+ match ← f a b with
618+ | .done b => pure b
619+ | .yield b => ch.forIn f b
653620
654621/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
655622instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync.Receiver α) α where
0 commit comments