Skip to content

Commit 8506675

Browse files
mathsjamescopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 846194679
1 parent 7965712 commit 8506675

File tree

2 files changed

+105
-83
lines changed

2 files changed

+105
-83
lines changed

willow/src/zk/linear_ip.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub struct LinearInnerProductParameters {
3636
F: RistrettoPoint,
3737
F_: RistrettoPoint,
3838
G: Vec<RistrettoPoint>,
39+
seed: Vec<u8>,
3940
}
4041

4142
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar {
@@ -59,6 +60,7 @@ fn common_setup(length: usize, parameter_seed: &[u8]) -> LinearInnerProductParam
5960
)
6061
})
6162
.collect(),
63+
seed: parameter_seed.to_vec(),
6264
}
6365
}
6466

@@ -67,11 +69,9 @@ fn append_params_to_transcript(
6769
params: &LinearInnerProductParameters,
6870
) {
6971
transcript.append_u64(b"n", params.n as u64);
70-
for G_i in &params.G {
71-
transcript.append_message(b"G_i", G_i.compress().as_bytes());
72-
}
73-
transcript.append_message(b"F", params.F.compress().as_bytes());
74-
transcript.append_message(b"F_", params.F_.compress().as_bytes());
72+
// We append the seed not the resulting params themselves because appending that many params
73+
// more than doubles the run time of both prove and verify.
74+
transcript.append_message(b"seed", &params.seed);
7575
}
7676

