Skip to content

Commit f9adafe

Browse files
feat: adds acceptSelector and modified selectors (#10667)
This PR adds more selectors for TCP and Signals. It also fixes a problem with `Selectors` that they cannot be closures over a promise, otherwise it causes the waiter promise to never be dropped.
1 parent 69d8d63 commit f9adafe

File tree

15 files changed

+230
-51
lines changed

15 files changed

+230
-51
lines changed

src/Std/Internal/Async/Signal.lean

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,23 +238,26 @@ def stop (s : Signal.Waiter) : IO Unit :=
238238
s.native.stop
239239

240240
/--
241-
Create a `Selector` that resolves once `s` has received the signal. Note that calling this function starts `s`
242-
if it hasn't already started.
241+
Create a `Selector` that resolves once `s` has received the signal. Note that calling this function
242+
does not start the signal waiter.
243243
-/
244-
def selector (s : Signal.Waiter) : IO (Selector Unit) := do
245-
let signalWaiter ← s.wait
246-
return {
244+
def selector (s : Signal.Waiter) : Selector Unit :=
245+
{
247246
tryFn := do
247+
let signalWaiter : AsyncTask _ ← async s.wait
248248
if ← IO.hasFinished signalWaiter then
249249
return some ()
250250
else
251+
s.native.cancel
251252
return none
252253

253254
registerFn waiter := do
255+
let signalWaiter ← s.wait
254256
discard <| AsyncTask.mapIO (x := signalWaiter) fun _ => do
255257
let lose := return ()
256258
let win promise := promise.resolve (.ok ())
257259
waiter.race lose win
258260

259-
unregisterFn := s.stop
261+
unregisterFn := s.native.cancel
262+
260263
}

src/Std/Internal/Async/TCP.lean

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,44 @@ def accept (s : Server) : Async Client := do
6969
|> Async.ofPromise
7070
|>.map Client.ofNative
7171

72+
/--
73+
Tries to accept an incoming connection.
74+
-/
75+
@[inline]
76+
def tryAccept (s : Server) : IO (Option Client) := do
77+
let res ← s.native.tryAccept
78+
let socket ← IO.ofExcept res
79+
return Client.ofNative <$> socket
80+
81+
/--
82+
Creates a `Selector` that resolves once `s` has a connection available. Calling this function
83+
does not start the connection wait, so it must not be called in parallel with `accept`.
84+
-/
85+
def acceptSelector (s : TCP.Socket.Server) : Selector Client :=
86+
{
87+
tryFn :=
88+
s.tryAccept
89+
90+
registerFn waiter := do
91+
let task ← s.native.accept
92+
93+
-- If we get cancelled the promise will be dropped so prepare for that
94+
IO.chainTask (t := task.result?) fun res => do
95+
match res with
96+
| none => return ()
97+
| some res =>
98+
let lose := return ()
99+
let win promise := do
100+
try
101+
let result ← IO.ofExcept res
102+
promise.resolve (.ok (Client.ofNative result))
103+
catch e =>
104+
promise.resolve (.error e)
105+
waiter.race lose win
106+
107+
unregisterFn := s.native.cancelAccept
108+
}
109+
72110
/--
73111
Gets the local address of the server socket.
74112
-/
@@ -143,20 +181,25 @@ def recv? (s : Client) (size : UInt64) : Async (Option ByteArray) :=
143181

144182
/--
145183
Creates a `Selector` that resolves once `s` has data available, up to at most `size` bytes,
146-
and provides that data. Calling this function starts the data wait, so it must not be called
184+
and provides that data. Calling this function does not starts the data wait, so it must not be called
147185
in parallel with `recv?`.
148186
-/
149-
def recvSelector (s : TCP.Socket.Client) (size : UInt64) : Async (Selector (Option ByteArray)) := do
150-
let readableWaiter ← s.native.waitReadable
151-
return {
187+
def recvSelector (s : TCP.Socket.Client) (size : UInt64) : Selector (Option ByteArray) :=
188+
{
152189
tryFn := do
190+
let readableWaiter ← s.native.waitReadable
191+
153192
if ← readableWaiter.isResolved then
154193
-- We know that this read should not block
155194
let res ← (s.recv? size).block
156195
return some res
157196
else
197+
s.native.cancelRecv
158198
return none
199+
159200
registerFn waiter := do
201+
let readableWaiter ← s.native.waitReadable
202+
160203
-- If we get cancelled the promise will be dropped so prepare for that
161204
discard <| IO.mapTask (t := readableWaiter.result?) fun res => do
162205
match res with
@@ -172,6 +215,7 @@ def recvSelector (s : TCP.Socket.Client) (size : UInt64) : Async (Selector (Opti
172215
catch e =>
173216
promise.resolve (.error e)
174217
waiter.race lose win
218+
175219
unregisterFn := s.native.cancelRecv
176220
}
177221

src/Std/Internal/Async/Timer.lean

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,18 @@ def stop (s : Sleep) : IO Unit :=
6969
s.native.stop
7070

7171
/--
72-
Create a `Selector` that resolves once `s` has finished. Note that calling this function starts `s`
73-
if it hasn't already started.
72+
Create a `Selector` that resolves once `s` has finished. `s` only starts when it runs inside of a Selectable.
7473
-/
75-
def selector (s : Sleep) : Async (Selector Unit) := do
76-
return {
74+
def selector (s : Sleep) : Selector Unit :=
75+
{
7776
tryFn := do
7877
let sleepWaiter ← s.native.next
7978
if ← sleepWaiter.isResolved then
8079
return some ()
8180
else
81+
s.native.cancel
8282
return none
83+
8384
registerFn waiter := do
8485
let sleepWaiter ← s.native.next
8586
BaseIO.chainTask sleepWaiter.result? fun
@@ -107,7 +108,7 @@ Return a `Selector` that completes after `duration`.
107108
-/
108109
def Selector.sleep (duration : Std.Time.Millisecond.Offset) : Async (Selector Unit) := do
109110
let sleeper ← Sleep.mk duration
110-
sleeper.selector
111+
return sleeper.selector
111112

112113
/--
113114
`Interval` can be used to repeatedly wait for some duration like a clock.

src/Std/Internal/Async/UDP.lean

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,24 @@ def recv (s : Socket) (size : UInt64) : Async (ByteArray × Option SocketAddress
9191
Creates a `Selector` that resolves once `s` has data available, up to at most `size` bytes,
9292
and provides that data. If the socket has not been previously bound with `bind`, it is
9393
automatically bound to `0.0.0.0` (all interfaces) with a random port.
94-
Calling this function starts the data wait, so it must not be called in parallel with `recv`.
94+
Calling this function does starts the data wait, only when it's used with `Selectable.one` or `combine`.
95+
It must not be called in parallel with `recv`.
9596
-/
96-
def recvSelector (s : Socket) (size : UInt64) :
97-
Async (Selector (ByteArray × Option SocketAddress)) := do
98-
let readableWaiter ← s.native.waitReadable
99-
return {
97+
def recvSelector (s : Socket) (size : UInt64) : Selector (ByteArray × Option SocketAddress) :=
98+
{
10099
tryFn := do
100+
let readableWaiter ← s.native.waitReadable
101+
101102
if ← readableWaiter.isResolved then
102103
-- We know that this read should not block
103104
let res ← (s.recv size).block
104105
return some res
105106
else
107+
s.native.cancelRecv
106108
return none
107109
registerFn waiter := do
110+
let readableWaiter ← s.native.waitReadable
111+
108112
-- If we get cancelled the promise will be dropped so prepare for that
109113
discard <| IO.mapTask (t := readableWaiter.result?) fun res => do
110114
match res with
@@ -120,6 +124,7 @@ def recvSelector (s : Socket) (size : UInt64) :
120124
catch e =>
121125
promise.resolve (.error e)
122126
waiter.race lose win
127+
123128
unregisterFn := s.native.cancelRecv
124129
}
125130

src/Std/Internal/UV/Signal.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ This function has different behavior depending on the state of the `Signal`:
7777
@[extern "lean_uv_signal_stop"]
7878
opaque stop (signal : @& Signal) : IO Unit
7979

80+
/--
81+
This function has different behavior depending on the state of the `Signal`:
82+
- If it is initial or finished this is a no-op.
83+
- If it's running then it drops the accept promise and if it's not repeatable it sets
84+
the signal handler to the initial state.
85+
-/
86+
@[extern "lean_uv_signal_cancel"]
87+
opaque cancel (signal : @& Signal) : IO Unit
88+
8089
end Signal
8190

8291
end UV

src/Std/Internal/UV/TCP.lean

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ Accepts an incoming connection on a listening TCP socket.
9595
@[extern "lean_uv_tcp_accept"]
9696
opaque accept (socket : @& Socket) : IO (IO.Promise (Except IO.Error Socket))
9797

98+
/--
99+
Tries to accept an incoming connection on a listening TCP socket.
100+
-/
101+
@[extern "lean_uv_tcp_try_accept"]
102+
opaque tryAccept (socket : @& Socket) : IO (Except IO.Error (Option Socket))
103+
104+
/--
105+
Cancels the accept request of a socket.
106+
-/
107+
@[extern "lean_uv_tcp_cancel_accept"]
108+
opaque cancelAccept (socket : @& Socket) : IO Unit
109+
98110
/--
99111
Shuts down an incoming connection on a listening TCP socket.
100112
-/

src/Std/Internal/UV/Timer.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ opaque stop (timer : @& Timer) : IO Unit
8787
This function has different behavior depending on the state of the `Timer`:
8888
- If it is initial or finished this is a no-op.
8989
- If it is running, the promise generated by the `next` function is dropped.
90-
- If `repeating` is `false` then it sets the timer to the finished state.
90+
- If `repeating` is `false` then it sets the timer to the initial state.
9191
-/
9292
@[extern "lean_uv_timer_cancel"]
9393
opaque cancel (timer : @& Timer) : IO Unit

src/runtime/uv/signal.cpp

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,29 +39,29 @@ void initialize_libuv_signal() {
3939
}
4040

4141
static bool signal_promise_is_finished(lean_uv_signal_object * signal) {
42-
return lean_io_get_task_state_core((lean_object *)lean_to_promise(signal->m_promise)->m_result) == 2;
42+
return signal->m_promise == NULL || lean_io_get_task_state_core((lean_object *)lean_to_promise(signal->m_promise)->m_result) == 2;
4343
}
4444

4545
void handle_signal_event(uv_signal_t* handle, int signum) {
4646
lean_object * obj = (lean_object*)handle->data;
4747
lean_uv_signal_object * signal = lean_to_uv_signal(obj);
4848

4949
lean_assert(signal->m_state == SIGNAL_STATE_RUNNING);
50-
lean_assert(signal->m_promise != NULL);
5150

5251
if (signal->m_repeating) {
5352
if (!signal_promise_is_finished(signal)) {
5453
lean_object* res = lean_io_promise_resolve(lean_box(signum), signal->m_promise, lean_io_mk_world());
5554
lean_dec(res);
5655
}
5756
} else {
58-
lean_assert(!signal_promise_is_finished(signal));
57+
if (signal->m_promise != NULL) {
58+
lean_object* res = lean_io_promise_resolve(lean_box(signum), signal->m_promise, lean_io_mk_world());
59+
lean_dec(res);
60+
}
61+
5962
uv_signal_stop(signal->m_uv_signal);
6063
signal->m_state = SIGNAL_STATE_FINISHED;
6164

62-
lean_object* res = lean_io_promise_resolve(lean_box(signum), signal->m_promise, lean_io_mk_world());
63-
lean_dec(res);
64-
6565
lean_dec(obj);
6666
}
6767
}
@@ -154,33 +154,37 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_next(b_obj_arg obj, obj_arg /
154154
}
155155
case SIGNAL_STATE_RUNNING:
156156
{
157-
lean_assert(signal->m_promise != NULL);
158-
// 2 indicates finished
159157
if (signal_promise_is_finished(signal)) {
160-
lean_dec(signal->m_promise);
158+
if (signal->m_promise != NULL) {
159+
lean_dec(signal->m_promise);
160+
}
161+
161162
signal->m_promise = create_promise();
162-
lean_inc(signal->m_promise);
163-
return lean_io_result_mk_ok(signal->m_promise);
164-
} else {
165-
lean_inc(signal->m_promise);
166-
return lean_io_result_mk_ok(signal->m_promise);
167163
}
164+
165+
lean_inc(signal->m_promise);
166+
return lean_io_result_mk_ok(signal->m_promise);
168167
}
169168
case SIGNAL_STATE_FINISHED:
170169
{
171-
lean_assert(signal->m_promise != NULL);
170+
if (signal->m_promise == NULL) {
171+
lean_object* finished_promise = create_promise();
172+
return lean_io_result_mk_ok(finished_promise);
173+
}
174+
172175
lean_inc(signal->m_promise);
173176
return lean_io_result_mk_ok(signal->m_promise);
174177
}
175178
}
176179
} else {
177180
if (signal->m_state == SIGNAL_STATE_INITIAL) {
178181
return setup_signal();
179-
} else {
180-
lean_assert(signal->m_promise != NULL);
181-
182+
} else if (signal->m_promise != NULL) {
182183
lean_inc(signal->m_promise);
183184
return lean_io_result_mk_ok(signal->m_promise);
185+
} else {
186+
lean_object* finished_promise = create_promise();
187+
return lean_io_result_mk_ok(finished_promise);
184188
}
185189
}
186190
}
@@ -190,12 +194,15 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_stop(b_obj_arg obj, obj_arg /
190194
lean_uv_signal_object * signal = lean_to_uv_signal(obj);
191195

192196
if (signal->m_state == SIGNAL_STATE_RUNNING) {
193-
lean_assert(signal->m_promise != NULL);
194-
195197
event_loop_lock(&global_ev);
196198
int result = uv_signal_stop(signal->m_uv_signal);
197199
event_loop_unlock(&global_ev);
198200

201+
if (signal->m_promise != NULL) {
202+
lean_dec(signal->m_promise);
203+
signal->m_promise = NULL;
204+
}
205+
199206
signal->m_state = SIGNAL_STATE_FINISHED;
200207

201208
// The loop does not need to keep the signal alive anymore.
@@ -211,6 +218,30 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_stop(b_obj_arg obj, obj_arg /
211218
}
212219
}
213220

221+
/* Std.Internal.UV.Signal.cancel (signal : @& Signal) : IO Unit */
222+
extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_cancel(b_obj_arg obj, obj_arg /* w */) {
223+
lean_uv_signal_object * signal = lean_to_uv_signal(obj);
224+
225+
// It's locking here to avoid changing the state during other operations.
226+
event_loop_lock(&global_ev);
227+
228+
if (signal->m_state == SIGNAL_STATE_RUNNING && signal->m_promise != NULL) {
229+
if (signal->m_repeating) {
230+
lean_dec(signal->m_promise);
231+
signal->m_promise = NULL;
232+
} else {
233+
uv_signal_stop(signal->m_uv_signal);
234+
lean_dec(signal->m_promise);
235+
signal->m_promise = NULL;
236+
signal->m_state = SIGNAL_STATE_INITIAL;
237+
238+
lean_dec(obj);
239+
}
240+
}
241+
242+
event_loop_unlock(&global_ev);
243+
return lean_io_result_mk_ok(lean_box(0));
244+
}
214245

215246
#else
216247

@@ -235,6 +266,13 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_stop(b_obj_arg signal, obj_ar
235266
);
236267
}
237268

269+
/* Std.Internal.UV.Signal.cancel (signal : @& Signal) : IO Unit */
270+
extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_cancel(b_obj_arg obj, obj_arg /* w */) {
271+
lean_always_assert(
272+
false && ("Please build a version of Lean4 with libuv to invoke this.")
273+
);
274+
}
275+
238276
#endif
239277

240278
}

src/runtime/uv/signal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ static inline lean_uv_signal_object* lean_to_uv_signal(lean_object * o) { return
5151
extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_mk(uint32_t signum_obj, uint8_t repeating, obj_arg /* w */);
5252
extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_next(b_obj_arg signal, obj_arg /* w */);
5353
extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_stop(b_obj_arg signal, obj_arg /* w */);
54+
extern "C" LEAN_EXPORT lean_obj_res lean_uv_signal_cancel(b_obj_arg obj, obj_arg /* w */);
5455

5556
}

0 commit comments

Comments
 (0)