diff --git a/circuits/circom/expand_message_xmd.circom b/circuits/circom/expand_message_xmd.circom index 78951ec..293a4f3 100644 --- a/circuits/circom/expand_message_xmd.circom +++ b/circuits/circom/expand_message_xmd.circom @@ -61,7 +61,7 @@ template HashMsgPrimeToB0(msg_length) { n2b[i] = Num2Bits(8); n2b[i].in <== msg_prime[i]; for (var j = 0; j < 8; j ++) { - hasher.in[i * 8 + (7 - j)] <== n2b[i].out[j]; + hasher.in[i * 8 + (j)] <== n2b[i].out[7 - j]; } } @@ -127,7 +127,6 @@ template HashB(b_idx) { } template StrXor(n) { - // TODO: For safety, should the inputs be constrained to be 0 <= n <= 255? signal input a[n]; signal input b[n]; signal output out[n]; @@ -137,6 +136,7 @@ template StrXor(n) { component n2b_b[n]; component b2n[n]; for (var i = 0; i < n; i ++) { + // Constraints each byte of inputs to be in range [0, 255] as well as decomposing them into bits n2b_a[i] = Num2Bits(8); n2b_a[i].in <== a[i]; n2b_b[i] = Num2Bits(8); @@ -154,6 +154,7 @@ template StrXor(n) { } } +// Spec https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#name-expand_message_xmd template ExpandMessageXmd(msg_length) { signal input msg[msg_length]; signal output out[96]; @@ -212,231 +213,4 @@ template ExpandMessageXmd(msg_length) { } out[64 + i] <== b2n_b3[i].out; } -} - -// msg_prime_length is in bytes -template VerifyMsgPrime(padded_length) { - signal input msg_prime[padded_length]; - signal input offset_msg[padded_length]; - signal input msg_length; // in bytes - - var offset = 64; - - // msg_prime = z_pad || msg || lib_str || 0 || dst_prime - - // Step 1: Check that msg_prime starts with z_pad (64 zeroes) - for (var i = 0; i < offset; i ++) { - msg_prime[i] === 0; - } - - // Step 2: Verify that offset_msg is valid - component zs = ZeroSandwich(offset, padded_length); - zs.substring_length <== msg_length; - for (var i = 0; i < padded_length; i ++) { - zs.in[i] <== offset_msg[i]; - } - - // Step 3: Check that msg_prime starts with offset_msg values from offset - // onwards - /* - msg_length = 4 - [1, 2, 3, 4, 5, 0, 0, 8] <- msg_prime - [1, 2, 3, 4, 0, 0, 0, 0] <- offset_msg - a: [1, 1, 1, 1, 0, 1, 1, 0] <- [msg_prime[i] == offset_msg[i] for i in length] - b: [1, 1, 1, 1, 0, 0, 0, 0] <- [i < msg_length ? 1 : 0 for i in length] - c: [1, 1, 1, 1, 0, 0, 0, 0] <- [a[i] * b[i] for i in length] - sum(c) === msg_length - */ - - component iseq_a[padded_length - offset]; - component lt_b[padded_length - offset]; - component ct_c = CalculateTotal(padded_length - offset); - - for (var i = 0; i < padded_length - offset; i ++) { - iseq_a[i] = IsEqual(); - iseq_a[i].in[0] <== msg_prime[offset + i]; - iseq_a[i].in[1] <== offset_msg[offset + i]; - - // TODO: save on constraints by first checking that msg_length is - // less than length (using LessThan(252) outside the loop), then using - // LessThan(log2(msg_length)) inside the loop - lt_b[i] = LessThan(252); - lt_b[i].in[0] <== offset + i; - lt_b[i].in[1] <== offset + msg_length; - - ct_c.in[i] <== iseq_a[i].out * lt_b[i].out; - } - - ct_c.out === msg_length; - - // Step 4: Check that msg_prime contains lib_str || 0 || dst_prime from - component lib_str_selector[2]; - component zero_selector = Selector(padded_length); - component dst_prime_selector[50]; - - var lib_str[2] = get_lib_str(); - var dst_prime[50] = get_dst_prime(); - - for (var i = 0; i < 2; i ++) { - lib_str_selector[i] = Selector(padded_length); - lib_str_selector[i].index <== offset + msg_length + i; - for (var j = 0; j < padded_length; j ++) { - lib_str_selector[i].in[j] <== msg_prime[j]; - } - lib_str_selector[i].out === lib_str[i]; - } - - zero_selector.index <== offset + msg_length + 2; - for (var i = 0; i < padded_length; i ++) { - zero_selector.in[i] <== msg_prime[i]; - } - zero_selector.out === 0; - - for (var i = 0; i < 50; i ++) { - dst_prime_selector[i] = Selector(padded_length); - dst_prime_selector[i].index <== offset + msg_length + 2 + 1 + i; - for (var j = 0; j < padded_length; j ++) { - dst_prime_selector[i].in[j] <== msg_prime[j]; - } - dst_prime_selector[i].out === dst_prime[i]; - } -} - -// padded_msg_prime_length is in bytes -template ExpandMessageXmd2(padded_msg_prime_length) { - // offset_msg must be offset by 64 zeros and end with 0s. e.g. if msg = [1, - // 2, 3], then offset_msg = [0, 0, ... 0, 1, 2, 3, 0, 0...] - signal input offset_msg[padded_msg_prime_length]; - signal input msg_length; // in bytes - - signal input msg_prime[padded_msg_prime_length]; - signal input padded_msg_prime[padded_msg_prime_length]; - - signal output out[96]; - - var offset = 64; // the length of z_prime - - // Step 1: Ensure that each value in offset_msg, msg_prime, and - // padded_msg_prime are < 256 - // TODO - - // Step 2: Verify that msg_prime is valid - component v = VerifyMsgPrime(padded_msg_prime_length); - v.msg_length <== msg_length; - for (var i = 0; i < padded_msg_prime_length; i ++) { - v.msg_prime[i] <== msg_prime[i]; - v.offset_msg[i] <== offset_msg[i]; - } - - // Step 3: Convert padded_msg_prime and msg_prime_bits to bits - signal padded_msg_prime_bits[padded_msg_prime_length * 8]; - signal msg_prime_bits[padded_msg_prime_length * 8]; - component n2b_a[padded_msg_prime_length]; - component n2b_b[padded_msg_prime_length]; - for (var i = 0; i < padded_msg_prime_length; i ++) { - n2b_a[i] = Num2Bits(8); - n2b_b[i] = Num2Bits(8); - n2b_a[i].in <== padded_msg_prime[i]; - n2b_b[i].in <== msg_prime[i]; - for (var j = 0; j < 8; j ++) { - padded_msg_prime_bits[i * 8 + (7 - j)] <== n2b_a[i].out[j]; - msg_prime_bits[i * 8 + (7 - j)] <== n2b_b[i].out[j]; - } - } - - // Step 4: Hash padded_msg_prime to derive B0 - component b0 = Sha256Hash(padded_msg_prime_length * 8); - for (var i = 0; i < padded_msg_prime_length * 8; i ++) { - b0.padded_bits[i] <== padded_msg_prime_bits[i]; - b0.msg[i] <== msg_prime_bits[i]; - } - - /*signal output out[256];*/ - /*for (var i = 0; i < 256; i ++) {*/ - /*out[i] <== hasher.out[i];*/ - /*}*/ - component b1 = HashB(1); - for (var i = 0; i < 256; i ++) { - b1.b_bits[i] <== b0.out[i]; - } - - component b2 = HashBi(2); - for (var i = 0; i < 256; i ++) { - b2.b0_bits[i] <== b0.out[i]; - b2.bi_minus_one_bits[i] <== b1.bi_bits[i]; - } - - component b3 = HashBi(3); - for (var i = 0; i < 256; i ++) { - b3.b0_bits[i] <== b0.out[i]; - b3.bi_minus_one_bits[i] <== b2.bi_bits[i]; - } - - component b2n_b1[32]; - for (var i = 0; i < 32; i ++) { - b2n_b1[i] = Bits2Num(8); - for (var j = 0; j < 8; j ++) { - b2n_b1[i].in[j] <== b1.bi_bits[i * 8 + (7 - j)]; - } - out[i] <== b2n_b1[i].out; - } - - component b2n_b2[32]; - for (var i = 0; i < 32; i ++) { - b2n_b2[i] = Bits2Num(8); - for (var j = 0; j < 8; j ++) { - b2n_b2[i].in[j] <== b2.bi_bits[i * 8 + (7 - j)]; - } - out[32 + i] <== b2n_b2[i].out; - } - - component b2n_b3[32]; - for (var i = 0; i < 32; i ++) { - b2n_b3[i] = Bits2Num(8); - for (var j = 0; j < 8; j ++) { - b2n_b3[i].in[j] <== b3.bi_bits[i * 8 + (7 - j)]; - } - out[64 + i] <== b2n_b3[i].out; - } -} - -// The first offset elements of in should be 0, followed by substring_length -// elements, and the rest should be 0 -template ZeroSandwich(offset, length) { - assert(offset < length); - signal input in[length]; - signal input substring_length; - - // Check that the first offset elements are 0 - component isz[offset]; - for (var i = 0; i < offset; i ++) { - isz[i] = IsZero(); - isz[i].in <== in[i]; - isz[i].out === 1; - } - - /* - length = 8 - offset = 4 - substring_length = 2 -in: [0, 0, 0, 0, 5, 6, 0, 0] -a: [0, 0, 0, 0, 0, 0, 1, 1] <- [i >= (offset + substring_length) for i in length] -b: [0, 0, 0, 0, 0, 0, 0, 0] <- [a[i] * in[i] for in in length] -check that each element from offset onwards is 0 - */ - - component gte_a[length - offset]; - component isz_b[length - offset]; - for (var i = offset; i < length; i ++) { - // TODO: save on constraints by first checking that substring_length is - // less than length (using LessThan(252) outside the loop), then using - // GreaterEqThan(log2(length)) inside the loop - gte_a[i - offset] = GreaterEqThan(252); - gte_a[i - offset].in[0] <== i; - gte_a[i - offset].in[1] <== offset + substring_length; - - isz_b[i - offset] = IsZero(); - isz_b[i - offset].in <== gte_a[i - offset].out * in[i]; - isz_b[i - offset].out === 1; - } -} +} \ No newline at end of file diff --git a/circuits/circom/hash_to_field.circom b/circuits/circom/hash_to_field.circom index 57c0213..7b0ec38 100644 --- a/circuits/circom/hash_to_field.circom +++ b/circuits/circom/hash_to_field.circom @@ -2,6 +2,7 @@ pragma circom 2.0.0; include "./constants.circom"; include "./expand_message_xmd.circom"; +// Spec https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#name-hash_to_field-implementatio template HashToField(msg_length) { signal input msg[msg_length]; signal output u[2][4]; @@ -26,6 +27,9 @@ template HashToField(msg_length) { } // Converts a 48-byte array into a 4-register BigInt modulo the secp256k1 prime +// +// 48 bytes has enough entropy so that the distribution of points is close +// enough to uniform for 128 bit security: L = ceil((ceil(log2(p)) + k) / 8) = 48 template BytesToRegisters() { signal input bytes[48]; signal output out[4]; diff --git a/circuits/circom/map_to_curve.circom b/circuits/circom/map_to_curve.circom index 894a0b7..5b7f8cc 100644 --- a/circuits/circom/map_to_curve.circom +++ b/circuits/circom/map_to_curve.circom @@ -17,16 +17,15 @@ template CMov() { signal input c; signal output out[4]; - component mux[4]; + component mux = MultiMux1(4); for (var i = 0; i < 4; i ++) { - mux[i] = Mux1(); - mux[i].c[0] <== a[i]; - mux[i].c[1] <== b[i]; - mux[i].s <== c; + mux.c[0][i] <== a[i]; + mux.c[1][i] <== b[i]; } + mux.s <== c; for (var i = 0; i < 4; i ++) { - out[i] <== mux[i].out; + out[i] <== mux.out[i]; } } @@ -170,6 +169,7 @@ template XY2Selector() { } } +// Each step corresponds to a line in the reference implementation here https://github.com/cfrg/draft-irtf-cfrg-hash-to-curve/blob/eb001eaea0f49066dad611a4c7cb2749f167b97e/poc/sswu_generic.sage#L76-L97 template MapToCurve() { signal input u[4]; signal input gx1_sqrt[4]; diff --git a/circuits/circom/test/expand_msg_xmd2_test.circom b/circuits/circom/test/expand_msg_xmd2_test.circom deleted file mode 100644 index 2bb592d..0000000 --- a/circuits/circom/test/expand_msg_xmd2_test.circom +++ /dev/null @@ -1,5 +0,0 @@ -pragma circom 2.0.0; - -include "../expand_message_xmd.circom"; - -component main = ExpandMessageXmd2(192); diff --git a/circuits/circom/test/verify_msg_prime_test.circom b/circuits/circom/test/verify_msg_prime_test.circom deleted file mode 100644 index 45256c4..0000000 --- a/circuits/circom/test/verify_msg_prime_test.circom +++ /dev/null @@ -1,5 +0,0 @@ -pragma circom 2.0.0; - -include "../expand_message_xmd.circom"; - -component main = VerifyMsgPrime(120); diff --git a/circuits/ts/__tests__/ExpandMessageXmd.test.ts b/circuits/ts/__tests__/ExpandMessageXmd.test.ts index 0ab5be0..d81d552 100644 --- a/circuits/ts/__tests__/ExpandMessageXmd.test.ts +++ b/circuits/ts/__tests__/ExpandMessageXmd.test.ts @@ -215,119 +215,4 @@ describe('ExpandMessageXmd', () => { //expect(bytes[i]).toEqual(expected[i]) //} //}) - - //it('ZeroSandwich (valid)', async () => { - //const circuit = 'zero_sandwich_test' - //const circuitInputs = stringifyBigInts({ - //in: [0, 0, 0, 0, 5, 6, 0, 0], - //substring_length: 2, - //}) - - //const witness = await genWitness(circuit, circuitInputs) - //}) - - //it('ZeroSandwich (invalid)', async () => { - //const circuit = 'zero_sandwich_test' - //const circuitInputs = stringifyBigInts({ - //in: [0, 0, 0, 0, 5, 6, 3, 0], - //substring_length: 2, - //}) - - //try { - //const witness = await genWitness(circuit, circuitInputs) - //expect(false).toBeTruthy() - //} catch { - //expect(true).toBeTruthy() - //} - //expect.assertions(1) - //}) - - //it('ZeroSandwich (invalid)', async () => { - //const circuit = 'zero_sandwich_test' - //const circuitInputs = stringifyBigInts({ - //in: [0, 0, 0, 1, 5, 6, 0, 0], - //substring_length: 2, - //}) - - //try { - //const witness = await genWitness(circuit, circuitInputs) - //expect(false).toBeTruthy() - //} catch { - //expect(true).toBeTruthy() - //} - //expect.assertions(1) - //}) - - //it('VerifyMsgPrime', async () => { - //const circuit = 'verify_msg_prime_test' - //const msg_prime = gen_msg_prime(msg) - //const b = bufToPaddedBytes(Buffer.from(msg_prime)) - ////console.log(msg_prime.length) - ////console.log(Buffer.from(expected_msg_prime).toString('hex')) - ////console.log(Buffer.from(b).toString('hex')) - - //let offset_msg_buf = Buffer.alloc(msg_prime.length) - //for (let i = 64; i < 64 + msg.length; i ++) { - //offset_msg_buf[i] = Buffer.from(msg[i - 64])[0] - //} - //let offset_msg: Number[] = [] - //for (let i = 0; i < offset_msg_buf.length; i ++) { - //offset_msg.push(Number(offset_msg_buf[i])) - //} - - //const circuitInputs = stringifyBigInts({ - //msg_prime, - //offset_msg, - //msg_length: msg.length - //}) - //const witness = await genWitness(circuit, circuitInputs) - //}) - - it('ExpandMessageXmd2', async () => { - const circuit = 'expand_msg_xmd2_test' - const msg_prime = gen_msg_prime(str_to_array(msg)) - const padded_msg_prime = bufToPaddedBytes(Buffer.from(msg_prime)) - - let offset_msg_buf = Buffer.alloc(padded_msg_prime.length) - for (let i = 64; i < 64 + msg.length; i ++) { - offset_msg_buf[i] = Buffer.from(msg[i - 64])[0] - } - let offset_msg: Number[] = [] - for (let i = 0; i < offset_msg_buf.length; i ++) { - offset_msg.push(Number(offset_msg_buf[i])) - } - - const circuitInputs = stringifyBigInts({ - msg_prime: padded_msg_prime, - offset_msg, - msg_length: msg.length, - padded_msg_prime: bufToPaddedBytes(Buffer.from(msg_prime)), - }) - const witness = await genWitness(circuit, circuitInputs) - - //const hash = gen_b0(expected_msg_prime) - - //const hash_bits = buffer2bitArray(Buffer.from(hash)) - - //const bits: number[] = [] - //for (let i = 0; i < 256; i ++) { - ////const out = Number(await getSignalByName(circuit, witness, 'main.hash[' + i.toString() + ']')) - //const out = Number(witness[1 + i]) - //bits.push(out) - //} - //expect(bits.join('')).toEqual(hash_bits.join('')) - - const bytes: number[] = [] - for (let i = 0; i < 96; i ++) { - //const out = Number(await getSignalByName(circuit, witness, 'main.out[' + i.toString() + ']')) - const out = Number(witness[1 + i]) - bytes.push(out) - } - - const expected = expand_msg_xmd(str_to_array(msg)) - expect(expected.length).toEqual(bytes.length) - for (let i = 0; i < 96; i ++) { - expect(bytes[i]).toEqual(expected[i]) - } - }) })