Skip to content

Commit dd6d97d

Browse files
committed
Allow suspending endpoints
1 parent 3a4dcc7 commit dd6d97d

File tree

7 files changed

+349
-4
lines changed

7 files changed

+349
-4
lines changed

wl-proxy/src/client.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,18 @@ impl Client {
8686
self.handler.set(None);
8787
self.state.remove_endpoint(&self.endpoint);
8888
}
89+
90+
/// Suspends or unsuspends dispatching messages from the client.
91+
///
92+
/// Suspending takes effect immediately. That is, if this is called from within a
93+
/// message handler, no further messages from the client will be dispatched until it
94+
/// is unsuspended.
95+
///
96+
/// This can be useful in situations where one clients needs to synchronize with
97+
/// another. For example, when a client sends `wl_surface.commit` and another client
98+
/// needs to take some action before the commit is forwarded to the server.
99+
pub fn set_suspended(self: &Rc<Self>, suspended: bool) {
100+
self.state
101+
.set_endpoint_suspended(&self.endpoint, Some(self), suspended);
102+
}
89103
}

wl-proxy/src/endpoint.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ pub(crate) struct Endpoint {
3232
pub(crate) current_interest: Cell<u32>,
3333
pub(crate) desired_interest: Cell<u32>,
3434
pub(crate) interest_update_queued: Cell<bool>,
35+
pub(crate) suspended: Cell<bool>,
36+
pub(crate) desired_suspended: Cell<bool>,
37+
pub(crate) unsuspend_queued: Cell<bool>,
3538
incoming: RefCell<InputState>,
3639
}
3740