7777
fn validate_and_append_point(

willow/src/zk/rlwe_relation.rs

Lines changed: 100 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ fn update_public_vec_for_range_proof(
272272
z_r: &Vec<Scalar>,
273273
z_e: &Vec<Scalar>,
274274
z_vw: &Vec<Scalar>,
275-
psi_r: Scalar,
276-
psi_e: Scalar,
277-
psi_vw: Scalar,
275+
psi_r: &Vec<Scalar>,
276+
psi_e: &Vec<Scalar>,
277+
psi_vw: &Vec<Scalar>,
278278
n: usize,
279279
range_comm_offset: usize,
280280
samples_required: usize,
@@ -298,20 +298,14 @@ fn update_public_vec_for_range_proof(
298298

299299
// The range proofs equation also involves length 128 innerproducts involving the relevant
300300
// psi these are included in the last 3*128 entries of the inner product vectors.
301-
let mut phi_psi_r_pow = phi;
302-
let mut phi2_psi_e_pow = phi2;
303-
let mut phi3_psi_vw_pow = phi3;
304301
for i in 0..128 {
305-
public_vec[i + range_comm_offset] = phi_psi_r_pow;
306-
public_vec[i + range_comm_offset + 128] = phi2_psi_e_pow;
307-
public_vec[i + range_comm_offset + 256] = phi3_psi_vw_pow;
302+
public_vec[i + range_comm_offset] = phi * Scalar::from(psi_r[i]);
303+
public_vec[i + range_comm_offset + 128] = phi2 * Scalar::from(psi_e[i]);
304+
public_vec[i + range_comm_offset + 256] = phi3 * Scalar::from(psi_vw[i]);
308305
// Add contributions of the range proofs to the overall inner product result.
309-
*result += z_r[i] * phi_psi_r_pow;
310-
*result += z_e[i] * phi2_psi_e_pow;
311-
*result += z_vw[i] * phi3_psi_vw_pow;
312-
phi_psi_r_pow *= psi_r;
313-
phi2_psi_e_pow *= psi_e;
314-
phi3_psi_vw_pow *= psi_vw;
306+
*result += z_r[i] * public_vec[i + range_comm_offset];
307+
*result += z_e[i] * public_vec[i + range_comm_offset + 128];
308+
*result += z_vw[i] * public_vec[i + range_comm_offset + 256];
315309
}
316310
}
317311

@@ -333,28 +327,39 @@ pub fn generate_challenge_matrix(
333327
result
334328
}
335329

336-
// Multiplies a 128 by n matrix m and a length n vector v.
337-
// m is a binary matrix each column of which has entries given by the bits of a single entry in the
338-
// input vector m.
339-
// Both the output and v are vectors of 128 bit signed integers.
340-
pub fn multiply_by_challenge_matrix(
330+
// Applies a challenge matrix R1-R2 to a vector v and checks if the result satisfies the conditions
331+
// for not needing to be rejected. An internal error is returned in the event of rejection otherwise
332+
// the resulting vector z is returned.
333+
pub fn try_matrices_and_compute_z(
341334
v: &[i128],
342-
m: &[u128],
335+
R1: &[u128],
336+
R2: &[u128],
337+
y: &[i128],
338+
half_loose_bound: i128,
343339
) -> Result<Vec<i128>, status::StatusError> {
344340
let n = v.len();
345-
if m.len() != n {
346-
return Err(status::failed_precondition("m and v have different lengths".to_string()));
341+
if n != R1.len() || n != R2.len() {
342+
return Err(status::failed_precondition(
343+
"R1, R2, and v must have the same length".to_string(),
344+
));
347345
}
348-
349-
let mut result = vec![0 as i128; 128];
350-
for i in 0..n {
351-
for j in 0..128 {
352-
if m[i] & (1u128 << j) != 0 {
353-
result[j] += v[i];
346+
let mut z = vec![0 as i128; 128];
347+
for j in 0..128 {
348+
let mut u = 0i128;
349+
for i in 0..n {
350+
if R1[i] & (1u128 << j) != 0 {
351+
u += v[i];
352+
}
353+
if R2[i] & (1u128 << j) != 0 {
354+
u -= v[i];
354355
}
355356
}
357+
z[j] = u + y[j];
358+
if u.abs() > half_loose_bound / 128 || z[j].abs() > half_loose_bound {
359+
return Err(status::internal("Sample Rejected"));
360+
}
356361
}
357-
Ok(result)
362+
Ok(z)
358363
}
359364

360365
// Linearly combines the 128 vector challenges of a challenge matrix into a single vector challenge
@@ -364,33 +369,37 @@ pub fn flatten_challenge_matrix(
364369
R1: Vec<u128>,
365370
R2: Vec<u128>,
366371
challenge_label: &'static [u8],
367-
) -> Result<(Vec<Scalar>, Scalar), status::StatusError> {
372+
) -> Result<(Vec<Scalar>, Vec<Scalar>), status::StatusError> {
368373
let n = R1.len();
369374
if n != R2.len() {
370375
return Err(status::failed_precondition("R1 and R2 have different lengths".to_string()));
371376
}
372377

373-
let mut buf = [0u8; 64];
374-
transcript.challenge_bytes(challenge_label, &mut buf);
375-
let psi = Scalar::from_bytes_mod_order_wide(&buf);
376-
377-
let mut R = vec![Scalar::from(0 as u64); n];
378-
let mut psi_powers = [Scalar::from(1 as u64); 128];
379-
for j in 1..128 {
380-
psi_powers[j] = psi_powers[j - 1] * psi;
378+
let mut Rplus = vec![0u128; n];
379+
let mut Rminus = vec![0u128; n];
380+
let mut Rscalar = vec![Scalar::from(0 as u64); n];
381+
382+
let mut psi = [0u128; 128];
383+
let mut psi_scalar = vec![Scalar::from(0 as u64); 128];
384+
let mut buf = [0u8; 16];
385+
for j in 0..128 {
386+
transcript.challenge_bytes(challenge_label, &mut buf);
387+
psi[j] = u128::from_le_bytes(buf).rem_euclid(1u128 << 120);
388+
psi_scalar[j] = Scalar::from(psi[j]);
381389
}
382390
for i in 0..n {
383391
for j in 0..128 {
384392
if R1[i] & (1u128 << j) != 0 {
385-
R[i] += psi_powers[j];
393+
Rplus[i] += psi[j];
386394
}
387395
if R2[i] & (1u128 << j) != 0 {
388-
R[i] -= psi_powers[j];
396+
Rminus[i] += psi[j];
389397
}
390398
}
399+
Rscalar[i] = Scalar::from(Rplus[i]) - Scalar::from(Rminus[i]);
391400
}
392401

393-
Ok((R, psi))
402+
Ok((Rscalar, psi_scalar))
394403
}
395404

396405
// Check that loose_bound = bound*2500*sqrt(v.len()+1) fits within an i128.
@@ -430,7 +439,7 @@ fn generate_range_product(
430439
transcript: &mut (impl Transcript + Clone),
431440
challenge_label: &'static [u8],
432441
) -> Result<
433-
(Vec<Scalar>, RistrettoPoint, Vec<Scalar>, Scalar, Scalar, Vec<Scalar>),
442+
(Vec<Scalar>, RistrettoPoint, Vec<Scalar>, Scalar, Vec<Scalar>, Vec<Scalar>),
434443
status::StatusError,
435444
> {
436445
// Check that computing loose bound does not result in an overflow.
@@ -453,7 +462,6 @@ fn generate_range_product(
453462
let mut z = vec![0 as i128; 128];
454463
let mut attempts = 0;
455464
loop {
456-
let mut done = true;
457465
attempts += 1;
458466
y = (0..128).map(|_| (rng.gen_range(0..possible_y) as i128)).collect();
459467
for i in 0..128 {
@@ -468,21 +476,9 @@ fn generate_range_product(
468476
// subtracting the other we get a challenge matrix with the correct distribution.
469477
R1 = generate_challenge_matrix(transcript, challenge_label, v.len());
470478
R2 = generate_challenge_matrix(transcript, challenge_label, v.len());
471-
let u1 = multiply_by_challenge_matrix(v, &R1)?;
472-
let u2 = multiply_by_challenge_matrix(v, &R2)?;
473-
for i in 0..128 {
474-
let u = u1[i] - u2[i];
475-
if u.abs() > half_loose_bound / 128 {
476-
done = false;
477-
break;
478-
}
479-
z[i] = u + y[i];
480-
if z[i].abs() > half_loose_bound {
481-
done = false;
482-
break;
483-
}
484-
}
485-
if done {
479+
let z_or_error = try_matrices_and_compute_z(v, &R1, &R2, &y, half_loose_bound);
480+
if z_or_error.is_ok() {
481+
z = z_or_error.unwrap();
486482
break;
487483
}
488484
if attempts > 1000 {
@@ -522,7 +518,7 @@ fn generate_range_product_for_verification_and_verify_z_bound(
522518
z: &Vec<Scalar>,
523519
transcript: &mut impl Transcript,
524520
challenge_label: &'static [u8],
525-
) -> Result<(Vec<Scalar>, Scalar), status::StatusError> {
521+
) -> Result<(Vec<Scalar>, Vec<Scalar>), status::StatusError> {
526522
// Check that computing loose bound does not result in an overflow.
527523
check_loose_bound_will_not_overflow(bound, n)?;
528524

@@ -798,9 +794,9 @@ impl<'a> ZeroKnowledgeProver<RlweRelationProofStatement<'a>, RlweRelationProofWi
798794
&z_r,
799795
&z_e,
800796
&z_vw,
801-
psi_r,
802-
psi_e,
803-
psi_vw,
797+
&psi_r,
798+
&psi_e,
799+
&psi_vw,
804800
n,
805801
range_comm_offset,
806802
samples_required,
@@ -977,9 +973,9 @@ impl<'a> ZeroKnowledgeVerifier<RlweRelationProofStatement<'a>, RlweRelationProof
977973
&proof.z_r,
978974
&proof.z_e,
979975
&proof.z_vw,
980-
psi_r,
981-
psi_e,
982-
psi_vw,
976+
&psi_r,
977+
&psi_e,
978+
&psi_vw,
983979
n,
984980
range_comm_offset,
985981
samples_required,
@@ -1228,16 +1224,44 @@ mod tests {
12281224
}
12291225

12301226
#[test]
1231-
fn test_multiply_by_challenge_matrix_basic_case() -> googletest::Result<()> {
1232-
let v = &[10i128, 20i128];
1233-
let m = &[(1u128 << 0) | (1u128 << 2), (1u128 << 1) | (1u128 << 2)];
1227+
fn test_try_matrices_and_compute_z_valid() -> googletest::Result<()> {
1228+
let v = [1, -2, 3, -4];
1229+
let R1 = [1, 2, 3, 4];
1230+
let R2 = [4, 3, 2, 1];
1231+
let y = [1; 128];
1232+
let half_loose_bound = 10000;
1233+
let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound)?;
1234+
let mut expected_z = vec![1; 128];
1235+
expected_z[0] += 10;
1236+
expected_z[1] += 0;
1237+
expected_z[2] += -5;
1238+
verify_eq!(result, expected_z)?;
1239+
Ok(())
1240+
}
12341241

1235-
let mut expected_result = vec![0i128; 128];
1236-
expected_result[0] = 10;
1237-
expected_result[1] = 20;
1238-
expected_result[2] = 30;
1242+
#[test]
1243+
fn test_try_matrices_and_compute_z_mismatched_lengths() -> googletest::Result<()> {
1244+
let v = [1, -2, 3, -4];
1245+
let R1 = [1, 2, 3];
1246+
let R2 = [4, 3, 2, 1];
1247+
let y = [1; 128];
1248+
let half_loose_bound = 1000;
1249+
let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound);
1250+
assert!(result.is_err());
1251+
verify_eq!(result.unwrap_err().message(), "R1, R2, and v must have the same length")?;
1252+
Ok(())
1253+
}
12391254

1240-
assert_eq!(multiply_by_challenge_matrix(v, m).unwrap(), expected_result);
1255+
#[test]
1256+
fn test_try_matrices_and_compute_z_sample_rejected() -> googletest::Result<()> {
1257+
let v = [1000, -2000, 3000, -4000];
1258+
let R1 = [1, 2, 3, 4];
1259+
let R2 = [4, 3, 2, 1];
1260+
let y = [1; 128];
1261+
let half_loose_bound = 100000;
1262+
let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound);
1263+
assert!(result.is_err());
1264+
verify_eq!(result.unwrap_err().message(), "Sample Rejected")?;
12411265
Ok(())
12421266
}
12431267

@@ -1260,12 +1284,10 @@ mod tests {
12601284
for i in 0..4 {
12611285
public_vec[i] = R[i];
12621286
}
1263-
let mut psi_pow = Scalar::from(1u128);
12641287
let mut result = Scalar::from(0u128);
12651288
for i in 4..132 {
1266-
public_vec[i] = psi_pow;
1267-
result += z[i - 4] * psi_pow;
1268-
psi_pow *= psi;
1289+
public_vec[i] = psi[i - 4];
1290+
result += z[i - 4] * psi[i - 4];
12691291
}
12701292
let mut expected_result = Scalar::from(0u128);
12711293
for j in 0..132 {

0 commit comments

Comments
 (0)