Skip to content

Commit 78e42bd

Browse files
committed
chore: add the same tests to @chainsafe/swap-or-not-shuffle
1 parent 714267d commit 78e42bd

2 files changed

Lines changed: 169 additions & 17 deletions

File tree

test/bun/src/shuffle.ts

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ export function shuffleList(
99
seed: Uint8Array,
1010
rounds: number,
1111
): Uint32Array {
12+
validateShufflingParams(activeIndices, seed, rounds);
13+
14+
const clonedActiveIndices = activeIndices.slice();
1215
const result = binding.shuffleList(
13-
activeIndices,
14-
activeIndices.length,
16+
clonedActiveIndices,
17+
clonedActiveIndices.length,
1518
seed,
1619
seed.length,
1720
rounds,
@@ -21,7 +24,7 @@ export function shuffleList(
2124
throw new Error(`Shuffle failed with error code: ${result}`);
2225
}
2326

24-
return activeIndices;
27+
return clonedActiveIndices;
2528
}
2629

2730
/**
@@ -33,9 +36,12 @@ export function unshuffleList(
3336
seed: Uint8Array,
3437
rounds: number,
3538
): Uint32Array {
39+
validateShufflingParams(activeIndices, seed, rounds);
40+
41+
const clonedActiveIndices = activeIndices.slice();
3642
const result = binding.unshuffleList(
37-
activeIndices,
38-
activeIndices.length,
43+
clonedActiveIndices,
44+
clonedActiveIndices.length,
3945
seed,
4046
seed.length,
4147
rounds,
@@ -45,7 +51,7 @@ export function unshuffleList(
4551
throw new Error(`Unshuffle failed with error code: ${result}`);
4652
}
4753

48-
return activeIndices;
54+
return clonedActiveIndices;
4955
}
5056

5157
// same value to ErrorCode.Pending at zig side
@@ -82,10 +88,13 @@ export function withPollingParams(
8288
rounds: number,
8389
) => number,
8490
): Promise<Uint32Array> {
91+
validateShufflingParams(activeIndices, seed, rounds);
92+
8593
const start = Date.now();
94+
const clonedActiveIndices = activeIndices.slice();
8695
const pointerIdx = nativeShuffleFn(
87-
activeIndices,
88-
activeIndices.length,
96+
clonedActiveIndices,
97+
clonedActiveIndices.length,
8998
seed,
9099
seed.length,
91100
rounds,
@@ -109,7 +118,7 @@ export function withPollingParams(
109118
case POLL_STATUS_SUCCESS:
110119
clearInterval(interval);
111120
binding.releaseAsyncResult(pointerIdx);
112-
resolve(activeIndices);
121+
resolve(clonedActiveIndices);
113122
return;
114123
case POLL_STATUS_PENDING:
115124
break;
@@ -141,3 +150,19 @@ export function withPollingParams(
141150
doShuffleList(activeIndices, seed, rounds, binding.asyncUnshuffleList),
142151
};
143152
}
153+
154+
function validateShufflingParams(
155+
activeIndices: Uint32Array,
156+
seed: Uint8Array,
157+
rounds: number,
158+
): void {
159+
if (activeIndices.length >= 0xffffffff) {
160+
throw new Error("ActiveIndices must fit in a u32");
161+
}
162+
if (seed.length !== 32) {
163+
throw new Error("Shuffling seed must be 32 bytes long");
164+
}
165+
if (rounds < 0 || rounds > 255) {
166+
throw new Error("Rounds must be between 0 and 255");
167+
}
168+
}

test/bun/test/unit/shuffle.test.ts

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import { describe, expect, it } from "bun:test";
2+
import { randomBytes } from "node:crypto";
23
import {
34
shuffleList,
45
unshuffleList,
56
withPollingParams,
67
} from "../../src/shuffle.js";
8+
import * as referenceImplementation from "../referenceImplementation.js";
9+
10+
// start polling right after the call for every 1ms, throw error if after 100ms
11+
const { asyncShuffleList, asyncUnshuffleList } = withPollingParams(0, 1, 100);
712

813
describe("unshuffleList", () => {
914
const testCases: { input: Uint32Array; expected: Uint32Array }[] = [
@@ -33,13 +38,6 @@ describe("unshuffleList", () => {
3338
expect(result2).toEqual(input);
3439
});
3540

36-
// start polling right after the call for every 1ms, throw error if after 100ms
37-
const { asyncShuffleList, asyncUnshuffleList } = withPollingParams(
38-
0,
39-
1,
40-
100,
41-
);
42-
4341
const testWithNFactor = async (n: number) => {
4442
let promises: Promise<Uint32Array>[] = [];
4543
// call asyncUnshuffleList in parallel n times
@@ -82,4 +80,133 @@ describe("unshuffleList", () => {
8280
}
8381
});
8482

85-
// TODO: the same tests to @chainsafe/swap-or-not-shuffle
83+
// the same tests to @chainsafe/swap-or-not-shuffle
84+
interface ShuffleTestCase {
85+
id: string;
86+
rounds: number;
87+
seed: Uint8Array;
88+
input: Uint32Array;
89+
shuffled: string;
90+
unshuffled: string;
91+
}
92+
93+
function fromHex(hexInput: string): Uint8Array {
94+
let hex = hexInput;
95+
if (typeof hex !== "string") {
96+
throw new Error(`hex argument type ${typeof hex} must be of type string`);
97+
}
98+
99+
if (hex.startsWith("0x")) {
100+
hex = hex.slice(2);
101+
}
102+
103+
if (hex.length % 2 !== 0) {
104+
throw new Error(`hex string length ${hex.length} must be multiple of 2`);
105+
}
106+
107+
const b = Buffer.from(hex, "hex");
108+
return new Uint8Array(b.buffer, b.byteOffset, b.length);
109+
}
110+
111+
function getInputArray(count: number): Uint32Array {
112+
return Uint32Array.from(Array.from({ length: count }, (_, i) => i));
113+
}
114+
115+
function buildReferenceTestCase(
116+
count: number,
117+
rounds: number,
118+
): ShuffleTestCase {
119+
const seed = randomBytes(32);
120+
const input = getInputArray(count);
121+
const shuffled = input.slice();
122+
referenceImplementation.shuffleList(shuffled, seed, rounds);
123+
const unshuffled = input.slice();
124+
referenceImplementation.unshuffleList(unshuffled, seed, rounds);
125+
return {
126+
id: `TestCase for ${count} indices with seed of $0x${seed.toString("hex")}`,
127+
seed,
128+
rounds,
129+
input,
130+
shuffled: Buffer.from(shuffled).toString("hex"),
131+
unshuffled: Buffer.from(unshuffled).toString("hex"),
132+
};
133+
}
134+
135+
describe("shuffle", () => {
136+
it("should throw for invalid seed", () => {
137+
const test = buildReferenceTestCase(10, 10);
138+
let invalidSeed = Buffer.alloc(31, 0xac);
139+
expect(() => unshuffleList(test.input, invalidSeed, test.rounds)).toThrow(
140+
"Shuffling seed must be 32 bytes long",
141+
);
142+
invalidSeed = Buffer.alloc(33, 0xac);
143+
expect(() => unshuffleList(test.input, invalidSeed, test.rounds)).toThrow(
144+
"Shuffling seed must be 32 bytes long",
145+
);
146+
});
147+
148+
it("should throw for invalid number of rounds", () => {
149+
const test = buildReferenceTestCase(10, 10);
150+
expect(() => unshuffleList(test.input, test.seed, -1)).toThrow(
151+
"Rounds must be between 0 and 255",
152+
);
153+
expect(() => unshuffleList(test.input, test.seed, 256)).toThrow(
154+
"Rounds must be between 0 and 255",
155+
);
156+
});
157+
158+
/**
159+
* Leave this test commented for github runners. It fails on memory allocations. Leave in test suite
160+
* to confirm that it does work though (local tests if you want to verify)
161+
* TODO: not sure why Bun suspends creating a Buffer of 2**32 bytes
162+
* but the implementation checks for max length of input already
163+
*/
164+
it.skip("should throw for invalid input array length", () => {
165+
const test = buildReferenceTestCase(10, 10);
166+
const input = Uint32Array.from(Buffer.alloc(2 ** 32, 0xac));
167+
expect(() => unshuffleList(input, test.seed, 100)).toThrow(
168+
"ActiveIndices must fit in a u32",
169+
);
170+
});
171+
172+
it("should match spec test results", () => {
173+
const seed =
174+
"0x4fe91d85d6bc19b20413659c61f3c690a1c4d48be41cab8363a130cebabada97";
175+
const rounds = 10;
176+
const expected = Buffer.from([
177+
99, 71, 51, 5, 78, 61, 12, 17, 30, 3, 59, 47, 6, 9, 1, 41, 18, 37, 55, 43,
178+
20, 31, 38, 79, 29, 69, 70, 54, 53, 36, 34, 62, 77, 87, 39, 96, 56, 92,
179+
16, 82, 40, 27, 58, 14, 68, 76, 80, 13, 28, 81, 64, 26, 19, 60, 90, 2, 98,
180+
67, 66, 52, 46, 95, 49, 72, 8, 21, 75, 57, 97, 83, 84, 88, 86, 7, 74, 32,
181+
63, 85, 23, 65, 24, 91, 0, 48, 35, 15, 44, 25, 22, 73, 93, 45, 4, 33, 89,
182+
94, 10, 42, 11, 50,
183+
]).toString("hex");
184+
185+
const result = unshuffleList(getInputArray(100), fromHex(seed), rounds);
186+
187+
expect(Buffer.from(result).toString("hex")).toEqual(expected);
188+
});
189+
190+
const testCases: ShuffleTestCase[] = [
191+
buildReferenceTestCase(8, 10),
192+
buildReferenceTestCase(16, 10),
193+
buildReferenceTestCase(16, 100),
194+
buildReferenceTestCase(256, 192),
195+
buildReferenceTestCase(256, 192),
196+
];
197+
198+
for (const { id, seed, rounds, input, shuffled, unshuffled } of testCases) {
199+
it(`sync - ${id}`, () => {
200+
const unshuffledResult = unshuffleList(input, seed, rounds);
201+
const shuffledResult = shuffleList(input, seed, rounds);
202+
expect(Buffer.from(shuffledResult).toString("hex")).toEqual(shuffled);
203+
expect(Buffer.from(unshuffledResult).toString("hex")).toEqual(unshuffled);
204+
});
205+
it(`async - ${id}`, async () => {
206+
const unshuffledResult = await asyncUnshuffleList(input, seed, rounds);
207+
const shuffledResult = await asyncShuffleList(input, seed, rounds);
208+
expect(Buffer.from(shuffledResult).toString("hex")).toEqual(shuffled);
209+
expect(Buffer.from(unshuffledResult).toString("hex")).toEqual(unshuffled);
210+
});
211+
}
212+
});

0 commit comments

Comments
 (0)