Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions src/ssz/type/bit_vector.zig
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,23 @@ pub fn BitVector(comptime _length: comptime_int) type {
}
}

fn maskedByte(self: *const @This(), i_byte: usize) u8 {
const remainder_bits = length % 8;
if (remainder_bits == 0 or i_byte + 1 < byte_len) {
return self.data[i_byte];
}
const tail_mask: u8 = (@as(u8, 1) << @intCast(remainder_bits)) - 1;
return self.data[i_byte] & tail_mask;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The maskedByte function should be refactored to adhere to the repository style guide:

  1. Control Flow (Rule 133): Split the compound condition into nested if/else branches. This improves readability and ensures that the logic for handling the last byte is explicitly separated from the common case.
  2. Assertions (Rule 51): Add a precondition assertion to verify that i_byte is within bounds. This helps catch programmer errors early and documents the function's contract.
  3. Pair Assertions (Rule 58): Add an assertion to verify that the resulting byte has no bits set beyond the logical length. This provides defense-in-depth and documents the expected state of the data.
  4. Assertion Style (Rule 68): Use single-line if for assertions of implications where appropriate.
        fn maskedByte(self: *const @This(), i_byte: usize) u8 {
            std.debug.assert(i_byte < byte_len);

            const remainder_bits = length % 8;
            if (remainder_bits == 0) {
                return self.data[i_byte];
            } else {
                if (i_byte + 1 < byte_len) {
                    return self.data[i_byte];
                }
            }

            const tail_mask: u8 = (@as(u8, 1) << @intCast(remainder_bits)) - 1;
            const byte = self.data[i_byte] & tail_mask;

            // Pair assertion: ensure no bits are set beyond the logical length.
            if (remainder_bits > 0) std.debug.assert(@clz(byte) >= 8 - remainder_bits);

            return byte;
        }
References
  1. Split compound conditions into simple conditions using nested if/else branches. (link)
  2. Assert all function arguments and return values, pre/postconditions and invariants. (link)


pub fn getTrueBitIndexes(self: *const @This(), out: []usize) !usize {
if (out.len < length) {
return error.InvalidSize;
}
var true_bit_count: usize = 0;

for (0..byte_len) |i_byte| {
var b = self.data[i_byte];
var b = self.maskedByte(i_byte);

while (b != 0) {
const lsb: usize = @as(u8, @ctz(b));
Expand All @@ -65,7 +74,7 @@ pub fn BitVector(comptime _length: comptime_int) type {
var found_index: ?usize = null;

for (0..byte_len) |i_byte| {
var b = self.data[i_byte];
var b = self.maskedByte(i_byte);

while (b != 0) {
if (found_index != null) {
Expand Down Expand Up @@ -132,10 +141,10 @@ pub fn BitVector(comptime _length: comptime_int) type {
allocator: std.mem.Allocator,
values: *const [length]T,
) !std.array_list.AlignedManaged(T, null) {
var indices = try std.array_list.AlignedManaged(T, null).initCapacity(allocator, byte_len * 8);
var indices = try std.array_list.AlignedManaged(T, null).initCapacity(allocator, length);

for (0..byte_len) |i_byte| {
var b = self.data[i_byte];
var b = self.maskedByte(i_byte);
// Kernighan's algorithm to count the set bits instead of going through 0..8 for every byte
while (b != 0) {
const lsb: usize = @as(u8, @ctz(b)); // Get the index of least significant bit
Expand Down Expand Up @@ -559,6 +568,26 @@ test "BitVectorType - tree.deserializeFromBytes 128 bits" {
}
}

test "BitVectorType helpers ignore padding bits" {
const allocator = std.testing.allocator;
const Bits = BitVectorType(4);

var value = Bits.Type.empty;
value.data[0] = 0b1111_0001;

var indexes: [Bits.length]usize = undefined;
const count = try value.getTrueBitIndexes(indexes[0..]);
try std.testing.expectEqual(@as(usize, 1), count);
try std.testing.expectEqualSlices(usize, &.{0}, indexes[0..count]);

try std.testing.expectEqual(@as(?usize, 0), value.getSingleTrueBit());

const input_values = [_]u8{ 10, 11, 12, 13 };
var intersected = try value.intersectValues(u8, allocator, &input_values);
defer intersected.deinit();
try std.testing.expectEqualSlices(u8, &.{10}, intersected.items);
}

const TypeTestCase = @import("test_utils.zig").TypeTestCase;
test "BitVectorType of 128 bits" {
const testCases = [_]TypeTestCase{
Expand Down
Loading