@@ -27,7 +27,9 @@ fn Falcon(Hash: type, N: u32) type {
2727 pub const PublicKey = struct {
2828 h : Polynomial (N , Fq ),
2929
30- const V = @Vector (4 , i16 );
30+ const length = 16 ;
31+ const stride = (length / 4 ) * 7 ;
32+ const V = @Vector (length , i16 );
3133 const QV : V = @splat (Q );
3234
3335 const BITS_PER_VALUE = 14 ;
@@ -45,32 +47,36 @@ fn Falcon(Hash: type, N: u32) type {
4547 // values are compressed into a bit sequence of 14 * N bits, or 14N/8 bytes.
4648 const h = bytes [1.. ];
4749 var coeff : [N ]Fq = undefined ;
48- inline for (0.. N / 4 ) | i | {
50+ inline for (0.. N / length ) | i | {
4951 // Given that each element is 14 bits, 7 bytes hold 4 elements (56 / 14 = 4).
5052 // We represent the elements as u32 for efficient arithmetics.
51- const in = h [i * 7 .. ][0.. 7 ];
52- const out : * [4 ]u32 = @ptrCast (coeff [i * 4 .. ][0.. 4 ]);
53- const mask : @Vector (4 , u32 ) = @splat ((1 << 14 ) - 1 );
53+ const in = h [i * stride .. ][0.. stride ];
54+ const out : * [length ]u32 = @ptrCast (coeff [i * length .. ][0.. length ]);
55+ const mask : @Vector (length , u32 ) = @splat ((1 << 14 ) - 1 );
5456
5557 // We perform 2 movs to load words at `in` and `in + 3`.
5658 // The vector now contains 4 compressed elements (end-exclusive ranges):
5759 // 1. 00..14 (bytes 0, 1)
5860 // 2. 14..28 (bytes 1, 2, 3)
5961 // 3. 28..42 (bytes 3, 4, 5)
6062 // 4. 42..56 (bytes 5, 6)
61- const compressed : @Vector (4 , u32 ) = .{
62- @bitCast (in [0.. 4].* ),
63- @bitCast (in [0.. 4].* ),
64- @bitCast (in [3.. 7].* ),
65- @bitCast (in [3.. 7].* ),
66- };
67- const shifted = @byteSwap (compressed ) >> .{ 18 , 4 , 14 , 0 };
63+ var compressed : @Vector (length , u32 ) = undefined ;
64+ inline for (0.. length / 4 ) | j | {
65+ @setEvalBranchQuota (length * 1_000 );
66+ compressed [(j * 4 ) + 0 ] = @bitCast (in [j * 7 .. ][0.. 4].* );
67+ compressed [(j * 4 ) + 1 ] = @bitCast (in [j * 7 .. ][0.. 4].* );
68+ compressed [(j * 4 ) + 2 ] = @bitCast (in [j * 7 .. ][3.. 7].* );
69+ compressed [(j * 4 ) + 3 ] = @bitCast (in [j * 7 .. ][3.. 7].* );
70+ }
71+ const shifted = @byteSwap (compressed ) >> std .simd .repeat (length , [_ ]u5 { 18 , 4 , 14 , 0 });
72+ }
73+ const shifted = @byteSwap (compressed ) >> std .simd .repeat (length , [_ ]u5 { 18 , 4 , 14 , 0 });
6874 // After the mask, each element fits into 14-bits, so it'll always fit into signed 16 bits.
6975 const masked : V = @intCast (shifted & mask );
7076 // We perform the modulus check in parallel, checking each element and returning
7177 // an error if any of the elements are greater than greater than or equal to the modulus.
7278 if (@reduce (.Or , masked >= QV )) return error .InvalidCoeff ;
73- out .* = Fq .Vector (4 ).init (masked );
79+ out .* = Fq .Vector (length ).init (masked );
7480 }
7581
7682 return .{ .h = .{ .coeff = coeff } };
@@ -696,9 +702,7 @@ fn Falcon(Hash: type, N: u32) type {
696702
697703 // a[j] = a[j] * n^-1 mod q
698704 const ninv = T .precompute .ninv ;
699- for (& a ) | * aj | {
700- aj .* = aj .mul (ninv );
701- }
705+ for (& a ) | * aj | aj .* = aj .mul (ninv );
702706
703707 return .{ .coeff = a };
704708 }
0 commit comments