Skip to content

Commit 8e0298a

Browse files
committed
fix(persistent_merkle_tree): address batchGetRoot review comments
1 parent 3ab5c94 commit 8e0298a

2 files changed

Lines changed: 64 additions & 44 deletions

File tree

src/persistent_merkle_tree/Node.zig

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,8 @@ pub const Id = enum(u32) {
419419
/// Collects dirty branches via BFS, then hashes them bottom-up in batches
420420
/// using SIMD-accelerated `hashBatch`. Produces identical results to `getRoot`
421421
/// but is ~2x faster on large trees with many dirty nodes.
422+
///
423+
/// `depth` must be the full height of `node_id`'s tree.
422424
pub fn batchGetRoot(node_id: Id, pool: *Pool, depth: Depth, allocator: Allocator) Error!*const [32]u8 {
423425
const states = pool.nodes.items(.state);
424426
const hashes = pool.nodes.items(.hash);
@@ -427,21 +429,15 @@ pub const Id = enum(u32) {
427429
if (!states[@intFromEnum(node_id)].isBranchLazy()) {
428430
return &hashes[@intFromEnum(node_id)];
429431
}
430-
431-
// Fast path: depth 0 means the node is a leaf, just return its hash.
432-
if (depth == 0) {
433-
return &hashes[@intFromEnum(node_id)];
434-
}
432+
if (depth == 0) return Error.InvalidNode;
435433

436434
const lefts = pool.nodes.items(.left);
437435
const rights = pool.nodes.items(.right);
438436

439437
// Collect dirty branch nodes per level via BFS.
440438
// levels[i] holds dirty node IDs at depth i (0 = root level).
441439
var levels: [max_depth]std.ArrayListUnmanaged(Id) = undefined;
442-
for (0..depth) |i| {
443-
levels[i] = .empty;
444-
}
440+
@memset(levels[0..depth], .empty);
445441
defer for (0..depth) |i| {
446442
levels[i].deinit(allocator);
447443
};
@@ -466,16 +462,15 @@ pub const Id = enum(u32) {
466462
// Bottom-up batch hash: process from deepest level to root.
467463
var pairs = std.ArrayListUnmanaged([32]u8).empty;
468464
defer pairs.deinit(allocator);
465+
469466
var outs = std.ArrayListUnmanaged([32]u8).empty;
470467
defer outs.deinit(allocator);
471468

472-
var level_i: usize = depth - 1;
473-
while (true) : (level_i -= 1) {
469+
var level_i: usize = depth;
470+
while (level_i > 0) {
471+
level_i -= 1;
474472
const dirty_nodes = levels[level_i].items;
475-
if (dirty_nodes.len == 0) {
476-
if (level_i == 0) break;
477-
continue;
478-
}
473+
if (dirty_nodes.len == 0) continue;
479474

480475
// Build input pairs: [left_hash, right_hash] for each dirty node.
481476
pairs.clearRetainingCapacity();
@@ -487,17 +482,14 @@ pub const Id = enum(u32) {
487482

488483
// Batch hash
489484
outs.clearRetainingCapacity();
490-
try outs.ensureTotalCapacity(allocator, dirty_nodes.len);
491-
outs.items.len = dirty_nodes.len;
485+
try outs.resize(allocator, dirty_nodes.len);
492486
hashBatch(outs.items, pairs.items) catch unreachable;
493487

494488
// Write results back
495489
for (dirty_nodes, 0..) |id, i| {
496490
hashes[@intFromEnum(id)] = outs.items[i];
497491
states[@intFromEnum(id)].setBranchComputed();
498492
}
499-
500-
if (level_i == 0) break;
501493
}
502494

503495
return &hashes[@intFromEnum(node_id)];

src/persistent_merkle_tree/node_test.zig

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -640,26 +640,38 @@ test "batchGetRoot matches getRoot for dirty branches at various depths" {
640640
const depth: Depth = @intCast(depth_usize);
641641
const max_length = @as(usize, 1) << @intCast(depth);
642642

643-
// Build a full tree, then modify some leaves to create dirty branches.
644643
var leaves = try allocator.alloc(Node.Id, max_length);
645644
defer allocator.free(leaves);
646645
for (0..max_length) |i| {
647646
leaves[i] = try pool.createLeafFromUint(@intCast(i + 1));
647+
try pool.ref(leaves[i]);
648648
}
649+
defer for (leaves) |leaf| pool.unref(leaf);
649650

650-
const root = try Node.fillWithContents(p, leaves, depth);
651-
defer pool.unref(root);
651+
// Build two identical dirty trees. `getRoot` mutates in-place, so the
652+
// batch path must run on a separate tree to exercise dirty branches.
653+
const leaves_for_get_root = try allocator.dupe(Node.Id, leaves);
654+
defer allocator.free(leaves_for_get_root);
655+
656+
const leaves_for_batch = try allocator.dupe(Node.Id, leaves);
657+
defer allocator.free(leaves_for_batch);
658+
659+
const root_get = try Node.fillWithContents(p, leaves_for_get_root, depth);
660+
defer pool.unref(root_get);
661+
662+
const root_batch = try Node.fillWithContents(p, leaves_for_batch, depth);
663+
defer pool.unref(root_batch);
664+
665+
const new_leaf_get = try pool.createLeafFromUint(0xBEEF);
666+
const dirty_get = try root_get.setNodeAtDepth(p, depth, 0, new_leaf_get);
667+
defer pool.unref(dirty_get);
652668

653-
// Modify a few leaves to make branches dirty.
654-
const new_leaf = try pool.createLeafFromUint(0xBEEF);
655-
const dirty_root = try root.setNodeAtDepth(p, depth, 0, new_leaf);
656-
defer pool.unref(dirty_root);
669+
const new_leaf_batch = try pool.createLeafFromUint(0xBEEF);
670+
const dirty_batch = try root_batch.setNodeAtDepth(p, depth, 0, new_leaf_batch);
671+
defer pool.unref(dirty_batch);
657672

658-
// Clone the tree for getRoot (since getRoot mutates state in-place).
659-
// Both should produce the same hash.
660-
const expected = dirty_root.getRoot(p);
661-
// getRoot already computed the hashes, but batchGetRoot should return the same result.
662-
const actual = try dirty_root.batchGetRoot(p, depth, allocator);
673+
const expected = dirty_get.getRoot(p);
674+
const actual = try dirty_batch.batchGetRoot(p, depth, allocator);
663675
try std.testing.expectEqualSlices(u8, expected, actual);
664676
}
665677
}
@@ -673,31 +685,47 @@ test "batchGetRoot matches getRoot with multiple dirty leaves" {
673685
const depth: Depth = 4;
674686
const max_length = @as(usize, 1) << depth;
675687

676-
// Build a full tree with all leaves set.
677688
var leaves = try allocator.alloc(Node.Id, max_length);
678689
defer allocator.free(leaves);
679690
for (0..max_length) |i| {
680691
leaves[i] = try pool.createLeafFromUint(@intCast(i + 1));
692+
try pool.ref(leaves[i]);
681693
}
694+
defer for (leaves) |leaf| pool.unref(leaf);
695+
696+
// Build two identical dirty trees. `getRoot` mutates in-place, so the
697+
// batch path must run on a separate tree to exercise dirty branches.
698+
const leaves_for_get_root = try allocator.dupe(Node.Id, leaves);
699+
defer allocator.free(leaves_for_get_root);
700+
701+
const leaves_for_batch = try allocator.dupe(Node.Id, leaves);
702+
defer allocator.free(leaves_for_batch);
703+
704+
const root_get = try Node.fillWithContents(p, leaves_for_get_root, depth);
705+
defer pool.unref(root_get);
706+
var modified_get = root_get;
682707

683-
const root = try Node.fillWithContents(p, leaves, depth);
708+
const root_batch = try Node.fillWithContents(p, leaves_for_batch, depth);
709+
defer pool.unref(root_batch);
710+
var modified_batch = root_batch;
684711

685-
// Modify multiple scattered leaves.
686-
var modified = root;
687712
const modify_indices = [_]usize{ 0, 3, 7, 12, 15 };
688713
for (modify_indices) |idx| {
689-
const new_leaf = try pool.createLeafFromUint(@intCast(0xF000 + idx));
690-
const old = modified;
691-
modified = try modified.setNodeAtDepth(p, depth, idx, new_leaf);
692-
if (old != root) pool.unref(old);
714+
const new_leaf_get = try pool.createLeafFromUint(@intCast(0xF000 + idx));
715+
const old_get = modified_get;
716+
modified_get = try modified_get.setNodeAtDepth(p, depth, idx, new_leaf_get);
717+
if (old_get != root_get) pool.unref(old_get);
718+
719+
const new_leaf_batch = try pool.createLeafFromUint(@intCast(0xF000 + idx));
720+
const old_batch = modified_batch;
721+
modified_batch = try modified_batch.setNodeAtDepth(p, depth, idx, new_leaf_batch);
722+
if (old_batch != root_batch) pool.unref(old_batch);
693723
}
694-
defer pool.unref(modified);
695-
pool.unref(root);
724+
defer pool.unref(modified_get);
725+
defer pool.unref(modified_batch);
696726

697-
// Use getRoot on a copy-by-reference to get expected hash, then verify batchGetRoot matches.
698-
// Since getRoot computes in-place, calling it first then batchGetRoot should still agree.
699-
const expected = modified.getRoot(p);
700-
const actual = try modified.batchGetRoot(p, depth, allocator);
727+
const expected = modified_get.getRoot(p);
728+
const actual = try modified_batch.batchGetRoot(p, depth, allocator);
701729
try std.testing.expectEqualSlices(u8, expected, actual);
702730
}
703731

0 commit comments

Comments
 (0)