@@ -24,6 +24,7 @@ pub fn Pool(comptime Node: type) type {
2424 head : ? * Node = null ,
2525
2626 // Tracks chunks of allocated nodes, used for freeing them at deinit() time.
27+ cleanup_mu : std.Thread.Mutex = .{},
2728 cleanup : std .ArrayListUnmanaged ([* ]Node ) = .{},
2829
2930 // How many nodes to allocate at once for each chunk in the pool.
@@ -72,10 +73,25 @@ pub fn Pool(comptime Node: type) type {
7273 break ; // Pool is empty
7374 }
7475
75- // Pool is empty, allocate new chunk of nodes, and track the pointer for later cleanup
76+ // Pool is empty, we need to allocate new nodes
77+ // This is the rare path where we need a lock to ensure thread safety only for the
78+ // pool.cleanup tracking list.
79+ pool .cleanup_mu .lock ();
80+
81+ // Check the pool again after acquiring the lock
82+ // Another thread might have already allocated nodes while we were waiting
83+ const head2 = @atomicLoad (? * Node , & pool .head , .acquire );
84+ if (head2 ) | _ | {
85+ // Pool is no longer empty, release the lock and try to acquire a node again
86+ pool .cleanup_mu .unlock ();
87+ return pool .acquire (allocator );
88+ }
89+
90+ // Pool still empty, allocate new chunk of nodes, and track the pointer for later cleanup
7691 const new_nodes = try allocator .alloc (Node , pool .chunk_size );
7792 errdefer allocator .free (new_nodes );
7893 try pool .cleanup .append (allocator , @ptrCast (new_nodes .ptr ));
94+ pool .cleanup_mu .unlock ();
7995
8096 // Link all our new nodes (except the first one acquired by the caller) into a chain
8197 // with eachother.
@@ -311,3 +327,43 @@ test "basic" {
311327 try std .testing .expectEqual (queue .pop (), 3 );
312328 try std .testing .expectEqual (queue .pop (), null );
313329}
330+
331+ test "concurrent producers" {
332+ const allocator = std .testing .allocator ;
333+
334+ var queue : Queue (u32 ) = undefined ;
335+ try queue .init (allocator , 32 );
336+ defer queue .deinit (allocator );
337+
338+ const n_jobs = 100 ;
339+ const n_entries : u32 = 10000 ;
340+
341+ var pool : std.Thread.Pool = undefined ;
342+ try std .Thread .Pool .init (& pool , .{ .allocator = allocator , .n_jobs = n_jobs });
343+ defer pool .deinit ();
344+
345+ var wg : std.Thread.WaitGroup = .{};
346+ for (0.. n_jobs ) | _ | {
347+ pool .spawnWg (
348+ & wg ,
349+ struct {
350+ pub fn run (q : * Queue (u32 )) void {
351+ var i : u32 = 0 ;
352+ while (i < n_entries ) : (i += 1 ) {
353+ q .push (allocator , i ) catch unreachable ;
354+ }
355+ }
356+ }.run ,
357+ .{& queue },
358+ );
359+ }
360+
361+ wg .wait ();
362+
363+ // Verify we can read some values without crashing
364+ var count : usize = 0 ;
365+ while (queue .pop ()) | _ | {
366+ count += 1 ;
367+ if (count >= n_jobs * n_entries ) break ;
368+ }
369+ }
0 commit comments