diff --git a/crossbeam-channel/src/channel.rs b/crossbeam-channel/src/channel.rs index 5447e3303..be52df064 100644 --- a/crossbeam-channel/src/channel.rs +++ b/crossbeam-channel/src/channel.rs @@ -656,6 +656,44 @@ impl Sender { _ => false, } } + + /// Disconnects the channel. + /// + /// Explicitly disconnects the channel and returns `true`, as if all instances of either side + /// (sender or receiver) are dropped, unless it has been already disconnected. Otherwise, this + /// method does nothing and returns `false`. + /// + /// The successful disconnect operation results in immediately rejecting sending subsequent + /// messages to receivers. + /// + /// Disconnected channels can be restored connected by calling [`connect`], as long as there + /// are at least a sender and a receiver for the channel. + /// + /// [`connect`]: Self::connect + pub fn disconnect(&self) -> bool { + match &self.flavor { + SenderFlavor::List(chan) => chan.disconnect(), + _ => todo!(), + } + } + + /// Connects the channel. + /// + /// Connects the disconnected channel and returns `true`, unless it has been already + /// connected. Otherwise, this method does nothing and returns `false`. + /// + /// The successful connect operation results in immediately succeeding sending subsequent + /// messages to receivers. + /// + /// Connected channels can be disconnected again by calling [`disconnect`]. + /// + /// [`disconnect`]: Self::disconnect + pub fn connect(&self) -> bool { + match &self.flavor { + SenderFlavor::List(chan) => chan.connect(), + _ => todo!(), + } + } } impl Drop for Sender { @@ -663,7 +701,7 @@ impl Drop for Sender { unsafe { match &self.flavor { SenderFlavor::Array(chan) => chan.release(|c| c.disconnect()), - SenderFlavor::List(chan) => chan.release(|c| c.disconnect_senders()), + SenderFlavor::List(chan) => chan.release(|c| c.disconnect_senders_to_drop()), SenderFlavor::Zero(chan) => chan.release(|c| c.disconnect()), } } @@ -1153,6 +1191,49 @@ impl Receiver { _ => false, } } + + /// Disconnects the channel. + /// + /// Explicitly disconnects the channel and returns `true`, as if all instances of either side + /// (receiver and sender) are dropped, unless it has been already disconnected. Otherwise, this + /// method does nothing and returns `false`. Also, this method does nothing and returns + /// `false` as well, if this method is called on channels other than bounded-capacity, + /// unbounded-capacity or zero-capacity channels. + /// + /// The successful disconnect operation results in immediately rejecting receiving messages, + /// unless they are sent already but not received yet. + /// + /// Disconnected channels can be restored connected by calling [`connect`], as long as there + /// are at least a sender and a receiver for the channel. + /// + /// [`connect`]: Self::connect + pub fn disconnect(&self) -> bool { + match &self.flavor { + ReceiverFlavor::List(chan) => chan.disconnect(), + _ => todo!(), + } + } + + /// Connects the channel. + /// + /// Connects the disconnected channel and returns `true`, unless it has been already + /// connected. Otherwise, this method does nothing and returns `false`. Also, this method does + /// nothing and returns `false` as well, if this method is called on channels other than + /// bounded-capacity, unbounded-capacity or zero-capacity channels. + + /// + /// The successful connect operation results in immediately succeeding receiving messages + /// from senders. + /// + /// Connected channels can be disconnected again by calling [`disconnect`]. + /// + /// [`disconnect`]: Self::disconnect + pub fn connect(&self) -> bool { + match &self.flavor { + ReceiverFlavor::List(chan) => chan.connect(), + _ => todo!(), + } + } } impl Drop for Receiver { @@ -1160,7 +1241,7 @@ impl Drop for Receiver { unsafe { match &self.flavor { ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect()), - ReceiverFlavor::List(chan) => chan.release(|c| c.disconnect_receivers()), + ReceiverFlavor::List(chan) => chan.release(|c| c.disconnect_receivers_to_drop()), ReceiverFlavor::Zero(chan) => chan.release(|c| c.disconnect()), ReceiverFlavor::At(_) => {} ReceiverFlavor::Tick(_) => {} diff --git a/crossbeam-channel/src/flavors/list.rs b/crossbeam-channel/src/flavors/list.rs index e86551ad2..75fd589aa 100644 --- a/crossbeam-channel/src/flavors/list.rs +++ b/crossbeam-channel/src/flavors/list.rs @@ -34,11 +34,16 @@ const LAP: usize = 32; // The maximum number of messages a block can hold. const BLOCK_CAP: usize = LAP - 1; // How many lower bits are reserved for metadata. -const SHIFT: usize = 1; -// Has two different purposes: -// * If set in head, indicates that the block is not the last one. -// * If set in tail, indicates that the channel is disconnected. -const MARK_BIT: usize = 1; +const SHIFT: usize = 2; + +const HEAD_NOT_LAST: usize = 1; + +// Channel is disconnected implicitly via drop +const TAIL_DISCONNECT_IMPLICIT: usize = 1; +// Channel is disconnected explicitly as requested +const TAIL_DISCONNECT_EXPLICIT: usize = 2; + +const TAIL_DISCONNECT_ANY: usize = TAIL_DISCONNECT_IMPLICIT | TAIL_DISCONNECT_EXPLICIT; /// A slot in a block. struct Slot { @@ -204,7 +209,7 @@ impl Channel { loop { // Check if the channel is disconnected. - if tail & MARK_BIT != 0 { + if tail & TAIL_DISCONNECT_ANY != 0 { token.list.block = ptr::null(); return true; } @@ -317,14 +322,14 @@ impl Channel { let mut new_head = head + (1 << SHIFT); - if new_head & MARK_BIT == 0 { + if new_head & HEAD_NOT_LAST == 0 { atomic::fence(Ordering::SeqCst); let tail = self.tail.index.load(Ordering::Relaxed); // If the tail equals the head, that means the channel is empty. if head >> SHIFT == tail >> SHIFT { // If the channel is disconnected... - if tail & MARK_BIT != 0 { + if tail & TAIL_DISCONNECT_ANY != 0 { // ...then receive an error. token.list.block = ptr::null(); return true; @@ -334,9 +339,9 @@ impl Channel { } } - // If head and tail are not in the same block, set `MARK_BIT` in head. + // If head and tail are not in the same block, set `HEAD_NOT_LAST` in head. if (head >> SHIFT) / LAP != (tail >> SHIFT) / LAP { - new_head |= MARK_BIT; + new_head |= HEAD_NOT_LAST; } } @@ -360,9 +365,9 @@ impl Channel { // If we've reached the end of the block, move to the next one. if offset + 1 == BLOCK_CAP { let next = (*block).wait_next(); - let mut next_index = (new_head & !MARK_BIT).wrapping_add(1 << SHIFT); + let mut next_index = (new_head & !HEAD_NOT_LAST).wrapping_add(1 << SHIFT); if !(*next).next.load(Ordering::Relaxed).is_null() { - next_index |= MARK_BIT; + next_index |= HEAD_NOT_LAST; } self.head.block.store(next, Ordering::Release); @@ -538,10 +543,27 @@ impl Channel { /// Disconnects senders and wakes up all blocked receivers. /// /// Returns `true` if this call disconnected the channel. - pub(crate) fn disconnect_senders(&self) -> bool { - let tail = self.tail.index.fetch_or(MARK_BIT, Ordering::SeqCst); + pub(crate) fn disconnect_senders_to_drop(&self) -> bool { + let tail = self + .tail + .index + .fetch_or(TAIL_DISCONNECT_IMPLICIT, Ordering::SeqCst); + + if tail & TAIL_DISCONNECT_IMPLICIT == 0 { + self.receivers.disconnect(); + true + } else { + false + } + } + + pub(crate) fn disconnect(&self) -> bool { + let tail = self + .tail + .index + .fetch_or(TAIL_DISCONNECT_EXPLICIT, Ordering::SeqCst); - if tail & MARK_BIT == 0 { + if tail & TAIL_DISCONNECT_ANY == 0 { self.receivers.disconnect(); true } else { @@ -549,13 +571,25 @@ impl Channel { } } + pub(crate) fn connect(&self) -> bool { + let tail = self + .tail + .index + .fetch_and(!TAIL_DISCONNECT_EXPLICIT, Ordering::SeqCst); + + tail & TAIL_DISCONNECT_ANY == TAIL_DISCONNECT_EXPLICIT + } + /// Disconnects receivers. /// /// Returns `true` if this call disconnected the channel. - pub(crate) fn disconnect_receivers(&self) -> bool { - let tail = self.tail.index.fetch_or(MARK_BIT, Ordering::SeqCst); + pub(crate) fn disconnect_receivers_to_drop(&self) -> bool { + let tail = self + .tail + .index + .fetch_or(TAIL_DISCONNECT_IMPLICIT, Ordering::SeqCst); - if tail & MARK_BIT == 0 { + if tail & TAIL_DISCONNECT_IMPLICIT == 0 { // If receivers are dropped first, discard all messages to free // memory eagerly. self.discard_all_messages(); @@ -628,13 +662,13 @@ impl Channel { drop(Box::from_raw(block)); } } - head &= !MARK_BIT; + head &= !HEAD_NOT_LAST; self.head.index.store(head, Ordering::Release); } /// Returns `true` if the channel is disconnected. pub(crate) fn is_disconnected(&self) -> bool { - self.tail.index.load(Ordering::SeqCst) & MARK_BIT != 0 + self.tail.index.load(Ordering::SeqCst) & TAIL_DISCONNECT_ANY != 0 } /// Returns `true` if the channel is empty. diff --git a/crossbeam-channel/tests/list.rs b/crossbeam-channel/tests/list.rs index ebe6f6f85..685185398 100644 --- a/crossbeam-channel/tests/list.rs +++ b/crossbeam-channel/tests/list.rs @@ -580,3 +580,84 @@ fn channel_through_channel() { }) .unwrap(); } + +#[test] +fn disconnect_by_sender() { + let (s, r) = unbounded::<()>(); + let s2 = s.clone(); + + assert!(s.disconnect()); + + assert!(!s.disconnect()); + assert!(!s2.disconnect()); + + drop(r); +} + +#[test] +fn connect_by_sender() { + let (s, r) = unbounded::<()>(); + assert!(s.disconnect()); + + assert!(s.connect()); + assert!(s.disconnect()); + + drop(r); + // connect should fail after all receivers has gone + assert!(!s.connect()); +} + +#[test] +fn disconnect_by_receiver() { + let (s, r) = unbounded::<()>(); + let r2 = r.clone(); + assert!(r.disconnect()); + assert!(!r.disconnect()); + assert!(!r2.disconnect()); + drop(s); +} + +#[test] +fn connect_by_receiver() { + let (s, r) = unbounded::<()>(); + assert!(r.disconnect()); + + assert!(r.connect()); + assert!(r.disconnect()); + + drop(s); + // connect should fail after all senders has gone + assert!(!r.connect()); +} + +#[test] +fn send_after_disconnect_then_connect() { + let (s, r) = unbounded::<()>(); + + assert!(s.disconnect()); + assert_eq!(s.send(()), Err(SendError(()))); + + assert!(s.connect()); + assert_eq!(s.send(()), Ok(())); + + drop(r); +} + +#[test] +fn receive_after_disconnect_then_connect() { + let (s, r) = unbounded::<()>(); + s.send(()).unwrap(); + + assert!(r.disconnect()); + assert_eq!(r.recv(), Ok(())); + assert_eq!( + r.recv_timeout(Duration::from_millis(1)), + Err(RecvTimeoutError::Disconnected) + ); + + assert!(r.connect()); + assert_eq!( + r.recv_timeout(Duration::from_millis(1)), + Err(RecvTimeoutError::Timeout) + ); +}