From 2deae428fad139268f0e28a51d5133b2463a2fb3 Mon Sep 17 00:00:00 2001 From: itowlson Date: Tue, 26 Nov 2024 10:23:49 +1300 Subject: [PATCH] Add optional close/shutdown message to AsyncWriteStream Signed-off-by: itowlson --- crates/wasi/src/write_stream.rs | 63 +++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/crates/wasi/src/write_stream.rs b/crates/wasi/src/write_stream.rs index fe3658662ff2..e394b8fd7d0d 100644 --- a/crates/wasi/src/write_stream.rs +++ b/crates/wasi/src/write_stream.rs @@ -9,6 +9,7 @@ struct WorkerState { items: std::collections::VecDeque, write_budget: usize, flush_pending: bool, + shutdown_pending: bool, error: Option, } @@ -31,6 +32,7 @@ struct Worker { } enum Job { + Shutdown, Flush, Write(Bytes), } @@ -43,6 +45,7 @@ impl Worker { items: std::collections::VecDeque::new(), write_budget, flush_pending: false, + shutdown_pending: false, error: None, }), new_work: tokio::sync::Notify::new(), @@ -55,7 +58,7 @@ impl Worker { let state = self.state(); if state.error.is_some() || !state.alive - || (!state.flush_pending && state.write_budget > 0) + || (!state.flush_pending && !state.shutdown_pending && state.write_budget > 0) { return; } @@ -69,7 +72,7 @@ impl Worker { return Err(e); } - if state.flush_pending || state.write_budget == 0 { + if state.flush_pending || state.shutdown_pending || state.write_budget == 0 { return Ok(0); } @@ -84,6 +87,9 @@ impl Worker { if state.flush_pending { return Some(Job::Flush); } + if state.shutdown_pending { + return Some(Job::Shutdown); + } } else if let Some(bytes) = state.items.pop_front() { return Some(Job::Write(bytes)); } @@ -96,6 +102,7 @@ impl Worker { state.alive = false; state.error = Some(e.into()); state.flush_pending = false; + state.shutdown_pending = false; } self.write_ready_changed.notify_one(); } @@ -114,6 +121,14 @@ impl Worker { self.state().flush_pending = false; } + Job::Shutdown => { + if let Err(e) = writer.shutdown().await { + self.report_error(e); + return; + } + self.state().shutdown_pending = false; + } + Job::Write(mut bytes) => { tracing::debug!("worker writing: {bytes:?}"); let len = bytes.len(); @@ -140,6 +155,7 @@ impl Worker { pub struct AsyncWriteStream { worker: Arc, join_handle: Option>, + shutdown_join_handle: Option, } impl AsyncWriteStream { @@ -157,6 +173,46 @@ impl AsyncWriteStream { AsyncWriteStream { worker, join_handle: Some(join_handle), + shutdown_join_handle: None, + } + } + + /// Create a [`AsyncWriteStream`]. In order to use the [`HostOutputStream`] impl + /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`]. + /// + /// The [`AsyncWriteStream`] created by this constructor can be shut down (that is, + /// graceful EOF) by sending a message through the sender side of the `shutdown_rx` + /// sync channel. + pub fn new_closeable( + write_budget: usize, + writer: T, + mut shutdown_rx: tokio::sync::mpsc::Receiver<()>, + ) -> Self { + let worker = Arc::new(Worker::new(write_budget)); + + let w = Arc::clone(&worker); + let join_handle = crate::runtime::spawn(async move { w.work(writer).await }); + + let w_clone = worker.clone(); + let shutdown_join_handle = tokio::spawn(async move { + let shutdown_msg = shutdown_rx.recv().await; + if shutdown_msg.is_some() { + let mut state = w_clone.state(); + if state.check_error().is_err() { + // The stream is already failing - no point shutting it down. + return; + } + + state.shutdown_pending = true; + w_clone.new_work.notify_one(); + } + }) + .abort_handle(); + + AsyncWriteStream { + worker, + join_handle: Some(join_handle), + shutdown_join_handle: Some(shutdown_join_handle), } } } @@ -197,6 +253,9 @@ impl HostOutputStream for AsyncWriteStream { } async fn cancel(&mut self) { + if let Some(handle) = self.shutdown_join_handle.take() { + handle.abort(); + }; match self.join_handle.take() { Some(task) => _ = task.cancel().await, None => {}