Skip to content

Commit 85f4205

Browse files
authored
feat(mlkem): initialize mlkem from golang standard library
1 parent cc7ce9e commit 85f4205

9 files changed

Lines changed: 3355 additions & 0 deletions

File tree

mlkem/field.go

Lines changed: 573 additions & 0 deletions
Large diffs are not rendered by default.

mlkem/field_test.go

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.24
6+
7+
package mlkem
8+
9+
import (
10+
"bytes"
11+
"crypto/rand"
12+
"math/big"
13+
mathrand "math/rand/v2"
14+
"strconv"
15+
"testing"
16+
)
17+
18+
func TestFieldReduce(t *testing.T) {
19+
for a := range uint32(2*q*q) {
20+
got := fieldReduce(a)
21+
exp := fieldElement(a % q)
22+
if got != exp {
23+
t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
24+
}
25+
}
26+
}
27+
28+
func TestFieldAdd(t *testing.T) {
29+
for a := range fieldElement(q) {
30+
for b := range fieldElement(q) {
31+
got := fieldAdd(a, b)
32+
exp := (a + b) % q
33+
if got != exp {
34+
t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
35+
}
36+
}
37+
}
38+
}
39+
40+
func TestFieldSub(t *testing.T) {
41+
for a := range fieldElement(q) {
42+
for b := range fieldElement(q) {
43+
got := fieldSub(a, b)
44+
exp := (a - b + q) % q
45+
if got != exp {
46+
t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
47+
}
48+
}
49+
}
50+
}
51+
52+
func TestFieldMul(t *testing.T) {
53+
for a := range fieldElement(q) {
54+
for b := range fieldElement(q) {
55+
got := fieldMul(a, b)
56+
exp := fieldElement((uint32(a) * uint32(b)) % q)
57+
if got != exp {
58+
t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
59+
}
60+
}
61+
}
62+
}
63+
64+
func TestDecompressCompress(t *testing.T) {
65+
for _, bits := range []uint8{1, 4, 10} {
66+
for a := uint16(0); a < 1<<bits; a++ {
67+
f := decompress(a, bits)
68+
if f >= q {
69+
t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
70+
}
71+
got := compress(f, bits)
72+
if got != a {
73+
t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
74+
}
75+
}
76+
77+
for a := fieldElement(0); a < q; a++ {
78+
c := compress(a, bits)
79+
if c >= 1<<bits {
80+
t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
81+
}
82+
got := decompress(c, bits)
83+
diff := min(a-got, got-a, a-got+q, got-a+q)
84+
ceil := q / (1 << bits)
85+
if diff > fieldElement(ceil) {
86+
t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
87+
a, bits, bits, got, diff, ceil)
88+
}
89+
}
90+
}
91+
}
92+
93+
func CompressRat(x fieldElement, d uint8) uint16 {
94+
if x >= q {
95+
panic("x out of range")
96+
}
97+
if d <= 0 || d >= 12 {
98+
panic("d out of range")
99+
}
100+
101+
precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q
102+
103+
// FloatString rounds halves away from 0, and our result should always be positive,
104+
// so it should work as we expect. (There's no direct way to round a Rat.)
105+
rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
106+
if err != nil {
107+
panic(err)
108+
}
109+
110+
// If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
111+
return uint16(rounded % (1 << d))
112+
}
113+
114+
func TestCompress(t *testing.T) {
115+
for d := 1; d < 12; d++ {
116+
for n := range q {
117+
expected := CompressRat(fieldElement(n), uint8(d))
118+
result := compress(fieldElement(n), uint8(d))
119+
if result != expected {
120+
t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
121+
}
122+
}
123+
}
124+
}
125+
126+
func DecompressRat(y uint16, d uint8) fieldElement {
127+
if y >= 1<<d {
128+
panic("y out of range")
129+
}
130+
if d <= 0 || d >= 12 {
131+
panic("d out of range")
132+
}
133+
134+
precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y == (q * y) / 2ᵈ
135+
136+
// FloatString rounds halves away from 0, and our result should always be positive,
137+
// so it should work as we expect. (There's no direct way to round a Rat.)
138+
rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
139+
if err != nil {
140+
panic(err)
141+
}
142+
143+
// If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
144+
return fieldElement(rounded % q)
145+
}
146+
147+
func TestDecompress(t *testing.T) {
148+
for d := 1; d < 12; d++ {
149+
for n := 0; n < (1 << d); n++ {
150+
expected := DecompressRat(uint16(n), uint8(d))
151+
result := decompress(uint16(n), uint8(d))
152+
if result != expected {
153+
t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
154+
}
155+
}
156+
}
157+
}
158+
159+
func randomRingElement() ringElement {
160+
var r ringElement
161+
for i := range r {
162+
r[i] = fieldElement(mathrand.IntN(q))
163+
}
164+
return r
165+
}
166+
167+
func TestEncodeDecode(t *testing.T) {
168+
f := randomRingElement()
169+
b := make([]byte, 12*n/8)
170+
rand.Read(b)
171+
172+
// Compare ringCompressAndEncode to ringCompressAndEncodeN.
173+
e1 := ringCompressAndEncode(nil, f, 10)
174+
e2 := ringCompressAndEncode10(nil, f)
175+
if !bytes.Equal(e1, e2) {
176+
t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode10 = %x", e1, e2)
177+
}
178+
e1 = ringCompressAndEncode(nil, f, 4)
179+
e2 = ringCompressAndEncode4(nil, f)
180+
if !bytes.Equal(e1, e2) {
181+
t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode4 = %x", e1, e2)
182+
}
183+
e1 = ringCompressAndEncode(nil, f, 1)
184+
e2 = ringCompressAndEncode1(nil, f)
185+
if !bytes.Equal(e1, e2) {
186+
t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode1 = %x", e1, e2)
187+
}
188+
189+
// Compare ringDecodeAndDecompress to ringDecodeAndDecompressN.
190+
g1 := ringDecodeAndDecompress(b[:encodingSize10], 10)
191+
g2 := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
192+
if g1 != g2 {
193+
t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress10 = %v", g1, g2)
194+
}
195+
g1 = ringDecodeAndDecompress(b[:encodingSize4], 4)
196+
g2 = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
197+
if g1 != g2 {
198+
t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress4 = %v", g1, g2)
199+
}
200+
g1 = ringDecodeAndDecompress(b[:encodingSize1], 1)
201+
g2 = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
202+
if g1 != g2 {
203+
t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress1 = %v", g1, g2)
204+
}
205+
206+
// Round-trip ringCompressAndEncode and ringDecodeAndDecompress.
207+
for d := 1; d < 12; d++ {
208+
encodingSize := d * n / 8
209+
g := ringDecodeAndDecompress(b[:encodingSize], uint8(d))
210+
out := ringCompressAndEncode(nil, g, uint8(d))
211+
if !bytes.Equal(out, b[:encodingSize]) {
212+
t.Errorf("roundtrip failed for d = %d", d)
213+
}
214+
}
215+
216+
// Round-trip ringCompressAndEncodeN and ringDecodeAndDecompressN.
217+
g := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
218+
out := ringCompressAndEncode10(nil, g)
219+
if !bytes.Equal(out, b[:encodingSize10]) {
220+
t.Errorf("roundtrip failed for specialized 10")
221+
}
222+
g = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
223+
out = ringCompressAndEncode4(nil, g)
224+
if !bytes.Equal(out, b[:encodingSize4]) {
225+
t.Errorf("roundtrip failed for specialized 4")
226+
}
227+
g = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
228+
out = ringCompressAndEncode1(nil, g)
229+
if !bytes.Equal(out, b[:encodingSize1]) {
230+
t.Errorf("roundtrip failed for specialized 1")
231+
}
232+
}
233+
234+
func BitRev7(n uint8) uint8 {
235+
if n>>7 != 0 {
236+
panic("not 7 bits")
237+
}
238+
var r uint8
239+
r |= n >> 6 & 0b0000_0001
240+
r |= n >> 4 & 0b0000_0010
241+
r |= n >> 2 & 0b0000_0100
242+
r |= n /**/ & 0b0000_1000
243+
r |= n << 2 & 0b0001_0000
244+
r |= n << 4 & 0b0010_0000
245+
r |= n << 6 & 0b0100_0000
246+
return r
247+
}
248+
249+
func TestZetas(t *testing.T) {
250+
ζ := big.NewInt(17)
251+
q := big.NewInt(q)
252+
for k, zeta := range zetas {
253+
// ζ^BitRev7(k) mod q
254+
exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
255+
if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
256+
t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
257+
}
258+
}
259+
}
260+
261+
func TestGammas(t *testing.T) {
262+
ζ := big.NewInt(17)
263+
q := big.NewInt(q)
264+
for k, gamma := range gammas {
265+
// ζ^2BitRev7(i)+1
266+
exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
267+
if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
268+
t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
269+
}
270+
}
271+
}

0 commit comments

Comments
 (0)