@@ -103,6 +106,9 @@ impl Endpoint {
103106
current_interest: Default::default(),
104107
desired_interest: Default::default(),
105108
interest_update_queued: Default::default(),
109+
suspended: Default::default(),
110+
desired_suspended: Default::default(),
111+
unsuspend_queued: Default::default(),
106112
incoming: Default::default(),
107113
})
108114
}
@@ -128,6 +134,9 @@ impl Endpoint {
128134
let fds = &mut incoming.fds;
129135
let mut may_read_from_socket = true;
130136
loop {
137+
if self.suspended.get() {
138+
break;
139+
}
131140
if let Some(client) = client
132141
&& client.destroyed.get()
133142
{

wl-proxy/src/poll.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ impl Poller {
129129
.map_err(|e| PollError::Update(e.into()))
130130
}
131131

132-
#[cfg_attr(not(test), expect(dead_code))]
133132
pub(crate) fn register_edge_triggered(
134133
&self,
135134
id: u64,

wl-proxy/src/state.rs

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use {
2323
cell::{Cell, RefCell},
2424
collections::HashMap,
2525
io::{self, pipe},
26-
os::fd::{AsFd, OwnedFd},
26+
os::fd::{AsFd, AsRawFd, OwnedFd},
2727
rc::{Rc, Weak},
2828
sync::{
2929
Arc,
@@ -91,6 +91,8 @@ enum StateErrorKind {
9191
WaylandSocketSetFd(#[source] io::Error),
9292
#[error(transparent)]
9393
PollError(PollError),
94+
#[error("Could not create an eventfd")]
95+
CreateEventfd(#[source] io::Error),
9496
}
9597

9698
/// The proxy state.
@@ -177,6 +179,10 @@ pub struct State {
177179
pub(crate) object_stash: Stash<Rc<dyn Object>>,
178180
pub(crate) forward_to_client: Cell<bool>,
179181
pub(crate) forward_to_server: Cell<bool>,
182+
unsuspend_fd: OwnedFd,
183+
unsuspend_requests: Stack<EndpointWithClient>,
184+
has_unsuspend_requests: Cell<bool>,
185+
unsuspend_triggered: Cell<bool>,
180186
}
181187

182188
/// A handler for events emitted by a [`State`].
@@ -213,6 +219,7 @@ enum Pollable {
213219
Endpoint(EndpointWithClient),
214220
Acceptor(Rc<Acceptor>),
215221
Destructor(OwnedFd, Arc<AtomicBool>),
222+
Unsuspend,
216223
}
217224

218225
#[derive(Clone)]
@@ -302,6 +309,24 @@ impl State {
302309
Ok(true)
303310
}
304311

312+
fn unsuspend_endpoints(self: &Rc<Self>, _lock: &HandlerLock<'_>) -> Result<(), StateError> {
313+
if !self.has_unsuspend_requests.get() {
314+
return Ok(());
315+
}
316+
self.check_destroyed()?;
317+
while let Some(ewc) = self.unsuspend_requests.pop() {
318+
ewc.endpoint.unsuspend_queued.set(false);
319+
if ewc.endpoint.desired_suspended.get() {
320+
continue;
321+
}
322+
ewc.endpoint.suspended.set(false);
323+
self.readable_endpoints.push(ewc);
324+
self.has_readable_endpoints.set(true);
325+
}
326+
self.has_unsuspend_requests.set(false);
327+
Ok(())
328+
}
329+
305330
fn accept_connections(self: &Rc<Self>, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
306331
if !self.has_acceptable_acceptors.get() {
307332
return Ok(false);
@@ -346,12 +371,50 @@ impl State {
346371
return Err(StateErrorKind::DispatchEvents(e).into());
347372
}
348373
}
349-
self.change_interest(&ewc.endpoint, |i| i | poll::READABLE);
374+
if !ewc.endpoint.suspended.get() {
375+
self.change_interest(&ewc.endpoint, |i| i | poll::READABLE);
376+
}
350377
}
351378
self.has_readable_endpoints.set(false);
352379
Ok(true)
353380
}
354381

382+
pub(crate) fn set_endpoint_suspended(
383+
&self,
384+
endpoint: &Rc<Endpoint>,
385+
client: Option<&Rc<Client>>,
386+
suspended: bool,
387+
) {
388+
if self.destroyed.get() {
389+
return;
390+
}
391+
if suspended {
392+
endpoint.suspended.set(true);
393+
endpoint.desired_suspended.set(true);
394+
return;
395+
}
396+
endpoint.desired_suspended.set(false);
397+
if endpoint.unsuspend_queued.get() {
398+
return;
399+
}
400+
if !self.unsuspend_triggered.get() {
401+
if let Err(e) = uapi::eventfd_write(self.unsuspend_fd.as_raw_fd(), 1) {
402+
log::error!(
403+
"Could not write to eventfd: {}",
404+
Report::new(io::Error::from(e)),
405+
);
406+
self.destroy();
407+
return;
408+
}
409+
self.unsuspend_triggered.set(true);
410+
}
411+
self.unsuspend_requests.push(EndpointWithClient {
412+
endpoint: endpoint.clone(),
413+
client: client.cloned(),
414+
});
415+
endpoint.unsuspend_queued.set(true);
416+
}
417+
355418
fn change_interest(&self, endpoint: &Rc<Endpoint>, f: impl FnOnce(u32) -> u32) {
356419
if self.destroyed.get() {
357420
return;
@@ -435,6 +498,10 @@ impl State {
435498
return Err(StateErrorKind::RemoteDestroyed.into());
436499
}
437500
}
501+
Pollable::Unsuspend => {
502+
self.has_unsuspend_requests.set(true);
503+
self.unsuspend_triggered.set(false);
504+
}
438505
}
439506
}
440507
}
@@ -561,12 +628,22 @@ impl State {
561628
did_work |= self.flush_locked(&lock)?;
562629
}
563630
self.wait_for_work(&lock, timeout)?;
631+
self.unsuspend_endpoints(&lock)?;
564632
did_work |= self.accept_connections(&lock)?;
565633
did_work |= self.read_messages(&lock)?;
566634
did_work |= self.flush_locked(&lock)?;
567635
destroy_on_error.forget();
568636
Ok(did_work)
569637
}
638+
639+
/// Suspends or unsuspends dispatching messages from the server.
640+
///
641+
/// See also [`Client::set_suspended`].
642+
pub fn set_suspended(&self, suspended: bool) {
643+
if let Some(endpoint) = &self.server {
644+
self.set_endpoint_suspended(endpoint, None, suspended);
645+
}
646+
}
570647
}
571648

572649
impl State {
@@ -832,6 +909,7 @@ impl State {
832909
}
833910
Pollable::Acceptor(a) => &a.socket,
834911
Pollable::Destructor(fd, _) => fd,
912+
Pollable::Unsuspend => &self.unsuspend_fd,
835913
};
836914
self.poller.unregister(fd.as_fd());
837915
}
@@ -853,6 +931,7 @@ impl State {
853931
self.flushable_endpoints.take();
854932
self.interest_update_endpoints.take();
855933
self.interest_update_acceptors.take();
934+
self.unsuspend_requests.take();
856935
self.all_objects.borrow_mut().clear();
857936
// Ensure that the poll fd stays permanently readable.
858937
let _ = self.create_remote_destructor();

wl-proxy/src/state/builder.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ enum Server {
4242
#[derive(Copy, Clone, Linearize)]
4343
pub(crate) enum StaticPollableIds {
4444
Server,
45+
Unsuspend,
4546
}
4647

4748
impl StateBuilder {
@@ -136,6 +137,10 @@ impl StateBuilder {
136137
);
137138
server = Some(s);
138139
}
140+
let unsuspend_fd = uapi::eventfd(0, c::EFD_CLOEXEC | c::EFD_NONBLOCK)
141+
.map(Into::into)
142+
.map_err(|e| StateErrorKind::CreateEventfd(e.into()))?;
143+
endpoints.insert(StaticPollableIds::Unsuspend as u64, Pollable::Unsuspend);
139144
let poller = Poller::new().map_err(StateErrorKind::PollError)?;
140145
#[cfg(feature = "logging")]
141146
let log_prefix = {
@@ -190,6 +195,10 @@ impl StateBuilder {
190195
object_stash: Default::default(),
191196
forward_to_client: Cell::new(true),
192197
forward_to_server: Cell::new(true),
198+
unsuspend_fd,
199+
unsuspend_requests: Default::default(),
200+
has_unsuspend_requests: Default::default(),
201+
unsuspend_triggered: Default::default(),
193202
});
194203
if let Some(server) = &state.server {
195204
state.change_interest(server, |i| i | poll::READABLE);
@@ -203,6 +212,14 @@ impl StateBuilder {
203212
.set_server_id_unchecked(1, display.clone())
204213
.unwrap();
205214
}
215+
state
216+
.poller
217+
.register_edge_triggered(
218+
StaticPollableIds::Unsuspend as u64,
219+
state.unsuspend_fd.as_fd(),
220+
poll::READABLE,
221+
)
222+
.map_err(StateErrorKind::PollError)?;
206223
Ok(state)
207224
}
208225

0 commit comments

Comments
 (0)