Skip to content

Commit 225fe95

Browse files
committed
Add support for scoped threads
Add loom::thread::scope to mirror std::thread::scope provided by the standard library.
1 parent a0b154d commit 225fe95

File tree

2 files changed

+340
-42
lines changed

2 files changed

+340
-42
lines changed

src/thread.rs

+242-42
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ use std::{fmt, io};
1414
use tracing::trace;
1515

1616
/// Mock implementation of `std::thread::JoinHandle`.
17-
pub struct JoinHandle<T> {
18-
result: Arc<Mutex<Option<std::thread::Result<T>>>>,
19-
notify: rt::Notify,
20-
thread: Thread,
21-
}
17+
pub struct JoinHandle<T>(JoinHandleInner<'static, T>);
2218

2319
/// Mock implementation of `std::thread::Thread`.
2420
#[derive(Clone, Debug)]
@@ -128,7 +124,7 @@ where
128124
F: 'static,
129125
T: 'static,
130126
{
131-
spawn_internal(f, None, location!())
127+
JoinHandle(spawn_internal_static(f, None, location!()))
132128
}
133129

134130
/// Mock implementation of `std::thread::park`.
@@ -142,38 +138,6 @@ pub fn park() {
142138
rt::park(location!());
143139
}
144140

145-
fn spawn_internal<F, T>(f: F, name: Option<String>, location: Location) -> JoinHandle<T>
146-
where
147-
F: FnOnce() -> T,
148-
F: 'static,
149-
T: 'static,
150-
{
151-
let result = Arc::new(Mutex::new(None));
152-
let notify = rt::Notify::new(true, false);
153-
154-
let id = {
155-
let name = name.clone();
156-
let result = result.clone();
157-
rt::spawn(move || {
158-
rt::execution(|execution| {
159-
init_current(execution, name);
160-
});
161-
162-
*result.lock().unwrap() = Some(Ok(f()));
163-
notify.notify(location);
164-
})
165-
};
166-
167-
JoinHandle {
168-
result,
169-
notify,
170-
thread: Thread {
171-
id: ThreadId { id },
172-
name,
173-
},
174-
}
175-
}
176-
177141
impl Builder {
178142
/// Generates the base configuration for spawning a thread, from which
179143
/// configuration methods can be chained.
@@ -206,21 +170,40 @@ impl Builder {
206170
F: Send + 'static,
207171
T: Send + 'static,
208172
{
209-
Ok(spawn_internal(f, self.name, location!()))
173+
Ok(JoinHandle(spawn_internal_static(f, self.name, location!())))
174+
}
175+
}
176+
177+
impl Builder {
178+
/// Spawns a new scoped thread using the settings set through this `Builder`.
179+
pub fn spawn_scoped<'scope, 'env, F, T>(
180+
self,
181+
scope: &'scope Scope<'scope, 'env>,
182+
f: F,
183+
) -> io::Result<ScopedJoinHandle<'scope, T>>
184+
where
185+
F: FnOnce() -> T + Send + 'scope,
186+
T: Send + 'scope,
187+
{
188+
Ok(ScopedJoinHandle(
189+
// Safety: the call to this function requires a `&'scope Scope`
190+
// which can only be constructed by `scope()`, which ensures that
191+
// all spawned threads are joined before the `Scope` is destroyed.
192+
unsafe { spawn_internal(f, self.name, Some(scope.data.clone()), location!()) },
193+
))
210194
}
211195
}
212196

213197
impl<T> JoinHandle<T> {
214198
/// Waits for the associated thread to finish.
215199
#[track_caller]
216200
pub fn join(self) -> std::thread::Result<T> {
217-
self.notify.wait(location!());
218-
self.result.lock().unwrap().take().unwrap()
201+
self.0.join()
219202
}
220203

221204
/// Gets a handle to the underlying [`Thread`]
222205
pub fn thread(&self) -> &Thread {
223-
&self.thread
206+
self.0.thread()
224207
}
225208
}
226209

@@ -301,3 +284,220 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
301284
f.pad("LocalKey { .. }")
302285
}
303286
}
287+
288+
/// A scope for spawning scoped threads.
289+
///
290+
/// See [`scope`] for more details.
291+
#[derive(Debug)]
292+
pub struct Scope<'scope, 'env: 'scope> {
293+
data: Arc<ScopeData>,
294+
scope: PhantomData<&'scope mut &'scope ()>,
295+
env: PhantomData<&'env mut &'env ()>,
296+
}
297+
298+
/// An owned permission to join on a scoped thread (block on its termination).
299+
///
300+
/// See [`Scope::spawn`] for details.
301+
#[derive(Debug)]
302+
pub struct ScopedJoinHandle<'scope, T>(JoinHandleInner<'scope, T>);
303+
304+
/// Create a scope for spawning scoped threads.
305+
///
306+
/// Mock implementation of [`std::thread::scope`].
307+
#[track_caller]
308+
pub fn scope<'env, F, T>(f: F) -> T
309+
where
310+
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
311+
{
312+
let scope = Scope {
313+
data: Arc::new(ScopeData {
314+
running_threads: Mutex::default(),
315+
main_thread: current(),
316+
}),
317+
env: PhantomData,
318+
scope: PhantomData,
319+
};
320+
321+
// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
322+
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&scope)));
323+
324+
// Wait until all the threads are finished. This is required to fulfill
325+
// the safety requirements of `spawn_internal`.
326+
let running = loop {
327+
{
328+
let running = scope.data.running_threads.lock().unwrap();
329+
if running.count == 0 {
330+
break running;
331+
}
332+
}
333+
park();
334+
};
335+
336+
for notify in &running.notify_on_finished {
337+
notify.wait(location!())
338+
}
339+
340+
// Throw any panic from `f`, or the return value of `f` if no thread panicked.
341+
match result {
342+
Err(e) => std::panic::resume_unwind(e),
343+
Ok(result) => result,
344+
}
345+
}
346+
347+
impl<'scope, 'env> Scope<'scope, 'env> {
348+
/// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
349+
///
350+
/// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
351+
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
352+
where
353+
F: FnOnce() -> T + Send + 'scope,
354+
T: Send + 'scope,
355+
{
356+
Builder::new()
357+
.spawn_scoped(self, f)
358+
.expect("failed to spawn thread")
359+
}
360+
}
361+
362+
impl<'scope, T> ScopedJoinHandle<'scope, T> {
363+
/// Extracts a handle to the underlying thread.
364+
pub fn thread(&self) -> &Thread {
365+
self.0.thread()
366+
}
367+
368+
/// Waits for the associated thread to finish.
369+
pub fn join(self) -> std::thread::Result<T> {
370+
self.0.join()
371+
}
372+
}
373+
374+
/// Handle for joining on a thread with a scope.
375+
#[derive(Debug)]
376+
struct JoinHandleInner<'scope, T> {
377+
data: Arc<ThreadData<'scope, T>>,
378+
notify: rt::Notify,
379+
thread: Thread,
380+
}
381+
382+
/// Spawns a thread without a local scope.
383+
fn spawn_internal_static<F, T>(
384+
f: F,
385+
name: Option<String>,
386+
location: Location,
387+
) -> JoinHandleInner<'static, T>
388+
where
389+
F: FnOnce() -> T,
390+
F: 'static,
391+
T: 'static,
392+
{
393+
// Safety: the requirements of `spawn_internal` are trivially satisfied
394+
// since there is no `scope`.
395+
unsafe { spawn_internal(f, name, None, location) }
396+
}
397+
398+
/// Spawns a thread with an optional scope.
399+
///
400+
/// The caller must ensure that if `scope` is not None, the provided closure
401+
/// finishes before `'scope` ends.
402+
unsafe fn spawn_internal<'scope, F, T>(
403+
f: F,
404+
name: Option<String>,
405+
scope: Option<Arc<ScopeData>>,
406+
location: Location,
407+
) -> JoinHandleInner<'scope, T>
408+
where
409+
F: FnOnce() -> T,
410+
F: 'scope,
411+
T: 'scope,
412+
{
413+
let scope_notify = scope
414+
.clone()
415+
.map(|scope| (scope.add_running_thread(), scope));
416+
let thread_data = Arc::new(ThreadData::new());
417+
let notify = rt::Notify::new(true, false);
418+
419+
let id = {
420+
let name = name.clone();
421+
let thread_data = thread_data.clone();
422+
let body: Box<dyn FnOnce() + 'scope> = Box::new(move || {
423+
rt::execution(|execution| {
424+
init_current(execution, name);
425+
});
426+
427+
*thread_data.result.lock().unwrap() = Some(Ok(f()));
428+
notify.notify(location);
429+
430+
if let Some((notifier, scope)) = scope_notify {
431+
notifier.notify(location!());
432+
scope.remove_running_thread()
433+
}
434+
});
435+
rt::spawn(std::mem::transmute::<_, Box<dyn FnOnce()>>(body))
436+
};
437+
438+
JoinHandleInner {
439+
data: thread_data,
440+
notify,
441+
thread: Thread {
442+
id: ThreadId { id },
443+
name,
444+
},
445+
}
446+
}
447+
448+
/// Data for a running thread.
449+
#[derive(Debug)]
450+
struct ThreadData<'scope, T> {
451+
result: Mutex<Option<std::thread::Result<T>>>,
452+
_marker: PhantomData<Option<&'scope ScopeData>>,
453+
}
454+
455+
impl<'scope, T> ThreadData<'scope, T> {
456+
fn new() -> Self {
457+
Self {
458+
result: Mutex::new(None),
459+
_marker: PhantomData,
460+
}
461+
}
462+
}
463+
464+
impl<'scope, T> JoinHandleInner<'scope, T> {
465+
fn join(self) -> std::thread::Result<T> {
466+
self.notify.wait(location!());
467+
self.data.result.lock().unwrap().take().unwrap()
468+
}
469+
470+
fn thread(&self) -> &Thread {
471+
&self.thread
472+
}
473+
}
474+
475+
#[derive(Default, Debug)]
476+
struct ScopeThreads {
477+
count: usize,
478+
notify_on_finished: Vec<rt::Notify>,
479+
}
480+
481+
#[derive(Debug)]
482+
struct ScopeData {
483+
running_threads: Mutex<ScopeThreads>,
484+
main_thread: Thread,
485+
}
486+
487+
impl ScopeData {
488+
fn add_running_thread(&self) -> rt::Notify {
489+
let mut running = self.running_threads.lock().unwrap();
490+
running.count += 1;
491+
let notify = rt::Notify::new(true, false);
492+
running.notify_on_finished.push(notify);
493+
notify
494+
}
495+
496+
fn remove_running_thread(&self) {
497+
let mut running = self.running_threads.lock().unwrap();
498+
running.count -= 1;
499+
if running.count == 0 {
500+
self.main_thread.unpark()
501+
}
502+
}
503+
}

0 commit comments

Comments
 (0)