Skip to content

Commit 8facb07

Browse files
mathsjamescopybara-github
authored andcommitted
Allow early rejection during multiplication by challenge matrices.
This will (on average) roughly halve the run time spent on multiplication by challenge matrices, reducing the run time of create_client_message by about 8%. PiperOrigin-RevId: 846194679
1 parent e9f13a3 commit 8facb07

File tree

1 file changed

+67
-39
lines changed

1 file changed

+67
-39
lines changed

willow/src/zk/rlwe_relation.rs

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -327,28 +327,41 @@ pub fn generate_challenge_matrix(
327327
result
328328
}
329329

330-
// Multiplies a 128 by n matrix m and a length n vector v.
331-
// m is a binary matrix each column of which has entries given by the bits of a single entry in the
332-
// input vector m.
333-
// Both the output and v are vectors of 128 bit signed integers.
334-
pub fn multiply_by_challenge_matrix(
330+
// Applies a challenge matrix R1-R2 to a vector v and checks if the result satisfies the conditions
331+
// for not needing to be rejected. An internal error is returned in the event of rejection otherwise
332+
// the resulting vector z is returned.
333+
//
334+
// To understand the rejection conditions see the comment for generate_range_product.
335+
pub fn try_matrices_and_compute_z(
335336
v: &[i128],
336-
m: &[u128],
337+
R1: &[u128],
338+
R2: &[u128],
339+
y: &[i128],
340+
half_loose_bound: i128,
337341
) -> Result<Vec<i128>, status::StatusError> {
338342
let n = v.len();
339-
if m.len() != n {
340-
return Err(status::failed_precondition("m and v have different lengths".to_string()));
343+
if n != R1.len() || n != R2.len() {
344+
return Err(status::failed_precondition(
345+
"R1, R2, and v must have the same length".to_string(),
346+
));
341347
}
342-
343-
let mut result = vec![0 as i128; 128];
344-
for i in 0..n {
345-
for j in 0..128 {
346-
if m[i] & (1u128 << j) != 0 {
347-
result[j] += v[i];
348+
let mut z = vec![0 as i128; 128];
349+
for j in 0..128 {
350+
let mut u = 0i128;
351+
for i in 0..n {
352+
if R1[i] & (1u128 << j) != 0 {
353+
u += v[i];
348354
}
355+
if R2[i] & (1u128 << j) != 0 {
356+
u -= v[i];
357+
}
358+
}
359+
z[j] = u + y[j];
360+
if u.abs() > half_loose_bound / 128 || z[j].abs() > half_loose_bound {
361+
return Err(status::internal("Sample Rejected"));
349362
}
350363
}
351-
Ok(result)
364+
Ok(z)
352365
}
353366

354367
// Linearly combines the 128 vector challenges of a challenge matrix into a single vector challenge
@@ -459,7 +472,6 @@ fn generate_range_product(
459472
let mut z = vec![0 as i128; 128];
460473
let mut attempts = 0;
461474
loop {
462-
let mut done = true;
463475
attempts += 1;
464476
y = (0..128).map(|_| (rng.gen_range(0..possible_y) as i128)).collect();
465477
for i in 0..128 {
@@ -474,21 +486,9 @@ fn generate_range_product(
474486
// subtracting the other we get a challenge matrix with the correct distribution.
475487
R1 = generate_challenge_matrix(transcript, challenge_label, v.len());
476488
R2 = generate_challenge_matrix(transcript, challenge_label, v.len());
477-
let u1 = multiply_by_challenge_matrix(v, &R1)?;
478-
let u2 = multiply_by_challenge_matrix(v, &R2)?;
479-
for i in 0..128 {
480-
let u = u1[i] - u2[i];
481-
if u.abs() > half_loose_bound / 128 {
482-
done = false;
483-
break;
484-
}
485-
z[i] = u + y[i];
486-
if z[i].abs() > half_loose_bound {
487-
done = false;
488-
break;
489-
}
490-
}
491-
if done {
489+
let z_or_error = try_matrices_and_compute_z(v, &R1, &R2, &y, half_loose_bound);
490+
if z_or_error.is_ok() {
491+
z = z_or_error.unwrap();
492492
break;
493493
}
494494
if attempts > 1000 {
@@ -1244,16 +1244,44 @@ mod tests {
12441244
}
12451245

12461246
#[test]
1247-
fn test_multiply_by_challenge_matrix_basic_case() -> googletest::Result<()> {
1248-
let v = &[10i128, 20i128];
1249-
let m = &[(1u128 << 0) | (1u128 << 2), (1u128 << 1) | (1u128 << 2)];
1247+
fn test_try_matrices_and_compute_z_valid() -> googletest::Result<()> {
1248+
let v = [1, -2, 3, -4];
1249+
let R1 = [1, 2, 3, 4];
1250+
let R2 = [4, 3, 2, 1];
1251+
let y = [1; 128];
1252+
let half_loose_bound = 10000;
1253+
let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound)?;
1254+
let mut expected_z = vec![1; 128];
1255+
expected_z[0] += 10;
1256+
expected_z[1] += 0;
1257+
expected_z[2] += -5;
1258+
verify_eq!(result, expected_z)?;
1259+
Ok(())
1260+
}
12501261

1251-
let mut expected_result = vec![0i128; 128];
1252-
expected_result[0] = 10;
1253-
expected_result[1] = 20;
1254-
expected_result[2] = 30;
1262+
#[test]
1263+
fn test_try_matrices_and_compute_z_mismatched_lengths() -> googletest::Result<()> {
1264+
let v = [1, -2, 3, -4];
1265+
let R1 = [1, 2, 3];
1266+
let R2 = [4, 3, 2, 1];
1267+
let y = [1; 128];
1268+
let half_loose_bound = 1000;
1269+
let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound);
1270+
assert!(result.is_err());
1271+
verify_eq!(result.unwrap_err().message(), "R1, R2, and v must have the same length")?;
1272+
Ok(())
1273+
}
12551274

1256-
assert_eq!(multiply_by_challenge_matrix(v, m).unwrap(), expected_result);
1275+
#[test]
1276+
fn test_try_matrices_and_compute_z_sample_rejected() -> googletest::Result<()> {
1277+
let v = [1000, -2000, 3000, -4000];
1278+
let R1 = [1, 2, 3, 4];
1279+
let R2 = [4, 3, 2, 1];
1280+
let y = [1; 128];
1281+
let half_loose_bound = 100000;
1282+
let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound);
1283+
assert!(result.is_err());
1284+
verify_eq!(result.unwrap_err().message(), "Sample Rejected")?;
12571285
Ok(())
12581286
}
12591287

0 commit comments

Comments
 (0)