66import { FFTCore , reverseBits } from '@noble/curves/abstract/fft.js' ;
77import { shake128 , shake256 } from '@noble/hashes/sha3.js' ;
88import type { TypedArray } from '@noble/hashes/utils.js' ;
9- import { type BytesCoderLen , cleanBytes , type Coder , getMask } from './utils.ts' ;
9+ import {
10+ type BytesCoderLen ,
11+ cleanBytes ,
12+ type Coder ,
13+ getMask ,
14+ type TArg ,
15+ type TRet ,
16+ } from './utils.ts' ;
1017
1118/** Extendable-output reader used by the CRYSTALS implementations. */
1219export type XOF = (
@@ -61,6 +68,19 @@ export type CrystalOpts<T extends TypedArray> = {
6168/** Constructor function for typed polynomial containers. */
6269export type TypedCons < T extends TypedArray > = ( n : number ) => T ;
6370
71+ type Crystals < T extends TypedArray > = {
72+ mod : ( a : number , modulo ?: number ) => number ;
73+ smod : ( a : number , modulo ?: number ) => number ;
74+ nttZetas : T ;
75+ NTT : {
76+ /** Forward transform in place. Mutates and returns `r`. */
77+ encode : ( r : T ) => T ;
78+ /** Inverse transform in place. Mutates and returns `r`. */
79+ decode : ( r : T ) => T ;
80+ } ;
81+ bitsCoder : ( d : number , c : Coder < number , number > ) => BytesCoderLen < T > ;
82+ } ;
83+
6484/**
6585 * Creates shared modular arithmetic, NTT, and packing helpers for CRYSTALS schemes.
6686 * @param opts - Polynomial and transform parameters. See {@link CrystalOpts}.
@@ -80,20 +100,7 @@ export type TypedCons<T extends TypedArray> = (n: number) => T;
80100 * const reduced = crystals.mod(-1);
81101 * ```
82102 */
83- export const genCrystals = < T extends TypedArray > (
84- opts : CrystalOpts < T >
85- ) : {
86- mod : ( a : number , modulo ?: number ) => number ;
87- smod : ( a : number , modulo ?: number ) => number ;
88- nttZetas : T ;
89- NTT : {
90- /** Forward transform in place. Mutates and returns `r`. */
91- encode : ( r : T ) => T ;
92- /** Inverse transform in place. Mutates and returns `r`. */
93- decode : ( r : T ) => T ;
94- } ;
95- bitsCoder : ( d : number , c : Coder < number , number > ) => BytesCoderLen < T > ;
96- } => {
103+ export const genCrystals = < T extends TypedArray > ( opts : CrystalOpts < T > ) : TRet < Crystals < T > > => {
97104 // isKyber: true means Kyber, false means Dilithium
98105 const { newPoly, N, Q, F, ROOT_OF_UNITY , brvBits, isKyber } = opts ;
99106 // Normalize JS `%` into the canonical Z_m representative `[0, modulo-1]` expected by
@@ -160,38 +167,48 @@ export const genCrystals = <T extends TypedArray>(
160167 } ;
161168 // Pack one little-endian `d`-bit word per coefficient, matching FIPS 203 ByteEncode /
162169 // ByteDecode and the FIPS 204 BitsToBytes-based polynomial packing helpers.
163- const bitsCoder = ( d : number , c : Coder < number , number > ) : BytesCoderLen < T > => {
170+ const bitsCoder = ( d : number , c : Coder < number , number > ) : TRet < BytesCoderLen < T > > => {
164171 const mask = getMask ( d ) ;
165172 const bytesLen = d * ( N / 8 ) ;
166173 return {
167174 bytesLen,
168- encode : ( poly : T ) : Uint8Array => {
175+ encode : ( poly_ : TArg < T > ) : TRet < Uint8Array > => {
176+ const poly = poly_ as T ;
169177 const r = new Uint8Array ( bytesLen ) ;
170178 for ( let i = 0 , buf = 0 , bufLen = 0 , pos = 0 ; i < poly . length ; i ++ ) {
171179 buf |= ( c . encode ( poly [ i ] ) & mask ) << bufLen ;
172180 bufLen += d ;
173181 for ( ; bufLen >= 8 ; bufLen -= 8 , buf >>= 8 ) r [ pos ++ ] = buf & getMask ( bufLen ) ;
174182 }
175- return r ;
183+ return r as TRet < Uint8Array > ;
176184 } ,
177- decode : ( bytes : Uint8Array ) : T => {
185+ decode : ( bytes : TArg < Uint8Array > ) : TRet < T > => {
178186 const r = newPoly ( N ) ;
179187 for ( let i = 0 , buf = 0 , bufLen = 0 , pos = 0 ; i < bytes . length ; i ++ ) {
180188 buf |= bytes [ i ] << bufLen ;
181189 bufLen += 8 ;
182190 for ( ; bufLen >= d ; bufLen -= d , buf >>= d ) r [ pos ++ ] = c . decode ( buf & mask ) ;
183191 }
184- return r ;
192+ return r as TRet < T > ;
185193 } ,
186- } ;
194+ } as TRet < BytesCoderLen < T > > ;
187195 } ;
188196
189- return { mod, smod, nttZetas, NTT , bitsCoder } ;
197+ return {
198+ mod,
199+ smod,
200+ nttZetas : nttZetas as TRet < T > ,
201+ NTT : {
202+ encode : ( r : TArg < T > ) : TRet < T > => NTT . encode ( r as T ) as TRet < T > ,
203+ decode : ( r : TArg < T > ) : TRet < T > => NTT . decode ( r as T ) as TRet < T > ,
204+ } ,
205+ bitsCoder : bitsCoder as TRet < Crystals < T > > [ 'bitsCoder' ] ,
206+ } ;
190207} ;
191208
192209const createXofShake =
193- ( shake : typeof shake128 ) : XOF =>
194- ( seed : Uint8Array , blockLen ?: number ) => {
210+ ( shake : typeof shake128 ) : TRet < XOF > =>
211+ ( seed : TArg < Uint8Array > , blockLen ?: number ) => {
195212 if ( ! blockLen ) blockLen = shake . blockLen ;
196213 // Optimizations that won't mater:
197214 // - cached seed update (two .update(), on start and on the end)
@@ -217,7 +234,7 @@ const createXofShake =
217234 calls ++ ;
218235 return ( ) => {
219236 xofs ++ ;
220- return h . xofInto ( buf ) ;
237+ return h . xofInto ( buf ) as TRet < Uint8Array > ;
221238 } ;
222239 } ,
223240 clean : ( ) => {
@@ -243,7 +260,7 @@ const createXofShake =
243260 * const block = reader.get(0, 0)();
244261 * ```
245262 */
246- export const XOF128 : XOF = /* @__PURE__ */ createXofShake ( shake128 ) ;
263+ export const XOF128 : TRet < XOF > = /* @__PURE__ */ createXofShake ( shake128 ) ;
247264/**
248265 * SHAKE256-based extendable-output reader factory used by ML-DSA.
249266 * `get(x, y)` appends raw one-byte coordinates to the seed, invalidates previously returned
@@ -260,4 +277,4 @@ export const XOF128: XOF = /* @__PURE__ */ createXofShake(shake128);
260277 * const block = reader.get(0, 0)();
261278 * ```
262279 */
263- export const XOF256 : XOF = /* @__PURE__ */ createXofShake ( shake256 ) ;
280+ export const XOF256 : TRet < XOF > = /* @__PURE__ */ createXofShake ( shake256 ) ;
0 commit comments