Skip to content

Commit 47a54d1

Browse files
committed
fix: graceful shutdown
@nazarhussain pointed out correctly that the not-so-strict deinit logic could result in memory leaks due to loose rules: added a new `shutting_down` bool to denote that the workers are no longer accepting work in `submitAndWait`. We might also consider nazar's suggestion of having `active_jobs` but reordering the logic of `workerLoop` to pop first and then check for shutdown allows the final `join` call to finish work before shutting down. Safety: It is safe to pop work first since we stop accepting work in `pushBatch` by checking for the `shutting_down` signal; no new work can be accepted at the point of entry into this loop.
1 parent 69a695e commit 47a54d1

File tree

1 file changed

+49
-18
lines changed

1 file changed

+49
-18
lines changed

src/bls/ThreadPool.zig

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ const Signature = blst.Signature;
2222
const BlstError = @import("error.zig").BlstError;
2323
const SecretKey = @import("SecretKey.zig");
2424

25+
const PoolError = error{
26+
/// Pool is currently shutting down.
27+
ShuttingDown,
28+
};
29+
2530
/// This is pretty arbitrary
2631
pub const MAX_WORKERS: usize = 16;
2732

@@ -39,7 +44,12 @@ pub const Opts = struct {
3944
allocator: Allocator,
4045
n_workers: usize,
4146
threads: [MAX_WORKERS]std.Thread = undefined,
47+
/// Signals workers to exit after draining the queue. Checked by `workerLoop`
48+
/// only when the queue is empty, so all pending items are processed first.
4249
shutdown: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
50+
/// Signals `pushBatch` to reject new work. Set before `shutdown` so no new
51+
/// items enter the queue while workers are draining it.
52+
shutting_down: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
4353
queue: JobQueue,
4454

4555
/// Thread-safe FIFO work queue. Workers wait on `cond` for new items
@@ -50,10 +60,16 @@ const JobQueue = struct {
5060
head: ?*WorkItem = null,
5161
tail: ?*WorkItem = null,
5262

53-
fn pushBatch(self: *JobQueue, items: []*WorkItem) void {
63+
/// Pushes a batch of `WorkItem`s to the `JobQueue`.
64+
///
65+
/// Returns false if the pool has signalled that it is shutting down and does
66+
/// not push any work.
67+
fn pushBatch(self: *JobQueue, pool: *ThreadPool, items: []*WorkItem) bool {
5468
self.mutex.lock();
5569
defer self.mutex.unlock();
5670

71+
if (pool.shutting_down.load(.acquire)) return false;
72+
5773
for (items) |item| {
5874
item.next = null;
5975
if (self.tail) |tail| {
@@ -64,6 +80,7 @@ const JobQueue = struct {
6480
self.tail = item;
6581
}
6682
self.cond.broadcast();
83+
return true;
6784
}
6885

6986
fn pop(self: *JobQueue) ?*WorkItem {
@@ -104,31 +121,45 @@ pub fn init(allocator_: Allocator, opts: Opts) (Allocator.Error || std.Thread.Sp
104121
}
105122

106123
/// Shuts down the thread pool and frees resources.
124+
///
125+
/// Cleanup happens in 3 phases:
126+
/// 1) stop accepting new work,
127+
/// 2) finish existing work,
128+
/// 3) cleaning up resources.
129+
///
107130
/// The pool pointer is invalid after this call.
108131
pub fn deinit(pool: *ThreadPool) void {
132+
// Phase 1: stop accepting new work
133+
pool.queue.mutex.lock();
134+
pool.shutting_down.store(true, .release);
135+
136+
// Phase 2: tell workers to drain queue then exit
109137
pool.shutdown.store(true, .release);
110-
{
111-
pool.queue.mutex.lock();
112-
defer pool.queue.mutex.unlock();
113-
pool.queue.cond.broadcast();
114-
}
115-
for (pool.threads[0..pool.n_workers]) |t| {
116-
t.join();
117-
}
138+
pool.queue.cond.broadcast();
139+
pool.queue.mutex.unlock();
140+
141+
// Phase 3: wait for workers to finish draining and exit
142+
for (pool.threads[0..pool.n_workers]) |t| t.join();
118143
pool.allocator.destroy(pool);
119144
}
120145

146+
/// Main loop for worker threads.
147+
///
148+
/// Pops work first before checking for `shutdown` signal, allowing
149+
/// workers to finish their work before closing.
150+
///
151+
/// Safety: it is safe to pop work first since we stop accepting work
152+
/// in `pushBatch` by checking for the `shutting_down` signal; no new
153+
/// work can be accepted at the point of entry into this loop.
121154
fn workerLoop(pool: *ThreadPool) void {
122155
while (true) {
123156
const item: *WorkItem = blk: {
124157
pool.queue.mutex.lock();
125158
defer pool.queue.mutex.unlock();
126159

127160
while (true) {
161+
if (pool.queue.pop()) |wi| break :blk wi;
128162
if (pool.shutdown.load(.acquire)) return;
129-
if (pool.queue.pop()) |wi| {
130-
break :blk wi;
131-
}
132163
pool.queue.cond.wait(&pool.queue.mutex);
133164
}
134165
};
@@ -139,8 +170,8 @@ fn workerLoop(pool: *ThreadPool) void {
139170
}
140171

141172
/// Submit work items to the pool and wait for all to complete.
142-
fn submitAndWait(pool: *ThreadPool, items: []*WorkItem) void {
143-
pool.queue.pushBatch(items);
173+
fn submitAndWait(pool: *ThreadPool, items: []*WorkItem) PoolError!void {
174+
if (!pool.queue.pushBatch(pool, items)) return PoolError.ShuttingDown;
144175
for (items) |item| {
145176
item.done.wait();
146177
}
@@ -220,7 +251,7 @@ pub fn verifyMultipleAggregateSignatures(
220251
sigs: []const *Signature,
221252
sigs_groupcheck: bool,
222253
rands: []const [32]u8,
223-
) BlstError!bool {
254+
) (BlstError || PoolError)!bool {
224255
if (n_elems == 0 or
225256
pks.len != n_elems or
226257
sigs.len != n_elems or
@@ -274,7 +305,7 @@ pub fn verifyMultipleAggregateSignatures(
274305
item_ptrs[i] = &work_items[i].base;
275306
}
276307

277-
pool.submitAndWait(item_ptrs[0..n_active]);
308+
try pool.submitAndWait(item_ptrs[0..n_active]);
278309

279310
if (job.err_flag.load(.acquire)) return BlstError.VerifyFail;
280311

@@ -349,7 +380,7 @@ pub fn aggregateVerify(
349380
dst: []const u8,
350381
pks: []const *PublicKey,
351382
pks_validate: bool,
352-
) BlstError!bool {
383+
) (BlstError || PoolError)!bool {
353384
const n_elems = pks.len;
354385
if (n_elems == 0 or msgs.len != n_elems) return BlstError.VerifyFail;
355386

@@ -395,7 +426,7 @@ pub fn aggregateVerify(
395426
item_ptrs[i] = &work_items[i].base;
396427
}
397428

398-
pool.submitAndWait(item_ptrs[0..n_active]);
429+
try pool.submitAndWait(item_ptrs[0..n_active]);
399430

400431
if (job.err_flag.load(.acquire)) return false;
401432

0 commit comments

Comments
 (0)