@@ -14,11 +14,7 @@ use std::{fmt, io};
14
14
use tracing:: trace;
15
15
16
16
/// Mock implementation of `std::thread::JoinHandle`.
17
- pub struct JoinHandle < T > {
18
- result : Arc < Mutex < Option < std:: thread:: Result < T > > > > ,
19
- notify : rt:: Notify ,
20
- thread : Thread ,
21
- }
17
+ pub struct JoinHandle < T > ( JoinHandleInner < ' static , T > ) ;
22
18
23
19
/// Mock implementation of `std::thread::Thread`.
24
20
#[ derive( Clone , Debug ) ]
@@ -128,7 +124,7 @@ where
128
124
F : ' static ,
129
125
T : ' static ,
130
126
{
131
- spawn_internal ( f, None , location ! ( ) )
127
+ JoinHandle ( spawn_internal_static ( f, None , location ! ( ) ) )
132
128
}
133
129
134
130
/// Mock implementation of `std::thread::park`.
@@ -142,38 +138,6 @@ pub fn park() {
142
138
rt:: park ( location ! ( ) ) ;
143
139
}
144
140
145
- fn spawn_internal < F , T > ( f : F , name : Option < String > , location : Location ) -> JoinHandle < T >
146
- where
147
- F : FnOnce ( ) -> T ,
148
- F : ' static ,
149
- T : ' static ,
150
- {
151
- let result = Arc :: new ( Mutex :: new ( None ) ) ;
152
- let notify = rt:: Notify :: new ( true , false ) ;
153
-
154
- let id = {
155
- let name = name. clone ( ) ;
156
- let result = result. clone ( ) ;
157
- rt:: spawn ( move || {
158
- rt:: execution ( |execution| {
159
- init_current ( execution, name) ;
160
- } ) ;
161
-
162
- * result. lock ( ) . unwrap ( ) = Some ( Ok ( f ( ) ) ) ;
163
- notify. notify ( location) ;
164
- } )
165
- } ;
166
-
167
- JoinHandle {
168
- result,
169
- notify,
170
- thread : Thread {
171
- id : ThreadId { id } ,
172
- name,
173
- } ,
174
- }
175
- }
176
-
177
141
impl Builder {
178
142
/// Generates the base configuration for spawning a thread, from which
179
143
/// configuration methods can be chained.
@@ -206,21 +170,40 @@ impl Builder {
206
170
F : Send + ' static ,
207
171
T : Send + ' static ,
208
172
{
209
- Ok ( spawn_internal ( f, self . name , location ! ( ) ) )
173
+ Ok ( JoinHandle ( spawn_internal_static ( f, self . name , location ! ( ) ) ) )
174
+ }
175
+ }
176
+
177
+ impl Builder {
178
+ /// Spawns a new scoped thread using the settings set through this `Builder`.
179
+ pub fn spawn_scoped < ' scope , ' env , F , T > (
180
+ self ,
181
+ scope : & ' scope Scope < ' scope , ' env > ,
182
+ f : F ,
183
+ ) -> io:: Result < ScopedJoinHandle < ' scope , T > >
184
+ where
185
+ F : FnOnce ( ) -> T + Send + ' scope ,
186
+ T : Send + ' scope ,
187
+ {
188
+ Ok ( ScopedJoinHandle (
189
+ // Safety: the call to this function requires a `&'scope Scope`
190
+ // which can only be constructed by `scope()`, which ensures that
191
+ // all spawned threads are joined before the `Scope` is destroyed.
192
+ unsafe { spawn_internal ( f, self . name , Some ( scope. data . clone ( ) ) , location ! ( ) ) } ,
193
+ ) )
210
194
}
211
195
}
212
196
213
197
impl < T > JoinHandle < T > {
214
198
/// Waits for the associated thread to finish.
215
199
#[ track_caller]
216
200
pub fn join ( self ) -> std:: thread:: Result < T > {
217
- self . notify . wait ( location ! ( ) ) ;
218
- self . result . lock ( ) . unwrap ( ) . take ( ) . unwrap ( )
201
+ self . 0 . join ( )
219
202
}
220
203
221
204
/// Gets a handle to the underlying [`Thread`]
222
205
pub fn thread ( & self ) -> & Thread {
223
- & self . thread
206
+ self . 0 . thread ( )
224
207
}
225
208
}
226
209
@@ -301,3 +284,220 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
301
284
f. pad ( "LocalKey { .. }" )
302
285
}
303
286
}
287
+
288
+ /// A scope for spawning scoped threads.
289
+ ///
290
+ /// See [`scope`] for more details.
291
+ #[ derive( Debug ) ]
292
+ pub struct Scope < ' scope , ' env : ' scope > {
293
+ data : Arc < ScopeData > ,
294
+ scope : PhantomData < & ' scope mut & ' scope ( ) > ,
295
+ env : PhantomData < & ' env mut & ' env ( ) > ,
296
+ }
297
+
298
+ /// An owned permission to join on a scoped thread (block on its termination).
299
+ ///
300
+ /// See [`Scope::spawn`] for details.
301
+ #[ derive( Debug ) ]
302
+ pub struct ScopedJoinHandle < ' scope , T > ( JoinHandleInner < ' scope , T > ) ;
303
+
304
+ /// Create a scope for spawning scoped threads.
305
+ ///
306
+ /// Mock implementation of [`std::thread::scope`].
307
+ #[ track_caller]
308
+ pub fn scope < ' env , F , T > ( f : F ) -> T
309
+ where
310
+ F : for < ' scope > FnOnce ( & ' scope Scope < ' scope , ' env > ) -> T ,
311
+ {
312
+ let scope = Scope {
313
+ data : Arc :: new ( ScopeData {
314
+ running_threads : Mutex :: default ( ) ,
315
+ main_thread : current ( ) ,
316
+ } ) ,
317
+ env : PhantomData ,
318
+ scope : PhantomData ,
319
+ } ;
320
+
321
+ // Run `f`, but catch panics so we can make sure to wait for all the threads to join.
322
+ let result = std:: panic:: catch_unwind ( std:: panic:: AssertUnwindSafe ( || f ( & scope) ) ) ;
323
+
324
+ // Wait until all the threads are finished. This is required to fulfill
325
+ // the safety requirements of `spawn_internal`.
326
+ let running = loop {
327
+ {
328
+ let running = scope. data . running_threads . lock ( ) . unwrap ( ) ;
329
+ if running. count == 0 {
330
+ break running;
331
+ }
332
+ }
333
+ park ( ) ;
334
+ } ;
335
+
336
+ for notify in & running. notify_on_finished {
337
+ notify. wait ( location ! ( ) )
338
+ }
339
+
340
+ // Throw any panic from `f`, or the return value of `f` if no thread panicked.
341
+ match result {
342
+ Err ( e) => std:: panic:: resume_unwind ( e) ,
343
+ Ok ( result) => result,
344
+ }
345
+ }
346
+
347
+ impl < ' scope , ' env > Scope < ' scope , ' env > {
348
+ /// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
349
+ ///
350
+ /// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
351
+ pub fn spawn < F , T > ( & ' scope self , f : F ) -> ScopedJoinHandle < ' scope , T >
352
+ where
353
+ F : FnOnce ( ) -> T + Send + ' scope ,
354
+ T : Send + ' scope ,
355
+ {
356
+ Builder :: new ( )
357
+ . spawn_scoped ( self , f)
358
+ . expect ( "failed to spawn thread" )
359
+ }
360
+ }
361
+
362
+ impl < ' scope , T > ScopedJoinHandle < ' scope , T > {
363
+ /// Extracts a handle to the underlying thread.
364
+ pub fn thread ( & self ) -> & Thread {
365
+ self . 0 . thread ( )
366
+ }
367
+
368
+ /// Waits for the associated thread to finish.
369
+ pub fn join ( self ) -> std:: thread:: Result < T > {
370
+ self . 0 . join ( )
371
+ }
372
+ }
373
+
374
+ /// Handle for joining on a thread with a scope.
375
+ #[ derive( Debug ) ]
376
+ struct JoinHandleInner < ' scope , T > {
377
+ data : Arc < ThreadData < ' scope , T > > ,
378
+ notify : rt:: Notify ,
379
+ thread : Thread ,
380
+ }
381
+
382
+ /// Spawns a thread without a local scope.
383
+ fn spawn_internal_static < F , T > (
384
+ f : F ,
385
+ name : Option < String > ,
386
+ location : Location ,
387
+ ) -> JoinHandleInner < ' static , T >
388
+ where
389
+ F : FnOnce ( ) -> T ,
390
+ F : ' static ,
391
+ T : ' static ,
392
+ {
393
+ // Safety: the requirements of `spawn_internal` are trivially satisfied
394
+ // since there is no `scope`.
395
+ unsafe { spawn_internal ( f, name, None , location) }
396
+ }
397
+
398
+ /// Spawns a thread with an optional scope.
399
+ ///
400
+ /// The caller must ensure that if `scope` is not None, the provided closure
401
+ /// finishes before `'scope` ends.
402
+ unsafe fn spawn_internal < ' scope , F , T > (
403
+ f : F ,
404
+ name : Option < String > ,
405
+ scope : Option < Arc < ScopeData > > ,
406
+ location : Location ,
407
+ ) -> JoinHandleInner < ' scope , T >
408
+ where
409
+ F : FnOnce ( ) -> T ,
410
+ F : ' scope ,
411
+ T : ' scope ,
412
+ {
413
+ let scope_notify = scope
414
+ . clone ( )
415
+ . map ( |scope| ( scope. add_running_thread ( ) , scope) ) ;
416
+ let thread_data = Arc :: new ( ThreadData :: new ( ) ) ;
417
+ let notify = rt:: Notify :: new ( true , false ) ;
418
+
419
+ let id = {
420
+ let name = name. clone ( ) ;
421
+ let thread_data = thread_data. clone ( ) ;
422
+ let body: Box < dyn FnOnce ( ) + ' scope > = Box :: new ( move || {
423
+ rt:: execution ( |execution| {
424
+ init_current ( execution, name) ;
425
+ } ) ;
426
+
427
+ * thread_data. result . lock ( ) . unwrap ( ) = Some ( Ok ( f ( ) ) ) ;
428
+ notify. notify ( location) ;
429
+
430
+ if let Some ( ( notifier, scope) ) = scope_notify {
431
+ notifier. notify ( location ! ( ) ) ;
432
+ scope. remove_running_thread ( )
433
+ }
434
+ } ) ;
435
+ rt:: spawn ( std:: mem:: transmute :: < _ , Box < dyn FnOnce ( ) > > ( body) )
436
+ } ;
437
+
438
+ JoinHandleInner {
439
+ data : thread_data,
440
+ notify,
441
+ thread : Thread {
442
+ id : ThreadId { id } ,
443
+ name,
444
+ } ,
445
+ }
446
+ }
447
+
448
+ /// Data for a running thread.
449
+ #[ derive( Debug ) ]
450
+ struct ThreadData < ' scope , T > {
451
+ result : Mutex < Option < std:: thread:: Result < T > > > ,
452
+ _marker : PhantomData < Option < & ' scope ScopeData > > ,
453
+ }
454
+
455
+ impl < ' scope , T > ThreadData < ' scope , T > {
456
+ fn new ( ) -> Self {
457
+ Self {
458
+ result : Mutex :: new ( None ) ,
459
+ _marker : PhantomData ,
460
+ }
461
+ }
462
+ }
463
+
464
+ impl < ' scope , T > JoinHandleInner < ' scope , T > {
465
+ fn join ( self ) -> std:: thread:: Result < T > {
466
+ self . notify . wait ( location ! ( ) ) ;
467
+ self . data . result . lock ( ) . unwrap ( ) . take ( ) . unwrap ( )
468
+ }
469
+
470
+ fn thread ( & self ) -> & Thread {
471
+ & self . thread
472
+ }
473
+ }
474
+
475
+ #[ derive( Default , Debug ) ]
476
+ struct ScopeThreads {
477
+ count : usize ,
478
+ notify_on_finished : Vec < rt:: Notify > ,
479
+ }
480
+
481
+ #[ derive( Debug ) ]
482
+ struct ScopeData {
483
+ running_threads : Mutex < ScopeThreads > ,
484
+ main_thread : Thread ,
485
+ }
486
+
487
+ impl ScopeData {
488
+ fn add_running_thread ( & self ) -> rt:: Notify {
489
+ let mut running = self . running_threads . lock ( ) . unwrap ( ) ;
490
+ running. count += 1 ;
491
+ let notify = rt:: Notify :: new ( true , false ) ;
492
+ running. notify_on_finished . push ( notify) ;
493
+ notify
494
+ }
495
+
496
+ fn remove_running_thread ( & self ) {
497
+ let mut running = self . running_threads . lock ( ) . unwrap ( ) ;
498
+ running. count -= 1 ;
499
+ if running. count == 0 {
500
+ self . main_thread . unpark ( )
501
+ }
502
+ }
503
+ }
0 commit comments