Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/ssz/type/bit_vector.zig
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ pub fn BitVector(comptime _length: comptime_int) type {
}
}

fn maskedFinalByte(self: *const @This()) u8 {
const remainder_bits = length % 8;
std.debug.assert(remainder_bits != 0);
const tail_mask: u8 = (@as(u8, 1) << @intCast(remainder_bits)) - 1;
return self.data[byte_len - 1] & tail_mask;
}

pub fn getTrueBitIndexes(self: *const @This(), out: []usize) !usize {
if (out.len < length) {
return error.InvalidSize;
Expand All @@ -48,6 +55,11 @@ pub fn BitVector(comptime _length: comptime_int) type {

for (0..byte_len) |i_byte| {
var b = self.data[i_byte];
if (length % 8 != 0) {
if (i_byte + 1 == byte_len) {
b = self.maskedFinalByte();
}
}

while (b != 0) {
const lsb: usize = @as(u8, @ctz(b));
Expand All @@ -66,6 +78,11 @@ pub fn BitVector(comptime _length: comptime_int) type {

for (0..byte_len) |i_byte| {
var b = self.data[i_byte];
if (length % 8 != 0) {
if (i_byte + 1 == byte_len) {
b = self.maskedFinalByte();
}
}

while (b != 0) {
if (found_index != null) {
Expand Down Expand Up @@ -132,10 +149,15 @@ pub fn BitVector(comptime _length: comptime_int) type {
allocator: std.mem.Allocator,
values: *const [length]T,
) !std.array_list.AlignedManaged(T, null) {
var indices = try std.array_list.AlignedManaged(T, null).initCapacity(allocator, byte_len * 8);
var indices = try std.array_list.AlignedManaged(T, null).initCapacity(allocator, length);

for (0..byte_len) |i_byte| {
var b = self.data[i_byte];
if (length % 8 != 0) {
if (i_byte + 1 == byte_len) {
b = self.maskedFinalByte();
}
}
// Kernighan's algorithm to count the set bits instead of going through 0..8 for every byte
while (b != 0) {
const lsb: usize = @as(u8, @ctz(b)); // Get the index of least significant bit
Expand Down Expand Up @@ -559,6 +581,26 @@ test "BitVectorType - tree.deserializeFromBytes 128 bits" {
}
}

test "BitVectorType helpers ignore padding bits" {
const allocator = std.testing.allocator;
const Bits = BitVectorType(4);

var value = Bits.Type.empty;
value.data[0] = 0b1111_0001;

var indexes: [Bits.length]usize = undefined;
const count = try value.getTrueBitIndexes(indexes[0..]);
try std.testing.expectEqual(@as(usize, 1), count);
try std.testing.expectEqualSlices(usize, &.{0}, indexes[0..count]);

try std.testing.expectEqual(@as(?usize, 0), value.getSingleTrueBit());

const input_values = [_]u8{ 10, 11, 12, 13 };
var intersected = try value.intersectValues(u8, allocator, &input_values);
defer intersected.deinit();
try std.testing.expectEqualSlices(u8, &.{10}, intersected.items);
}

const TypeTestCase = @import("test_utils.zig").TypeTestCase;
test "BitVectorType of 128 bits" {
const testCases = [_]TypeTestCase{
Expand Down
Loading