@@ -22,6 +22,11 @@ const Signature = blst.Signature;
2222const BlstError = @import ("error.zig" ).BlstError ;
2323const SecretKey = @import ("SecretKey.zig" );
2424
25+ const PoolError = error {
26+ /// Pool is currently shutting down.
27+ ShuttingDown ,
28+ };
29+
2530/// This is pretty arbitrary
2631pub const MAX_WORKERS : usize = 16 ;
2732
@@ -39,7 +44,12 @@ pub const Opts = struct {
3944allocator : Allocator ,
4045n_workers : usize ,
4146threads : [MAX_WORKERS ]std.Thread = undefined ,
47+ /// Signals workers to exit after draining the queue. Checked by `workerLoop`
48+ /// only when the queue is empty, so all pending items are processed first.
4249shutdown : std .atomic .Value (bool ) = std .atomic .Value (bool ).init (false ),
50+ /// Signals `pushBatch` to reject new work. Set before `shutdown` so no new
51+ /// items enter the queue while workers are draining it.
52+ shutting_down : std .atomic .Value (bool ) = std .atomic .Value (bool ).init (false ),
4353queue : JobQueue ,
4454
4555/// Thread-safe FIFO work queue. Workers wait on `cond` for new items
@@ -50,10 +60,16 @@ const JobQueue = struct {
5060 head : ? * WorkItem = null ,
5161 tail : ? * WorkItem = null ,
5262
53- fn pushBatch (self : * JobQueue , items : []* WorkItem ) void {
63+ /// Pushes a batch of `WorkItem`s to the `JobQueue`.
64+ ///
65+ /// Returns false if the pool has signalled that it is shutting down and does
66+ /// not push any work.
67+ fn pushBatch (self : * JobQueue , pool : * ThreadPool , items : []* WorkItem ) bool {
5468 self .mutex .lock ();
5569 defer self .mutex .unlock ();
5670
71+ if (pool .shutting_down .load (.acquire )) return false ;
72+
5773 for (items ) | item | {
5874 item .next = null ;
5975 if (self .tail ) | tail | {
@@ -64,6 +80,7 @@ const JobQueue = struct {
6480 self .tail = item ;
6581 }
6682 self .cond .broadcast ();
83+ return true ;
6784 }
6885
6986 fn pop (self : * JobQueue ) ? * WorkItem {
@@ -104,31 +121,45 @@ pub fn init(allocator_: Allocator, opts: Opts) (Allocator.Error || std.Thread.Sp
104121}
105122
106123/// Shuts down the thread pool and frees resources.
124+ ///
125+ /// Cleanup happens in 3 phases:
126+ /// 1) stop accepting new work,
127+ /// 2) finish existing work,
128+ /// 3) cleaning up resources.
129+ ///
107130/// The pool pointer is invalid after this call.
108131pub fn deinit (pool : * ThreadPool ) void {
132+ // Phase 1: stop accepting new work
133+ pool .queue .mutex .lock ();
134+ pool .shutting_down .store (true , .release );
135+
136+ // Phase 2: tell workers to drain queue then exit
109137 pool .shutdown .store (true , .release );
110- {
111- pool .queue .mutex .lock ();
112- defer pool .queue .mutex .unlock ();
113- pool .queue .cond .broadcast ();
114- }
115- for (pool .threads [0.. pool .n_workers ]) | t | {
116- t .join ();
117- }
138+ pool .queue .cond .broadcast ();
139+ pool .queue .mutex .unlock ();
140+
141+ // Phase 3: wait for workers to finish draining and exit
142+ for (pool .threads [0.. pool .n_workers ]) | t | t .join ();
118143 pool .allocator .destroy (pool );
119144}
120145
146+ /// Main loop for worker threads.
147+ ///
148+ /// Pops work first before checking for `shutdown` signal, allowing
149+ /// workers to finish their work before closing.
150+ ///
151+ /// Safety: it is safe to pop work first since we stop accepting work
152+ /// in `pushBatch` by checking for the `shutting_down` signal; no new
153+ /// work can be accepted at the point of entry into this loop.
121154fn workerLoop (pool : * ThreadPool ) void {
122155 while (true ) {
123156 const item : * WorkItem = blk : {
124157 pool .queue .mutex .lock ();
125158 defer pool .queue .mutex .unlock ();
126159
127160 while (true ) {
161+ if (pool .queue .pop ()) | wi | break :blk wi ;
128162 if (pool .shutdown .load (.acquire )) return ;
129- if (pool .queue .pop ()) | wi | {
130- break :blk wi ;
131- }
132163 pool .queue .cond .wait (& pool .queue .mutex );
133164 }
134165 };
@@ -139,8 +170,8 @@ fn workerLoop(pool: *ThreadPool) void {
139170}
140171
141172/// Submit work items to the pool and wait for all to complete.
142- fn submitAndWait (pool : * ThreadPool , items : []* WorkItem ) void {
143- pool .queue .pushBatch (items );
173+ fn submitAndWait (pool : * ThreadPool , items : []* WorkItem ) PoolError ! void {
174+ if ( ! pool .queue .pushBatch (pool , items )) return PoolError . ShuttingDown ;
144175 for (items ) | item | {
145176 item .done .wait ();
146177 }
@@ -220,7 +251,7 @@ pub fn verifyMultipleAggregateSignatures(
220251 sigs : []const * Signature ,
221252 sigs_groupcheck : bool ,
222253 rands : []const [32 ]u8 ,
223- ) BlstError ! bool {
254+ ) ( BlstError || PoolError ) ! bool {
224255 if (n_elems == 0 or
225256 pks .len != n_elems or
226257 sigs .len != n_elems or
@@ -274,7 +305,7 @@ pub fn verifyMultipleAggregateSignatures(
274305 item_ptrs [i ] = & work_items [i ].base ;
275306 }
276307
277- pool .submitAndWait (item_ptrs [0.. n_active ]);
308+ try pool .submitAndWait (item_ptrs [0.. n_active ]);
278309
279310 if (job .err_flag .load (.acquire )) return BlstError .VerifyFail ;
280311
@@ -349,7 +380,7 @@ pub fn aggregateVerify(
349380 dst : []const u8 ,
350381 pks : []const * PublicKey ,
351382 pks_validate : bool ,
352- ) BlstError ! bool {
383+ ) ( BlstError || PoolError ) ! bool {
353384 const n_elems = pks .len ;
354385 if (n_elems == 0 or msgs .len != n_elems ) return BlstError .VerifyFail ;
355386
@@ -395,7 +426,7 @@ pub fn aggregateVerify(
395426 item_ptrs [i ] = & work_items [i ].base ;
396427 }
397428
398- pool .submitAndWait (item_ptrs [0.. n_active ]);
429+ try pool .submitAndWait (item_ptrs [0.. n_active ]);
399430
400431 if (job .err_flag .load (.acquire )) return false ;
401432
0 commit comments