Skip to content

Commit 9844b27

Browse files
committed
falcon: generalize the SIMD around pubkey parsing
1 parent edf7089 commit 9844b27

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

src/signatures/falcon.zig

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ fn Falcon(Hash: type, N: u32) type {
2727
pub const PublicKey = struct {
2828
h: Polynomial(N, Fq),
2929

30-
const V = @Vector(4, i16);
30+
const length = 16;
31+
const stride = (length / 4) * 7;
32+
const V = @Vector(length, i16);
3133
const QV: V = @splat(Q);
3234

3335
const BITS_PER_VALUE = 14;
@@ -45,32 +47,36 @@ fn Falcon(Hash: type, N: u32) type {
4547
// values are compressed into a bit sequence of 14 * N bits, or 14N/8 bytes.
4648
const h = bytes[1..];
4749
var coeff: [N]Fq = undefined;
48-
inline for (0..N / 4) |i| {
50+
inline for (0..N / length) |i| {
4951
// Given that each element is 14 bits, 7 bytes hold 4 elements (56 / 14 = 4).
5052
// We represent the elements as u32 for efficient arithmetics.
51-
const in = h[i * 7 ..][0..7];
52-
const out: *[4]u32 = @ptrCast(coeff[i * 4 ..][0..4]);
53-
const mask: @Vector(4, u32) = @splat((1 << 14) - 1);
53+
const in = h[i * stride ..][0..stride];
54+
const out: *[length]u32 = @ptrCast(coeff[i * length ..][0..length]);
55+
const mask: @Vector(length, u32) = @splat((1 << 14) - 1);
5456

5557
// We perform 2 movs to load words at `in` and `in + 3`.
5658
// The vector now contains 4 compressed elements (end-exclusive ranges):
5759
// 1. 00..14 (bytes 0, 1)
5860
// 2. 14..28 (bytes 1, 2, 3)
5961
// 3. 28..42 (bytes 3, 4, 5)
6062
// 4. 42..56 (bytes 5, 6)
61-
const compressed: @Vector(4, u32) = .{
62-
@bitCast(in[0..4].*),
63-
@bitCast(in[0..4].*),
64-
@bitCast(in[3..7].*),
65-
@bitCast(in[3..7].*),
66-
};
67-
const shifted = @byteSwap(compressed) >> .{ 18, 4, 14, 0 };
63+
var compressed: @Vector(length, u32) = undefined;
64+
inline for (0..length / 4) |j| {
65+
@setEvalBranchQuota(length * 1_000);
66+
compressed[(j * 4) + 0] = @bitCast(in[j * 7 ..][0..4].*);
67+
compressed[(j * 4) + 1] = @bitCast(in[j * 7 ..][0..4].*);
68+
compressed[(j * 4) + 2] = @bitCast(in[j * 7 ..][3..7].*);
69+
compressed[(j * 4) + 3] = @bitCast(in[j * 7 ..][3..7].*);
70+
}
71+
const shifted = @byteSwap(compressed) >> std.simd.repeat(length, [_]u5{ 18, 4, 14, 0 });
72+
}
73+
const shifted = @byteSwap(compressed) >> std.simd.repeat(length, [_]u5{ 18, 4, 14, 0 });
6874
// After the mask, each element fits into 14-bits, so it'll always fit into signed 16 bits.
6975
const masked: V = @intCast(shifted & mask);
7076
// We perform the modulus check in parallel, checking each element and returning
7177
// an error if any of the elements are greater than greater than or equal to the modulus.
7278
if (@reduce(.Or, masked >= QV)) return error.InvalidCoeff;
73-
out.* = Fq.Vector(4).init(masked);
79+
out.* = Fq.Vector(length).init(masked);
7480
}
7581

7682
return .{ .h = .{ .coeff = coeff } };
@@ -696,9 +702,7 @@ fn Falcon(Hash: type, N: u32) type {
696702

697703
// a[j] = a[j] * n^-1 mod q
698704
const ninv = T.precompute.ninv;
699-
for (&a) |*aj| {
700-
aj.* = aj.mul(ninv);
701-
}
705+
for (&a) |*aj| aj.* = aj.mul(ninv);
702706

703707
return .{ .coeff = a };
704708
}

0 commit comments

Comments
 (0)