Skip to content

Commit dc4da04

Browse files
committed
Use Uint8Array supertype compatible in both TS5.6 + TS5.9
1 parent a2095f4 commit dc4da04

12 files changed

Lines changed: 986 additions & 567 deletions

File tree

src/_crystals.ts

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import { FFTCore, reverseBits } from '@noble/curves/abstract/fft.js';
77
import { shake128, shake256 } from '@noble/hashes/sha3.js';
88
import type { TypedArray } from '@noble/hashes/utils.js';
9-
import { type BytesCoderLen, cleanBytes, type Coder, getMask } from './utils.ts';
9+
import {
10+
type BytesCoderLen,
11+
cleanBytes,
12+
type Coder,
13+
getMask,
14+
type TArg,
15+
type TRet,
16+
} from './utils.ts';
1017

1118
/** Extendable-output reader used by the CRYSTALS implementations. */
1219
export type XOF = (
@@ -61,6 +68,19 @@ export type CrystalOpts<T extends TypedArray> = {
6168
/** Constructor function for typed polynomial containers. */
6269
export type TypedCons<T extends TypedArray> = (n: number) => T;
6370

71+
type Crystals<T extends TypedArray> = {
72+
mod: (a: number, modulo?: number) => number;
73+
smod: (a: number, modulo?: number) => number;
74+
nttZetas: T;
75+
NTT: {
76+
/** Forward transform in place. Mutates and returns `r`. */
77+
encode: (r: T) => T;
78+
/** Inverse transform in place. Mutates and returns `r`. */
79+
decode: (r: T) => T;
80+
};
81+
bitsCoder: (d: number, c: Coder<number, number>) => BytesCoderLen<T>;
82+
};
83+
6484
/**
6585
* Creates shared modular arithmetic, NTT, and packing helpers for CRYSTALS schemes.
6686
* @param opts - Polynomial and transform parameters. See {@link CrystalOpts}.
@@ -80,20 +100,7 @@ export type TypedCons<T extends TypedArray> = (n: number) => T;
80100
* const reduced = crystals.mod(-1);
81101
* ```
82102
*/
83-
export const genCrystals = <T extends TypedArray>(
84-
opts: CrystalOpts<T>
85-
): {
86-
mod: (a: number, modulo?: number) => number;
87-
smod: (a: number, modulo?: number) => number;
88-
nttZetas: T;
89-
NTT: {
90-
/** Forward transform in place. Mutates and returns `r`. */
91-
encode: (r: T) => T;
92-
/** Inverse transform in place. Mutates and returns `r`. */
93-
decode: (r: T) => T;
94-
};
95-
bitsCoder: (d: number, c: Coder<number, number>) => BytesCoderLen<T>;
96-
} => {
103+
export const genCrystals = <T extends TypedArray>(opts: CrystalOpts<T>): TRet<Crystals<T>> => {
97104
// isKyber: true means Kyber, false means Dilithium
98105
const { newPoly, N, Q, F, ROOT_OF_UNITY, brvBits, isKyber } = opts;
99106
// Normalize JS `%` into the canonical Z_m representative `[0, modulo-1]` expected by
@@ -160,38 +167,48 @@ export const genCrystals = <T extends TypedArray>(
160167
};
161168
// Pack one little-endian `d`-bit word per coefficient, matching FIPS 203 ByteEncode /
162169
// ByteDecode and the FIPS 204 BitsToBytes-based polynomial packing helpers.
163-
const bitsCoder = (d: number, c: Coder<number, number>): BytesCoderLen<T> => {
170+
const bitsCoder = (d: number, c: Coder<number, number>): TRet<BytesCoderLen<T>> => {
164171
const mask = getMask(d);
165172
const bytesLen = d * (N / 8);
166173
return {
167174
bytesLen,
168-
encode: (poly: T): Uint8Array => {
175+
encode: (poly_: TArg<T>): TRet<Uint8Array> => {
176+
const poly = poly_ as T;
169177
const r = new Uint8Array(bytesLen);
170178
for (let i = 0, buf = 0, bufLen = 0, pos = 0; i < poly.length; i++) {
171179
buf |= (c.encode(poly[i]) & mask) << bufLen;
172180
bufLen += d;
173181
for (; bufLen >= 8; bufLen -= 8, buf >>= 8) r[pos++] = buf & getMask(bufLen);
174182
}
175-
return r;
183+
return r as TRet<Uint8Array>;
176184
},
177-
decode: (bytes: Uint8Array): T => {
185+
decode: (bytes: TArg<Uint8Array>): TRet<T> => {
178186
const r = newPoly(N);
179187
for (let i = 0, buf = 0, bufLen = 0, pos = 0; i < bytes.length; i++) {
180188
buf |= bytes[i] << bufLen;
181189
bufLen += 8;
182190
for (; bufLen >= d; bufLen -= d, buf >>= d) r[pos++] = c.decode(buf & mask);
183191
}
184-
return r;
192+
return r as TRet<T>;
185193
},
186-
};
194+
} as TRet<BytesCoderLen<T>>;
187195
};
188196

189-
return { mod, smod, nttZetas, NTT, bitsCoder };
197+
return {
198+
mod,
199+
smod,
200+
nttZetas: nttZetas as TRet<T>,
201+
NTT: {
202+
encode: (r: TArg<T>): TRet<T> => NTT.encode(r as T) as TRet<T>,
203+
decode: (r: TArg<T>): TRet<T> => NTT.decode(r as T) as TRet<T>,
204+
},
205+
bitsCoder: bitsCoder as TRet<Crystals<T>>['bitsCoder'],
206+
};
190207
};
191208

192209
const createXofShake =
193-
(shake: typeof shake128): XOF =>
194-
(seed: Uint8Array, blockLen?: number) => {
210+
(shake: typeof shake128): TRet<XOF> =>
211+
(seed: TArg<Uint8Array>, blockLen?: number) => {
195212
if (!blockLen) blockLen = shake.blockLen;
196213
// Optimizations that won't mater:
197214
// - cached seed update (two .update(), on start and on the end)
@@ -217,7 +234,7 @@ const createXofShake =
217234
calls++;
218235
return () => {
219236
xofs++;
220-
return h.xofInto(buf);
237+
return h.xofInto(buf) as TRet<Uint8Array>;
221238
};
222239
},
223240
clean: () => {
@@ -243,7 +260,7 @@ const createXofShake =
243260
* const block = reader.get(0, 0)();
244261
* ```
245262
*/
246-
export const XOF128: XOF = /* @__PURE__ */ createXofShake(shake128);
263+
export const XOF128: TRet<XOF> = /* @__PURE__ */ createXofShake(shake128);
247264
/**
248265
* SHAKE256-based extendable-output reader factory used by ML-DSA.
249266
* `get(x, y)` appends raw one-byte coordinates to the seed, invalidates previously returned
@@ -260,4 +277,4 @@ export const XOF128: XOF = /* @__PURE__ */ createXofShake(shake128);
260277
* const block = reader.get(0, 0)();
261278
* ```
262279
*/
263-
export const XOF256: XOF = /* @__PURE__ */ createXofShake(shake256);
280+
export const XOF256: TRet<XOF> = /* @__PURE__ */ createXofShake(shake256);

0 commit comments

Comments
 (0)