diff --git a/Cargo.toml b/Cargo.toml index 751d438a..57ff3aba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,9 +119,6 @@ bytemuck = { version = "1.13", features = ["derive"] } zerocopy = "0.8" serio = { version = "0.2" } -# io -uid-mux = { version = "0.2" } - # testing rstest = "0.12" pretty_assertions = "1" diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 79ae7e09..72ce792d 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -8,7 +8,7 @@ default = [] executor = [] sync = ["tokio/sync"] future = [] -test-utils = ["uid-mux/test-utils", "tokio/io-util", "tokio-util/compat"] +test-utils = ["tokio/io-util", "tokio-util/compat"] ideal = ["tokio/sync"] [dependencies] @@ -18,7 +18,6 @@ bytes = { workspace = true } pin-project-lite.workspace = true thiserror.workspace = true serio.workspace = true -uid-mux = { workspace = true } serde = { workspace = true, features = ["derive"] } pollster.workspace = true cfg-if.workspace = true @@ -34,7 +33,6 @@ tokio = { workspace = true, features = [ "net", ] } tokio-util = { workspace = true, features = ["compat"] } -uid-mux = { workspace = true, features = ["test-utils"] } tracing-subscriber = { workspace = true, features = ["fmt"] } criterion = { workspace = true, features = ["async_tokio"] } rstest = { workspace = true } diff --git a/crates/common/src/context.rs b/crates/common/src/context.rs index f20f78ef..f2fb4fad 100644 --- a/crates/common/src/context.rs +++ b/crates/common/src/context.rs @@ -118,7 +118,7 @@ impl Context { match &mut self.mode { Mode::St => Ok(st::map(self, items, f).await), Mode::Mt { threads } => { - let threads = threads.get(threads.concurrency()).await?; + let threads = threads.get(threads.concurrency())?; mt::map(threads, items, f, weight).await } } @@ -138,7 +138,7 @@ impl Context { match &mut self.mode { Mode::St => Ok(st::join(self, a, b).await), Mode::Mt { threads } => { - let threads = threads.get(2).await?; + let threads = threads.get(2)?; mt::join(threads, a, b).await } } @@ -167,7 +167,7 @@ impl Context { match &mut self.mode { Mode::St => Ok(st::try_join(self, a, b).await), Mode::Mt { threads } => { - let threads = threads.get(2).await?; + let threads = threads.get(2)?; mt::try_join(threads, a, b).await } } @@ -192,7 +192,7 @@ impl Context { match &mut self.mode { Mode::St => Ok(st::try_join3(self, a, b, c).await), Mode::Mt { threads } => { - let threads = threads.get(3).await?; + let threads = threads.get(3)?; mt::try_join3(threads, a, b, c).await } } @@ -220,7 +220,7 @@ impl Context { match &mut self.mode { Mode::St => Ok(st::try_join4(self, a, b, c, d).await), Mode::Mt { threads } => { - let threads = threads.get(4).await?; + let threads = threads.get(4)?; mt::try_join4(threads, a, b, c, d).await } } diff --git a/crates/common/src/context/mt.rs b/crates/common/src/context/mt.rs index cde831c1..3dd09c4c 100644 --- a/crates/common/src/context/mt.rs +++ b/crates/common/src/context/mt.rs @@ -41,10 +41,12 @@ impl Multithread { ContextError::new(ErrorKind::Thread, "thread ID overflow".to_string()) })?; - let io_fut = { self.builder.lock().unwrap().mux.open(id.clone()) }; - - let io = io_fut - .await + let io = self + .builder + .lock() + .unwrap() + .mux + .open(id.clone()) .map_err(|e| ContextError::new(ErrorKind::Mux, e))?; let ctx = @@ -60,15 +62,16 @@ pub(crate) struct ThreadBuilder { } impl ThreadBuilder { - async fn spawn( + fn spawn( this: Arc>, id: ThreadId, config: Arc, ) -> Result { - let io_fut = { this.lock().unwrap().mux.open(id.clone()) }; - - let io = io_fut - .await + let io = this + .lock() + .unwrap() + .mux + .open(id.clone()) .map_err(|e| ContextError::new(ErrorKind::Mux, e))?; let ctx = Context::new_multi_threaded(id.clone(), io, config, this.clone()); @@ -115,7 +118,7 @@ impl Threads { self.config.concurrency } - pub(crate) async fn get(&mut self, count: usize) -> Result<&[Handle], ContextError> { + pub(crate) fn get(&mut self, count: usize) -> Result<&[Handle], ContextError> { if count > self.config.concurrency { return Err(ContextError::new( ErrorKind::Thread, @@ -128,8 +131,7 @@ impl Threads { ContextError::new(ErrorKind::Thread, "thread ID overflow".to_string()) })?; - let child = - ThreadBuilder::spawn(self.builder.clone(), id, self.config.clone()).await?; + let child = ThreadBuilder::spawn(self.builder.clone(), id, self.config.clone())?; self.children.push(child); } } diff --git a/crates/common/src/context/mt/builder.rs b/crates/common/src/context/mt/builder.rs index d3211a37..4eb9e935 100644 --- a/crates/common/src/context/mt/builder.rs +++ b/crates/common/src/context/mt/builder.rs @@ -1,7 +1,5 @@ use std::sync::{Arc, Mutex}; -use uid_mux::UidMux; - use crate::{ ThreadId, context::{ @@ -73,18 +71,8 @@ impl MultithreadBuilder { } /// Sets the multiplexer. - pub fn mux(mut self, mux: M) -> Self - where - M: UidMux + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, - { - self.mux = Some(Box::new(mux)); - self - } - - #[allow(dead_code)] - pub(crate) fn mux_internal(mut self, mux: Box) -> Self { - self.mux = Some(mux); + pub fn mux>>(mut self, mux: M) -> Self { + self.mux = Some(mux.into()); self } diff --git a/crates/common/src/context/test/helpers.rs b/crates/common/src/context/test/helpers.rs index 2a348e5f..c5bfa43c 100644 --- a/crates/common/src/context/test/helpers.rs +++ b/crates/common/src/context/test/helpers.rs @@ -1,8 +1,8 @@ //! Basic test context helpers. +use crate::mux::test_framed_mux; use futures::{AsyncRead, AsyncWrite}; use serio::channel::duplex; -use uid_mux::test_utils::test_framed_mux; use crate::{ context::{Context, Multithread, SpawnError}, @@ -41,8 +41,8 @@ pub fn test_mt_context(io_buffer: usize) -> (Multithread, Multithread) { let mux_1: Box = Box::new(mux_1); ( - Multithread::builder().mux_internal(mux_0).build().unwrap(), - Multithread::builder().mux_internal(mux_1).build().unwrap(), + Multithread::builder().mux(mux_0).build().unwrap(), + Multithread::builder().mux(mux_1).build().unwrap(), ) } @@ -62,12 +62,12 @@ where ( Multithread::builder() .spawn_handler(spawn.clone()) - .mux_internal(mux_0) + .mux(mux_0) .build() .unwrap(), Multithread::builder() .spawn_handler(spawn) - .mux_internal(mux_1) + .mux(mux_1) .build() .unwrap(), ) @@ -95,13 +95,13 @@ where Multithread::builder() .concurrency(concurrency) .spawn_handler(spawn.clone()) - .mux_internal(mux_0) + .mux(mux_0) .build() .unwrap(), Multithread::builder() .concurrency(concurrency) .spawn_handler(spawn) - .mux_internal(mux_1) + .mux(mux_1) .build() .unwrap(), ) diff --git a/crates/common/src/context/test/recording.rs b/crates/common/src/context/test/recording.rs index f4ddc0a8..fc94ab10 100644 --- a/crates/common/src/context/test/recording.rs +++ b/crates/common/src/context/test/recording.rs @@ -177,75 +177,69 @@ impl std::fmt::Debug for RecordingTestMux { } impl Mux for RecordingTestMux { - fn open( - &self, - id: ThreadId, - ) -> Pin> + Send>> { - let mux = self.clone(); - Box::pin(async move { - let mut state = mux.state.lock().unwrap(); - - // Check if channel already exists from the other side - match mux.role { - RecordingRole::A => { - if let Some(stream) = state.waiting_a.remove(&id) { - return Ok(if let Some(limit) = mux.max_frame_length { - Io::from_io_with_limit(stream, limit) - } else { - Io::from_io(stream) - }); - } + fn open(&self, id: ThreadId) -> Result { + let mut state = self.state.lock().unwrap(); + + // Check if channel already exists from the other side + match self.role { + RecordingRole::A => { + if let Some(stream) = state.waiting_a.remove(&id) { + return Ok(if let Some(limit) = self.max_frame_length { + Io::from_io_with_limit(stream, limit) + } else { + Io::from_io(stream) + }); } - RecordingRole::B => { - if let Some(recording_stream) = state.waiting_b.remove(&id) { - return Ok(if let Some(limit) = mux.max_frame_length { - Io::from_io_with_limit(recording_stream, limit) - } else { - Io::from_io(recording_stream) - }); - } + } + RecordingRole::B => { + if let Some(recording_stream) = state.waiting_b.remove(&id) { + return Ok(if let Some(limit) = self.max_frame_length { + Io::from_io_with_limit(recording_stream, limit) + } else { + Io::from_io(recording_stream) + }); } } + } - // Check for duplicate - if !state.opened.insert(id.clone()) { - return Err(std::io::Error::other("duplicate stream id")); - } + // Check for duplicate + if !state.opened.insert(id.clone()) { + return Err(std::io::Error::other("duplicate stream id")); + } - // Create new byte-based channel pair - let (stream_a, stream_b) = tokio::io::duplex(mux.buffer); - - // Role B's writes are recorded - let recorded_for_channel = mux.recorded.clone(); - let channel_id = id.clone(); - - match mux.role { - RecordingRole::A => { - // A gets plain stream, B gets recording stream - let recording_stream = - RecordingDuplexWithId::new(stream_b, channel_id, recorded_for_channel); - state - .waiting_b - .insert(id, recording_stream.into_recording_duplex()); - Ok(if let Some(limit) = mux.max_frame_length { - Io::from_io_with_limit(stream_a.compat(), limit) - } else { - Io::from_io(stream_a.compat()) - }) - } - RecordingRole::B => { - // B gets recording stream, A gets plain stream - state.waiting_a.insert(id, stream_a.compat()); - let recording_stream = - RecordingDuplexWithId::new(stream_b, channel_id, recorded_for_channel); - Ok(if let Some(limit) = mux.max_frame_length { - Io::from_io_with_limit(recording_stream.into_recording_duplex(), limit) - } else { - Io::from_io(recording_stream.into_recording_duplex()) - }) - } + // Create new byte-based channel pair + let (stream_a, stream_b) = tokio::io::duplex(self.buffer); + + // Role B's writes are recorded + let recorded_for_channel = self.recorded.clone(); + let channel_id = id.clone(); + + match self.role { + RecordingRole::A => { + // A gets plain stream, B gets recording stream + let recording_stream = + RecordingDuplexWithId::new(stream_b, channel_id, recorded_for_channel); + state + .waiting_b + .insert(id, recording_stream.into_recording_duplex()); + Ok(if let Some(limit) = self.max_frame_length { + Io::from_io_with_limit(stream_a.compat(), limit) + } else { + Io::from_io(stream_a.compat()) + }) + } + RecordingRole::B => { + // B gets recording stream, A gets plain stream + state.waiting_a.insert(id, stream_a.compat()); + let recording_stream = + RecordingDuplexWithId::new(stream_b, channel_id, recorded_for_channel); + Ok(if let Some(limit) = self.max_frame_length { + Io::from_io_with_limit(recording_stream.into_recording_duplex(), limit) + } else { + Io::from_io(recording_stream.into_recording_duplex()) + }) } - }) + } } } @@ -374,8 +368,8 @@ pub fn recording_mt_context( let mux_1: Box = Box::new(mux_1); ( - Multithread::builder().mux_internal(mux_0).build().unwrap(), - Multithread::builder().mux_internal(mux_1).build().unwrap(), + Multithread::builder().mux(mux_0).build().unwrap(), + Multithread::builder().mux(mux_1).build().unwrap(), recorded, ) } @@ -397,8 +391,8 @@ pub fn recording_mt_context_with_limit( let mux_1: Box = Box::new(mux_1); ( - Multithread::builder().mux_internal(mux_0).build().unwrap(), - Multithread::builder().mux_internal(mux_1).build().unwrap(), + Multithread::builder().mux(mux_0).build().unwrap(), + Multithread::builder().mux(mux_1).build().unwrap(), recorded, ) } @@ -425,12 +419,12 @@ where ( Multithread::builder() .spawn_handler(spawn.clone()) - .mux_internal(mux_0) + .mux(mux_0) .build() .unwrap(), Multithread::builder() .spawn_handler(spawn) - .mux_internal(mux_1) + .mux(mux_1) .build() .unwrap(), recorded, @@ -465,13 +459,13 @@ where Multithread::builder() .spawn_handler(spawn.clone()) .concurrency(concurrency) - .mux_internal(mux_0) + .mux(mux_0) .build() .unwrap(), Multithread::builder() .spawn_handler(spawn) .concurrency(concurrency) - .mux_internal(mux_1) + .mux(mux_1) .build() .unwrap(), recorded, diff --git a/crates/common/src/context/test/replay.rs b/crates/common/src/context/test/replay.rs index c3d38c7f..f3695343 100644 --- a/crates/common/src/context/test/replay.rs +++ b/crates/common/src/context/test/replay.rs @@ -103,24 +103,19 @@ impl ReplayTestMux { } impl Mux for ReplayTestMux { - fn open( - &self, - id: ThreadId, - ) -> Pin> + Send>> { + fn open(&self, id: ThreadId) -> Result { let recorded = self.recorded.clone(); let max_frame_length = self.max_frame_length; - Box::pin(async move { - let data = { - let mut rec = recorded.lock().unwrap(); - rec.channels.remove(&id).unwrap_or_default() - }; - let replay = ReplayDuplex::new(data); - if let Some(limit) = max_frame_length { - Ok(Io::from_io_with_limit(replay, limit)) - } else { - Ok(Io::from_io(replay)) - } - }) + let data = { + let mut rec = recorded.lock().unwrap(); + rec.channels.remove(&id).unwrap_or_default() + }; + let replay = ReplayDuplex::new(data); + if let Some(limit) = max_frame_length { + Ok(Io::from_io_with_limit(replay, limit)) + } else { + Ok(Io::from_io(replay)) + } } } @@ -137,7 +132,7 @@ pub fn replay_mt_context(recorded: RecordedMtData) -> Multithread { let mux = ReplayTestMux::new(recorded, None); let mux: Box = Box::new(mux); - Multithread::builder().mux_internal(mux).build().unwrap() + Multithread::builder().mux(mux).build().unwrap() } /// Creates a multi-threaded context that replays recorded data with a custom @@ -154,7 +149,7 @@ pub fn replay_mt_context_with_limit( let mux = ReplayTestMux::new(recorded, Some(max_frame_length)); let mux: Box = Box::new(mux); - Multithread::builder().mux_internal(mux).build().unwrap() + Multithread::builder().mux(mux).build().unwrap() } /// Creates a multi-threaded context that replays recorded data with custom @@ -173,7 +168,7 @@ where Multithread::builder() .spawn_handler(spawn) - .mux_internal(mux) + .mux(mux) .build() .unwrap() } @@ -203,7 +198,7 @@ where Multithread::builder() .spawn_handler(spawn) .concurrency(concurrency) - .mux_internal(mux) + .mux(mux) .build() .unwrap() } diff --git a/crates/common/src/io.rs b/crates/common/src/io.rs index 39ca3ed3..87e38386 100644 --- a/crates/common/src/io.rs +++ b/crates/common/src/io.rs @@ -122,6 +122,7 @@ impl Io { } } + #[cfg(any(test, feature = "test-utils"))] pub(crate) fn from_io_with_limit( io: Io, max_frame_length: usize, diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 7cf040b6..721824fd 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -23,7 +23,7 @@ mod id; pub mod ideal; pub mod io; pub(crate) mod load_balance; -mod mux; +pub mod mux; #[cfg(feature = "sync")] pub mod sync; mod task; diff --git a/crates/common/src/mux.rs b/crates/common/src/mux.rs index 40b75e96..937fc59e 100644 --- a/crates/common/src/mux.rs +++ b/crates/common/src/mux.rs @@ -1,63 +1,131 @@ -use std::{future::Future, pin::Pin}; - -use uid_mux::UidMux; +//! Multiplexing types. use crate::{ThreadId, io::Io}; -pub(crate) trait Mux { +/// A multiplexer. +pub trait Mux { /// Opens a new I/O channel for the given thread. - fn open( - &self, - id: ThreadId, - ) -> Pin> + Send>>; -} - -impl Mux for T -where - T: UidMux + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, -{ - fn open( - &self, - id: ThreadId, - ) -> Pin> + Send>> { - let mux = self.clone(); - Box::pin(async move { - let io = mux.open(&id).await.map_err(std::io::Error::other)?; - - Ok(Io::from_io(io)) - }) - } + fn open(&self, id: ThreadId) -> Result; } #[cfg(any(test, feature = "test-utils"))] mod test_utils { - use super::*; - use uid_mux::{FramedUidMux, test_utils::TestFramedMux}; + use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Mutex}, + }; + + use serio::channel::{MemoryDuplex, duplex}; + + use crate::{ThreadId, io::Io, mux::Mux}; + + #[derive(Debug, Default)] + struct State { + exists: HashSet>, + waiting_a: HashMap, MemoryDuplex>, + waiting_b: HashMap, MemoryDuplex>, + } + + #[derive(Debug, Clone, Copy)] + enum Role { + A, + B, + } + + /// A test framed mux. + #[derive(Debug, Clone)] + pub struct TestFramedMux { + role: Role, + buffer: usize, + state: Arc>, + } impl Mux for TestFramedMux { - fn open( - &self, - id: ThreadId, - ) -> Pin> + Send>> { - let mux = self.clone(); - Box::pin(async move { - let io = mux.open_framed(&id).await.map_err(std::io::Error::other)?; - - Ok(Io::from_channel(io)) - }) + fn open(&self, id: ThreadId) -> Result { + let mut state = self.state.lock().unwrap(); + + if let Some(channel) = match self.role { + Role::A => state.waiting_a.remove(id.as_ref()), + Role::B => state.waiting_b.remove(id.as_ref()), + } { + Ok(Io::from_channel(channel)) + } else { + if !state.exists.insert(id.as_ref().to_vec()) { + return Err(std::io::Error::other("duplicate stream id")); + } + + let (a, b) = duplex(self.buffer); + + match self.role { + Role::A => { + state.waiting_b.insert(id.as_ref().to_vec(), b); + Ok(Io::from_channel(a)) + } + Role::B => { + state.waiting_a.insert(id.as_ref().to_vec(), a); + Ok(Io::from_channel(b)) + } + } + } } } -} -#[cfg(test)] -mod tests { - use super::*; - use uid_mux::yamux::YamuxCtrl; + /// Creates a test pair of framed mux instances. + pub fn test_framed_mux(buffer: usize) -> (TestFramedMux, TestFramedMux) { + let state = Arc::new(Mutex::new(State::default())); - #[test] - fn test_yamux_is_mux() { - fn assert_mux() {} - assert_mux::(); + ( + TestFramedMux { + role: Role::A, + buffer, + state: state.clone(), + }, + TestFramedMux { + role: Role::B, + buffer, + state, + }, + ) + } + + #[cfg(test)] + mod tests { + use crate::{ThreadId, mux::Mux}; + use serio::{SinkExt, StreamExt}; + + #[test] + fn test_framed_mux() { + let (a, b) = super::test_framed_mux(1); + + futures::executor::block_on(async { + let mut a_0 = a.open(ThreadId::new(0)).unwrap(); + let mut b_0 = b.open(ThreadId::new(0)).unwrap(); + + let mut a_1 = a.open(ThreadId::new(1)).unwrap(); + let mut b_1 = b.open(ThreadId::new(1)).unwrap(); + + a_0.send(42u8).await.unwrap(); + assert_eq!(b_0.next::().await.unwrap().unwrap(), 42); + + a_1.send(69u8).await.unwrap(); + assert_eq!(b_1.next::().await.unwrap().unwrap(), 69u8); + }) + } + + #[test] + fn test_framed_mux_duplicate() { + let (a, b) = super::test_framed_mux(1); + + futures::executor::block_on(async { + let _ = a.open(ThreadId::new(0)).unwrap(); + let _ = b.open(ThreadId::new(0)).unwrap(); + + assert!(a.open(ThreadId::new(0)).is_err()); + assert!(b.open(ThreadId::new(0)).is_err()); + }) + } } } + +#[cfg(any(test, feature = "test-utils"))] +pub use test_utils::{TestFramedMux, test_framed_mux};