Skip to content

Commit 34bbaa0

Browse files
mathsjamescopybara-github
authored andcommitted
Speed up flatten_challenge_matrix by using u128 instead of Scalar.
Switch to using independent random psi instead of powers of some base psi and implement most of the computation in flatten_challenge_matrix with u128 instead of Scalar. The change to psi is necessary to make the implementation in u128 possible. This change reduces the runtime of flatten_challenge_matrix by a factor of 3, thus speeding up the client message creation by 17%. PiperOrigin-RevId: 845807650
1 parent 53b301b commit 34bbaa0

File tree

1 file changed

+93
-79
lines changed

1 file changed

+93
-79
lines changed

willow/src/zk/rlwe_relation.rs

Lines changed: 93 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,15 @@ fn create_public_vec(
266266
fn update_public_vec_for_range_proof(
267267
public_vec: &mut Vec<Scalar>,
268268
result: &mut Scalar,
269-
R_r: &Vec<Scalar>,
270-
R_e: &Vec<Scalar>,
271-
R_vw: &Vec<Scalar>,
272-
z_r: &Vec<Scalar>,
273-
z_e: &Vec<Scalar>,
274-
z_vw: &Vec<Scalar>,
275-
psi_r: Scalar,
276-
psi_e: Scalar,
277-
psi_vw: Scalar,
269+
R_r: &[Scalar],
270+
R_e: &[Scalar],
271+
R_vw: &[Scalar],
272+
z_r: &[Scalar],
273+
z_e: &[Scalar],
274+
z_vw: &[Scalar],
275+
psi_r: &[Scalar],
276+
psi_e: &[Scalar],
277+
psi_vw: &[Scalar],
278278
n: usize,
279279
range_comm_offset: usize,
280280
samples_required: usize,
@@ -298,20 +298,14 @@ fn update_public_vec_for_range_proof(
298298

299299
// The range proofs equation also involves length 128 innerproducts involving the relevant
300300
// psi these are included in the last 3*128 entries of the inner product vectors.
301-
let mut phi_psi_r_pow = phi;
302-
let mut phi2_psi_e_pow = phi2;
303-
let mut phi3_psi_vw_pow = phi3;
304301
for i in 0..128 {
305-
public_vec[i + range_comm_offset] = phi_psi_r_pow;
306-
public_vec[i + range_comm_offset + 128] = phi2_psi_e_pow;
307-
public_vec[i + range_comm_offset + 256] = phi3_psi_vw_pow;
302+
public_vec[i + range_comm_offset] = phi * Scalar::from(psi_r[i]);
303+
public_vec[i + range_comm_offset + 128] = phi2 * Scalar::from(psi_e[i]);
304+
public_vec[i + range_comm_offset + 256] = phi3 * Scalar::from(psi_vw[i]);
308305
// Add contributions of the range proofs to the overall inner product result.
309-
*result += z_r[i] * phi_psi_r_pow;
310-
*result += z_e[i] * phi2_psi_e_pow;
311-
*result += z_vw[i] * phi3_psi_vw_pow;
312-
phi_psi_r_pow *= psi_r;
313-
phi2_psi_e_pow *= psi_e;
314-
phi3_psi_vw_pow *= psi_vw;
306+
*result += z_r[i] * public_vec[i + range_comm_offset];
307+
*result += z_e[i] * public_vec[i + range_comm_offset + 128];
308+
*result += z_vw[i] * public_vec[i + range_comm_offset + 256];
315309
}
316310
}
317311

@@ -364,33 +358,38 @@ pub fn flatten_challenge_matrix(
364358
R1: Vec<u128>,
365359
R2: Vec<u128>,
366360
challenge_label: &'static [u8],
367-
) -> Result<(Vec<Scalar>, Scalar), status::StatusError> {
361+
) -> Result<(Vec<Scalar>, Vec<Scalar>), status::StatusError> {
368362
let n = R1.len();
369363
if n != R2.len() {
370364
return Err(status::failed_precondition("R1 and R2 have different lengths".to_string()));
371365
}
372366

373-
let mut buf = [0u8; 64];
374-
transcript.challenge_bytes(challenge_label, &mut buf);
375-
let psi = Scalar::from_bytes_mod_order_wide(&buf);
376-
377-
let mut R = vec![Scalar::from(0 as u64); n];
378-
let mut psi_powers = [Scalar::from(1 as u64); 128];
379-
for j in 1..128 {
380-
psi_powers[j] = psi_powers[j - 1] * psi;
367+
let mut Rplus = vec![0u128; n];
368+
let mut Rminus = vec![0u128; n];
369+
let mut Rscalar = vec![Scalar::from(0u64); n];
370+
371+
let mut psi = [0u128; 128];
372+
let mut psi_scalar = vec![Scalar::from(0 as u64); 128];
373+
let mut buf = [0u8; 16];
374+
for j in 0..128 {
375+
transcript.challenge_bytes(challenge_label, &mut buf);
376+
// We only take challenges up to 2^121 so that the sum of 128 of them will fit in a u128.
377+
psi[j] = u128::from_le_bytes(buf) >> 7;
378+
psi_scalar[j] = Scalar::from(psi[j]);
381379
}
382380
for i in 0..n {
383381
for j in 0..128 {
384382
if R1[i] & (1u128 << j) != 0 {
385-
R[i] += psi_powers[j];
383+
Rplus[i] += psi[j];
386384
}
387385
if R2[i] & (1u128 << j) != 0 {
388-
R[i] -= psi_powers[j];
386+
Rminus[i] += psi[j];
389387
}
390388
}
389+
Rscalar[i] = Scalar::from(Rplus[i]) - Scalar::from(Rminus[i]);
391390
}
392391

393-
Ok((R, psi))
392+
Ok((Rscalar, psi_scalar))
394393
}
395394

396395
// Check that loose_bound = bound*2500*sqrt(v.len()+1) fits within an i128.
@@ -407,6 +406,16 @@ fn check_loose_bound_will_not_overflow(bound: u128, n: usize) -> Result<(), stat
407406
Ok(())
408407
}
409408

409+
// Struct to hold the results of the generate_range_product function.
410+
struct RangeProductMetadata {
411+
R: Vec<Scalar>,
412+
comm_y: RistrettoPoint,
413+
y: Vec<Scalar>,
414+
delta_y: Scalar,
415+
psi: Vec<Scalar>,
416+
z: Vec<Scalar>,
417+
}
418+
410419
// Return the inner product that needs to be checked for the range proof, the commitment to y that
411420
// the verifier will need to verify it and the blinding information required for the proof.
412421
//
@@ -429,10 +438,7 @@ fn generate_range_product(
429438
start: usize,
430439
transcript: &mut (impl Transcript + Clone),
431440
challenge_label: &'static [u8],
432-
) -> Result<
433-
(Vec<Scalar>, RistrettoPoint, Vec<Scalar>, Scalar, Scalar, Vec<Scalar>),
434-
status::StatusError,
435-
> {
441+
) -> Result<RangeProductMetadata, status::StatusError> {
436442
// Check that computing loose bound does not result in an overflow.
437443
check_loose_bound_will_not_overflow(bound, v.len())?;
438444

@@ -512,17 +518,19 @@ fn generate_range_product(
512518
})
513519
.collect();
514520

515-
Ok((R, comm_y, scalar_y, delta_y, psi, scalar_z))
521+
Ok(RangeProductMetadata { R, comm_y, y: scalar_y, delta_y, psi, z: scalar_z })
516522
}
517523

524+
// Verifies the z bound and returns the linear combination of the 128 rows of the range proof
525+
// projection matrix R and a vector psi of the coefficients used in that linear combination.
518526
fn generate_range_product_for_verification_and_verify_z_bound(
519527
n: usize,
520528
bound: u128,
521529
comm_y: RistrettoPoint,
522-
z: &Vec<Scalar>,
530+
z: &[Scalar],
523531
transcript: &mut impl Transcript,
524532
challenge_label: &'static [u8],
525-
) -> Result<(Vec<Scalar>, Scalar), status::StatusError> {
533+
) -> Result<(Vec<Scalar>, Vec<Scalar>), status::StatusError> {
526534
// Check that computing loose bound does not result in an overflow.
527535
check_loose_bound_will_not_overflow(bound, n)?;
528536

@@ -762,23 +770,23 @@ impl<'a> ZeroKnowledgeProver<RlweRelationProofStatement<'a>, RlweRelationProofWi
762770
// Get inner products to prove for range proofs. We then need to check
763771
// <R_r,r> + <psi_r^128,y_r> = <psi_r^128,z_r> mod P etc.
764772
// This is explained in more detail in the comment above generate_range_product.
765-
let (R_r, comm_y_r, y_r, delta_y_r, psi_r, z_r) = generate_range_product(
773+
let range_product_r = generate_range_product(
766774
&signed_r,
767775
bound_r,
768776
&self.prover,
769777
range_comm_offset,
770778
transcript,
771779
b"range matrix r",
772780
)?;
773-
let (R_e, comm_y_e, y_e, delta_y_e, psi_e, z_e) = generate_range_product(
781+
let range_product_e = generate_range_product(
774782
&signed_e,
775783
bound_e,
776784
&self.prover,
777785
range_comm_offset + 128,
778786
transcript,
779787
b"range matrix e",
780788
)?;
781-
let (R_vw, comm_y_vw, y_vw, delta_y_vw, psi_vw, z_vw) = generate_range_product(
789+
let range_product_vw = generate_range_product(
782790
&signed_vw,
783791
q * (n as u128),
784792
&self.prover,
@@ -792,15 +800,15 @@ impl<'a> ZeroKnowledgeProver<RlweRelationProofStatement<'a>, RlweRelationProofWi
792800
update_public_vec_for_range_proof(
793801
&mut public_vec,
794802
&mut result,
795-
&R_r,
796-
&R_e,
797-
&R_vw,
798-
&z_r,
799-
&z_e,
800-
&z_vw,
801-
psi_r,
802-
psi_e,
803-
psi_vw,
803+
&range_product_r.R,
804+
&range_product_e.R,
805+
&range_product_vw.R,
806+
&range_product_r.z,
807+
&range_product_e.z,
808+
&range_product_vw.z,
809+
&range_product_r.psi,
810+
&range_product_e.psi,
811+
&range_product_vw.psi,
804812
n,
805813
range_comm_offset,
806814
samples_required,
@@ -818,13 +826,21 @@ impl<'a> ZeroKnowledgeProver<RlweRelationProofStatement<'a>, RlweRelationProofWi
818826
private_vec[i + n + n + n] = scalar_wrho_vec[i];
819827
}
820828
for i in 0..128 {
821-
private_vec[i + range_comm_offset] = y_r[i];
822-
private_vec[i + range_comm_offset + 128] = y_e[i];
823-
private_vec[i + range_comm_offset + 256] = y_vw[i];
829+
private_vec[i + range_comm_offset] = range_product_r.y[i];
830+
private_vec[i + range_comm_offset + 128] = range_product_e.y[i];
831+
private_vec[i + range_comm_offset + 256] = range_product_vw.y[i];
824832
}
825833

826-
let private_vec_comm = comm_rev + comm_wrho + comm_y_r + comm_y_e + comm_y_vw;
827-
let blinding_factor = delta_rev + delta_w + delta_y_r + delta_y_e + delta_y_vw;
834+
let private_vec_comm = comm_rev
835+
+ comm_wrho
836+
+ range_product_r.comm_y
837+
+ range_product_e.comm_y
838+
+ range_product_vw.comm_y;
839+
let blinding_factor = delta_rev
840+
+ delta_w
841+
+ range_product_r.delta_y
842+
+ range_product_e.delta_y
843+
+ range_product_vw.delta_y;
828844

829845
// Set up linear product statement and prove it
830846
let lip_statement = LinearInnerProductProofStatement {
@@ -841,12 +857,12 @@ impl<'a> ZeroKnowledgeProver<RlweRelationProofStatement<'a>, RlweRelationProofWi
841857
Ok(RlweRelationProof {
842858
comm_rev: comm_rev.compress(),
843859
comm_wrho: comm_wrho.compress(),
844-
comm_y_r: comm_y_r.compress(),
845-
comm_y_e: comm_y_e.compress(),
846-
comm_y_vw: comm_y_vw.compress(),
847-
z_r: z_r,
848-
z_e: z_e,
849-
z_vw: z_vw,
860+
comm_y_r: range_product_r.comm_y.compress(),
861+
comm_y_e: range_product_e.comm_y.compress(),
862+
comm_y_vw: range_product_vw.comm_y.compress(),
863+
z_r: range_product_r.z,
864+
z_e: range_product_e.z,
865+
z_vw: range_product_vw.z,
850866
lip_proof: lip_proof,
851867
})
852868
}
@@ -977,9 +993,9 @@ impl<'a> ZeroKnowledgeVerifier<RlweRelationProofStatement<'a>, RlweRelationProof
977993
&proof.z_r,
978994
&proof.z_e,
979995
&proof.z_vw,
980-
psi_r,
981-
psi_e,
982-
psi_vw,
996+
&psi_r,
997+
&psi_e,
998+
&psi_vw,
983999
n,
9841000
range_comm_offset,
9851001
samples_required,
@@ -1247,33 +1263,31 @@ mod tests {
12471263
let v = [1, -2, 3, -4];
12481264
let prover = LinearInnerProductProver::new(b"42", 132);
12491265
let mut transcript = MerlinTranscript::new(b"42");
1250-
let (R, comm_y, y, delta_y, psi, z) =
1266+
let result =
12511267
generate_range_product(&v, bound, &prover, 4, &mut transcript, b"test vector")?;
12521268
let mut private_vec = [Scalar::from(0u128); 132];
12531269
for i in 0..4 {
12541270
private_vec[i] = Scalar::from((v[i] + (bound as i128)) as u128) - Scalar::from(bound);
12551271
}
1256-
for i in 4..132 {
1257-
private_vec[i] = y[i - 4];
1272+
for i in 0..128 {
1273+
private_vec[i + 4] = result.y[i];
12581274
}
12591275
let mut public_vec = [Scalar::from(0u128); 132];
12601276
for i in 0..4 {
1261-
public_vec[i] = R[i];
1277+
public_vec[i] = result.R[i];
12621278
}
1263-
let mut psi_pow = Scalar::from(1u128);
1264-
let mut result = Scalar::from(0u128);
1265-
for i in 4..132 {
1266-
public_vec[i] = psi_pow;
1267-
result += z[i - 4] * psi_pow;
1268-
psi_pow *= psi;
1279+
let mut inner_product = Scalar::from(0u128);
1280+
for i in 0..128 {
1281+
public_vec[i + 4] = result.psi[i];
1282+
inner_product += result.z[i] * result.psi[i];
12691283
}
12701284
let mut expected_result = Scalar::from(0u128);
12711285
for j in 0..132 {
12721286
expected_result += public_vec[j] * private_vec[j];
12731287
}
1274-
assert_eq!(result, expected_result);
1275-
let expected_comm_y = prover.commit_partial(&y, delta_y, 4, 132)?;
1276-
assert_eq!(comm_y, expected_comm_y);
1288+
assert_eq!(inner_product, expected_result);
1289+
let expected_comm_y = prover.commit_partial(&result.y, result.delta_y, 4, 132)?;
1290+
assert_eq!(result.comm_y, expected_comm_y);
12771291
Ok(())
12781292
}
12791293

0 commit comments

Comments
 (0)