Skip to content

Commit 25a9d99

Browse files
committed
Implement simple atomic stream select
for #577
1 parent 57ace76 commit 25a9d99

File tree

5 files changed

+95
-6
lines changed

5 files changed

+95
-6
lines changed

lib_eio/stream.ml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,45 @@ module Locking = struct
9494
Mutex.unlock t.mutex;
9595
Some v
9696

97+
let select_of_many streams_fns =
98+
let finished = Atomic.make false in
99+
let cancel_fns = ref [] in
100+
let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in
101+
let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in
102+
let wait ctx enqueue (t, f) = begin
103+
Mutex.lock t.mutex;
104+
(* First check if any items are already available and return early if there are. *)
105+
if not (Queue.is_empty t.items)
106+
then (
107+
cancel_all ();
108+
Mutex.unlock t.mutex;
109+
enqueue (Ok (f (Queue.take t.items))))
110+
else add_cancel_fn @@
111+
(* Otherwise, register interest in this stream. *)
112+
Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r ->
113+
if Result.is_ok r then (
114+
if not (Atomic.compare_and_set finished false true) then (
115+
(* Another stream has yielded an item in the meantime. However, as
116+
we have been waiting on this stream it must have been empty.
117+
118+
As the stream's mutex was held since before last checking for an item,
119+
the queue must be empty.
120+
*)
121+
assert ((Queue.length t.items) < t.capacity);
122+
Queue.add (Result.get_ok r) t.items
123+
) else (
124+
(* remove all other entries of this fiber in other streams' waiters. *)
125+
cancel_all ()
126+
));
127+
(* item is returned to waiting caller through enqueue and enter_unchecked. *)
128+
enqueue (Result.map f r))
129+
end in
130+
(* Register interest in all streams and return first available item. *)
131+
let wait_for_stream streams_fns = begin
132+
Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns)
133+
end in
134+
wait_for_stream streams_fns
135+
97136
let length t =
98137
Mutex.lock t.mutex;
99138
let len = Queue.length t.items in
@@ -125,6 +164,13 @@ let take_nonblocking = function
125164
| Sync x -> Sync.take_nonblocking x
126165
| Locking x -> Locking.take_nonblocking x
127166

167+
let select streams =
168+
let filter s = match s with
169+
| (Sync _, _) -> assert false
170+
| (Locking x, f) -> (x, f)
171+
in
172+
Locking.select_of_many (List.map filter streams)
173+
128174
let length = function
129175
| Sync _ -> 0
130176
| Locking x -> Locking.length x

lib_eio/stream.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option
4040
Note that if another domain may add to the stream then a [None]
4141
result may already be out-of-date by the time this returns. *)
4242

43+
val select : 'a 'b. ('a t * ('a -> 'b)) list -> 'b
44+
(** [select] returns the first item yielded by any stream. This only
45+
works for streams with non-zero capacity. *)
46+
4347
val length : 'a t -> int
4448
(** [length t] returns the number of items currently in [t]. *)
4549

lib_eio/waiters.ml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ let rec wake_one t v =
3838

3939
let is_empty = Lwt_dllist.is_empty
4040

41-
let await_internal ~mutex (t:'a t) id ctx enqueue =
41+
let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue =
4242
match Fiber_context.get_error ctx with
4343
| Some ex ->
4444
Option.iter Mutex.unlock mutex;
45-
enqueue (Error ex)
45+
enqueue (Error ex);
46+
fun () -> ()
4647
| None ->
4748
let resolved_waiter = ref Hook.null in
4849
let finished = Atomic.make false in
@@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue =
5657
enqueue (Error ex)
5758
)
5859
in
60+
let unwait () =
61+
if Atomic.compare_and_set finished false true
62+
then Hook.remove !resolved_waiter
63+
in
5964
Fiber_context.set_cancel_fn ctx cancel;
6065
let waiter = { enqueue; finished } in
6166
match mutex with
6267
| None ->
63-
resolved_waiter := add_waiter t waiter
68+
resolved_waiter := add_waiter t waiter;
69+
unwait
6470
| Some mutex ->
6571
resolved_waiter := add_waiter_protected ~mutex t waiter;
66-
Mutex.unlock mutex
72+
Mutex.unlock mutex;
73+
unwait
74+
75+
let await_internal ~mutex (t: 'a t) id ctx enqueue =
76+
let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in
77+
()
6778

6879
(* Returns a result if the wait succeeds, or raises if cancelled. *)
6980
let await ~mutex waiters id =

lib_eio/waiters.mli

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ val await :
2727
If [t] can be used from multiple domains:
2828
- [mutex] must be set to the mutex to use to unlock it.
2929
- [mutex] must be already held when calling this function, which will unlock it before blocking.
30-
When [await] returns, [mutex] will have been unlocked.
31-
@raise Cancel.Cancelled if the fiber's context is cancelled *)
30+
When [await] returns, [mutex] will have been unlocked.
31+
@raise Cancel.Cancelled if the fiber's context is cancelled *)
3232

3333
val await_internal :
3434
mutex:Mutex.t option ->
@@ -40,3 +40,12 @@ val await_internal :
4040
Note: [enqueue] is called from the triggering domain,
4141
which is currently calling {!wake_one} or {!wake_all}
4242
and must therefore be holding [mutex]. *)
43+
44+
val cancellable_await_internal :
45+
mutex:Mutex.t option ->
46+
'a t -> Ctf.id -> Fiber_context.t ->
47+
(('a, exn) result -> unit) -> (unit -> unit)
48+
(** Like [await_internal], but returns a function which, when called,
49+
removes the current fiber continuation from the waiters list.
50+
This is used when a fiber is waiting for multiple [Waiter]s simultaneously,
51+
and needs to remove itself from other waiters once it has been enqueued by one.*)

tests/stream.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream:
357357
+Got None from stream
358358
- : unit = ()
359359
```
360+
361+
Selecting from multiple channels:
362+
363+
```ocaml
364+
# run @@ fun () -> Switch.run (fun sw ->
365+
let t1, t2 = (S.create 2), (S.create 2) in
366+
let selector = [(t1, fun x -> x); (t2, fun x -> x)] in
367+
Fiber.fork ~sw (fun () -> S.add t2 "Hello");
368+
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
369+
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
370+
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
371+
Fiber.fork ~sw (fun () -> S.add t2 "Hello");
372+
Fiber.fork ~sw (fun () -> S.add t1 "World");
373+
)
374+
+Hello
375+
+Hello
376+
+World
377+
- : unit = ()
378+
```

0 commit comments

Comments
 (0)