Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions crossbeam-utils/src/sync/wait_group.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::primitive::sync::atomic::{AtomicUsize, Ordering};
use crate::primitive::sync::{Arc, Condvar, Mutex};
use core::mem::ManuallyDrop;
use std::fmt;

/// Enables threads to synchronize the beginning or end of some computation.
Expand Down Expand Up @@ -50,15 +52,17 @@ pub struct WaitGroup {
/// Inner state of a `WaitGroup`.
struct Inner {
cvar: Condvar,
count: Mutex<usize>,
lock: Mutex<()>,
count: AtomicUsize,
}

impl Default for WaitGroup {
fn default() -> Self {
Self {
inner: Arc::new(Inner {
cvar: Condvar::new(),
count: Mutex::new(1),
lock: Mutex::new(()),
count: AtomicUsize::new(1),
}),
}
}
Expand Down Expand Up @@ -102,36 +106,42 @@ impl WaitGroup {
/// # t.join().unwrap(); // join thread to avoid https://github.com/rust-lang/miri/issues/1371
/// ```
pub fn wait(self) {
if *self.inner.count.lock().unwrap() == 1 {
// SAFETY: this is equivalent to let Self { inner } = self, without calling our Drop.
let inner = unsafe {
let slf = ManuallyDrop::new(self);
core::ptr::read(&slf.inner)
};

if inner.count.fetch_sub(1, Ordering::AcqRel) == 1 {
// Acquire lock after updating count, see below.
drop(inner.lock.lock().unwrap());
inner.cvar.notify_all();
return;
}

let inner = self.inner.clone();
drop(self);

let mut count = inner.count.lock().unwrap();
while *count > 0 {
count = inner.cvar.wait(count).unwrap();
// We check the counter while holding the lock, and notifiers acquire
// the lock between updating the counter and notifying, ensuring we
// can not miss the notification.
let mut guard = inner.lock.lock().unwrap();
while inner.count.load(Ordering::Acquire) != 0 {
guard = inner.cvar.wait(guard).unwrap();
}
}
}

impl Drop for WaitGroup {
fn drop(&mut self) {
let mut count = self.inner.count.lock().unwrap();
*count -= 1;

if *count == 0 {
if self.inner.count.fetch_sub(1, Ordering::Release) == 1 {
// Acquire lock after updating count, see wait().
drop(self.inner.lock.lock().unwrap());
self.inner.cvar.notify_all();
}
}
}

impl Clone for WaitGroup {
fn clone(&self) -> Self {
let mut count = self.inner.count.lock().unwrap();
*count += 1;

self.inner.count.fetch_add(1, Ordering::Relaxed);
Self {
inner: self.inner.clone(),
}
Expand All @@ -140,7 +150,7 @@ impl Clone for WaitGroup {

impl fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count: &usize = &self.inner.count.lock().unwrap();
f.debug_struct("WaitGroup").field("count", count).finish()
let count = self.inner.count.load(Ordering::Relaxed);
f.debug_struct("WaitGroup").field("count", &count).finish()
}
}
Loading