Skip to content

Commit 8d07ce1

Browse files
authored
Generalize shift verifier to be symbolic over Channel::Elem (#1444)
## Summary - Remove rayon multithreading from shift `evaluate_matrices` (verifier should be single-threaded) - Generalize `verify_iop` to operate symbolically over `Channel::Elem`, using `_scalars` variants of math functions and replacing `FieldBuffer` with plain `Vec`/slice ## Test plan ``` cargo test -p binius-verifier -p binius-prover ```
1 parent efec4da commit 8d07ce1

File tree

7 files changed

+186
-251
lines changed

7 files changed

+186
-251
lines changed

crates/math/src/inner_product.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use std::{iter, ops::Deref};
44

5-
use binius_field::{ExtensionField, Field, PackedField};
5+
use binius_field::{ExtensionField, Field, FieldOps, PackedField};
66
use binius_utils::rayon::prelude::*;
77

88
use crate::FieldBuffer;
@@ -15,6 +15,14 @@ pub fn inner_product<F: Field>(
1515
inner_product_subfield(a, b)
1616
}
1717

18+
#[inline]
19+
pub fn inner_product_scalars<F: FieldOps>(
20+
a: impl IntoIterator<Item = F>,
21+
b: impl IntoIterator<Item = F>,
22+
) -> F {
23+
itertools::zip_eq(a, b).map(|(a_i, b_i)| b_i * a_i).sum()
24+
}
25+
1826
#[inline]
1927
pub fn inner_product_subfield<F, FSub>(
2028
a: impl IntoIterator<Item = FSub>,

crates/prover/src/protocols/shift/monster.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ where
152152
]
153153
.map(|r_zhat_prime| {
154154
let l_tilde = lagrange_evals(&subspace, r_zhat_prime);
155-
evaluate_h_op(l_tilde.to_ref(), r_j, r_s)
155+
evaluate_h_op(l_tilde.as_ref(), r_j, r_s)
156156
});
157157

158158
let r_s_tensor = eq_ind_partial_eval::<F>(r_s);
@@ -247,7 +247,7 @@ mod tests {
247247
let subspace =
248248
BinarySubspace::<AESTowerField8b>::with_dim(LOG_WORD_SIZE_BITS).isomorphic();
249249
let l_tilde = lagrange_evals(&subspace, r_zhat_prime);
250-
let succinct_evaluations = evaluate_h_op(l_tilde.to_ref(), &r_j, &r_s);
250+
let succinct_evaluations = evaluate_h_op(l_tilde.as_ref(), &r_j, &r_s);
251251

252252
// Method 2: Direct evaluation via multilinear part
253253
let h_parts = build_h_parts(r_zhat_prime);

crates/prover/tests/shift.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,18 @@ fn test_shift_prove_and_verify() {
316316
let verifier_output =
317317
verify(&cs, &verifier_bitand_data, &verifier_intmul_data, &mut verifier_transcript)
318318
.unwrap();
319-
verifier_transcript.finalize().unwrap();
320319

321320
// Check consistency with verifier output
322-
check_eval(&cs, &verifier_bitand_data, &verifier_intmul_data, &subspace, &verifier_output)
323-
.unwrap();
321+
check_eval(
322+
&cs,
323+
&verifier_bitand_data,
324+
&verifier_intmul_data,
325+
&subspace,
326+
&verifier_output,
327+
&mut verifier_transcript,
328+
)
329+
.unwrap();
330+
verifier_transcript.finalize().unwrap();
324331

325332
// Check the claimed eval matches the computed eval
326333
let expected_eval = evaluate_witness(

crates/verifier/src/protocols/shift/monster.rs

Lines changed: 86 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
// Copyright 2025 Irreducible Inc.
22

3-
use std::iter;
3+
use std::{array, iter};
44

55
use binius_core::{
66
ShiftVariant,
77
constraint_system::{Operand, ShiftedValueIndex},
88
};
9-
use binius_field::{BinaryField, Field, util::powers};
9+
use binius_field::{BinaryField, FieldOps, util::powers};
1010
use binius_math::{
11-
BinarySubspace, FieldBuffer, FieldSlice,
12-
inner_product::{inner_product, inner_product_buffers},
13-
multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate_inplace},
14-
univariate::lagrange_evals,
11+
BinarySubspace,
12+
inner_product::inner_product_scalars,
13+
multilinear::{eq::eq_ind_partial_eval_scalars, evaluate::evaluate_inplace_scalars},
14+
univariate::lagrange_evals_scalars,
1515
};
16-
use binius_utils::rayon::prelude::*;
1716

1817
use super::{
1918
SHIFT_VARIANT_COUNT,
@@ -28,71 +27,66 @@ use crate::config::{LOG_WORD_SIZE_BITS, WORD_SIZE_BITS};
2827
///
2928
/// This is the verifier's version of the h-parts evaluation - instead of building
3029
/// full multilinear polynomials, it directly computes their evaluations.
31-
pub fn evaluate_h_op<F: Field>(
32-
l_tilde: FieldSlice<F>,
33-
r_j: &[F],
34-
r_s: &[F],
35-
) -> [F; SHIFT_VARIANT_COUNT] {
36-
assert_eq!(l_tilde.log_len(), LOG_WORD_SIZE_BITS);
30+
pub fn evaluate_h_op<E: FieldOps>(l_tilde: &[E], r_j: &[E], r_s: &[E]) -> [E; SHIFT_VARIANT_COUNT] {
31+
assert_eq!(l_tilde.len(), WORD_SIZE_BITS);
3732
assert_eq!(r_j.len(), LOG_WORD_SIZE_BITS);
3833
assert_eq!(r_s.len(), LOG_WORD_SIZE_BITS);
3934

4035
// Use helper functions to compute shift indicator helpers for 64-bit shifts
4136
let (sigma, sigma_prime) = partial_eval_sigmas(r_j, r_s);
4237
let sigma_transpose = partial_eval_sigmas_transpose(r_j, r_s);
4338
let phi = partial_eval_phi(r_s);
44-
let j_product = r_j.iter().product::<F>();
39+
let j_product: E = r_j.iter().cloned().product();
4540

4641
// Use helper functions to compute shift indicator helpers for 32-bit shifts
4742
let (sigma32, sigma32_prime) = partial_eval_sigmas(&r_j[..5], &r_s[..5]);
4843
let sigma32_transpose = partial_eval_sigmas_transpose(&r_j[..5], &r_s[..5]);
4944
let phi32 = partial_eval_phi(&r_s[..5]);
50-
let j_product32 = r_j[..5].iter().product::<F>();
45+
let j_product32: E = r_j[..5].iter().cloned().product();
5146

5247
// Compute final results
53-
let sll = inner_product_buffers(&l_tilde, &sigma_transpose);
54-
let srl = inner_product_buffers(&l_tilde, &sigma);
48+
let sll = inner_product_scalars(l_tilde.iter().cloned(), sigma_transpose);
49+
let srl = inner_product_scalars(l_tilde.iter().cloned(), sigma.iter().cloned());
5550
// sra == ∑ᵢ L̃(i) ⋅ (srlᵢ + ∏ₖ rⱼ[k] ⋅ φᵢ)
5651
// == ∑ᵢ L̃(i) ⋅ srlᵢ + ∏ₖ rⱼ[k] ⋅ [ ∑ᵢ L̃(i) ⋅ φᵢ ]
5752
// == srl + ∏ₖ rⱼ[k] ⋅ [ ∑ᵢ L̃(i) ⋅ φᵢ ]
58-
let sra = srl + j_product * inner_product_buffers(&l_tilde, &phi);
59-
let rotr = inner_product(
60-
l_tilde.iter_scalars(),
61-
iter::zip(sigma.as_ref(), sigma_prime.as_ref()).map(|(&s_i, &s_prime_i)| s_i + s_prime_i),
53+
let sra = srl.clone() + j_product * inner_product_scalars(l_tilde.iter().cloned(), phi);
54+
let rotr = inner_product_scalars(
55+
l_tilde.iter().cloned(),
56+
iter::zip(&sigma, &sigma_prime).map(|(s_i, s_prime_i)| s_i.clone() + s_prime_i),
6257
);
6358

64-
// TODO: This is really gross, need to clean it up after other shift reduction modifications.
65-
let r_j_rest_tensor = eq_ind_partial_eval::<F>(&r_j[5..]);
66-
let l_tilde_chunks = l_tilde.chunks(5);
59+
let r_j_rest_tensor = eq_ind_partial_eval_scalars(&r_j[5..]);
60+
let chunk_size = 1 << 5; // 32
6761

68-
let sll32 = inner_product(
69-
l_tilde_chunks
70-
.clone()
71-
.map(|l_tilde_i| inner_product_buffers(&l_tilde_i, &sigma32_transpose)),
72-
r_j_rest_tensor.iter_scalars(),
62+
let sll32 = inner_product_scalars(
63+
l_tilde.chunks(chunk_size).map(|chunk| {
64+
inner_product_scalars(chunk.iter().cloned(), sigma32_transpose.iter().cloned())
65+
}),
66+
r_j_rest_tensor.iter().cloned(),
7367
);
74-
let srl32 = inner_product(
75-
l_tilde_chunks
76-
.clone()
77-
.map(|l_tilde_i| inner_product_buffers(&l_tilde_i, &sigma32)),
78-
r_j_rest_tensor.iter_scalars(),
68+
let srl32 = inner_product_scalars(
69+
l_tilde
70+
.chunks(chunk_size)
71+
.map(|chunk| inner_product_scalars(chunk.iter().cloned(), sigma32.iter().cloned())),
72+
r_j_rest_tensor.iter().cloned(),
7973
);
80-
let sra32 = srl32
81-
+ inner_product(
82-
l_tilde_chunks
83-
.clone()
84-
.map(|l_tilde_i| j_product32 * inner_product_buffers(&l_tilde_i, &phi32)),
85-
r_j_rest_tensor.iter_scalars(),
74+
let sra32 = srl32.clone()
75+
+ inner_product_scalars(
76+
l_tilde.chunks(chunk_size).map(|chunk| {
77+
j_product32.clone()
78+
* inner_product_scalars(chunk.iter().cloned(), phi32.iter().cloned())
79+
}),
80+
r_j_rest_tensor.iter().cloned(),
8681
);
87-
let rotr32 = inner_product(
88-
l_tilde_chunks.clone().map(|l_tilde_i| {
89-
inner_product(
90-
l_tilde_i.iter_scalars(),
91-
iter::zip(sigma32.as_ref(), sigma32_prime.as_ref())
92-
.map(|(&s_i, &s_prime_i)| s_i + s_prime_i),
82+
let rotr32 = inner_product_scalars(
83+
l_tilde.chunks(chunk_size).map(|chunk| {
84+
inner_product_scalars(
85+
chunk.iter().cloned(),
86+
iter::zip(&sigma32, &sigma32_prime).map(|(s_i, s_prime_i)| s_i.clone() + s_prime_i),
9387
)
9488
}),
95-
r_j_rest_tensor.iter_scalars(),
89+
r_j_rest_tensor,
9690
);
9791

9892
[sll, srl, sra, rotr, sll32, srl32, sra32, rotr32]
@@ -133,31 +127,32 @@ pub fn evaluate_h_op<F: Field>(
133127
/// # Errors
134128
///
135129
/// Returns an error if the binary subspace construction fails.
136-
pub fn evaluate_monster_multilinear_for_operation<F: BinaryField, const ARITY: usize>(
130+
pub fn evaluate_monster_multilinear_for_operation<F, E, const ARITY: usize>(
137131
operand_vecs: &[Vec<&Operand>],
138-
operator_data: &OperatorData<F, ARITY>,
132+
operator_data: &OperatorData<E, ARITY>,
139133
subspace: &BinarySubspace<F>,
140-
lambda: F,
141-
r_j: &[F],
142-
r_s: &[F],
143-
r_y: &[F],
144-
) -> Result<F, Error> {
134+
lambda: E,
135+
r_j: &[E],
136+
r_s: &[E],
137+
r_y: &[E],
138+
) -> Result<E, Error>
139+
where
140+
F: BinaryField,
141+
E: FieldOps<Scalar = F> + From<F>,
142+
{
145143
assert_eq!(subspace.dim(), LOG_WORD_SIZE_BITS); // precondition
146144

147-
let r_x_prime_tensor = eq_ind_partial_eval::<F>(&operator_data.r_x_prime);
148-
let r_y_tensor = eq_ind_partial_eval::<F>(r_y);
145+
let r_x_prime_tensor = eq_ind_partial_eval_scalars(&operator_data.r_x_prime);
146+
let r_y_tensor = eq_ind_partial_eval_scalars(r_y);
149147

150-
let l_tilde = lagrange_evals(subspace, operator_data.r_zhat_prime);
151-
let h_op_evals = evaluate_h_op(l_tilde.to_ref(), r_j, r_s);
148+
let l_tilde = lagrange_evals_scalars(subspace, operator_data.r_zhat_prime.clone());
149+
let h_op_evals = evaluate_h_op(&l_tilde, r_j, r_s);
152150

153151
let lambda_powers = powers(lambda).skip(1).take(ARITY).collect::<Vec<_>>();
154152
let evals = evaluate_matrices(operand_vecs, &lambda_powers, &r_x_prime_tensor, &r_y_tensor);
155153

156-
let eval = inner_product(
157-
evals.map(|mut evals_op| {
158-
let evals_op = FieldBuffer::new(LOG_WORD_SIZE_BITS, &mut evals_op[..]);
159-
evaluate_inplace(evals_op, r_s)
160-
}),
154+
let eval = inner_product_scalars(
155+
evals.map(|mut evals_op| evaluate_inplace_scalars(&mut evals_op[..], r_s)),
161156
h_op_evals,
162157
);
163158

@@ -178,32 +173,27 @@ pub fn evaluate_monster_multilinear_for_operation<F: BinaryField, const ARITY: u
178173
/// * `operand_coeffs` - Coefficients (λ powers) for batching operand evaluations
179174
/// * `r_x_prime_tensor` - Multilinear challenge tensor for constraint variables
180175
/// * `r_y_tensor` - Challenge tensor for word index variables
181-
///
182-
/// Note: This function uses multithreading (par_iter), which is an exception to the general
183-
/// rule that the verifier should be single-threaded. The sparse matrix evaluation takes time
184-
/// linear in the size of the constraint system, so we use parallelization here to make the
185-
/// verifier performant on large constraint systems.
186-
fn evaluate_matrices<F: BinaryField>(
176+
fn evaluate_matrices<F: BinaryField, E: FieldOps<Scalar = F> + From<F>>(
187177
operands: &[Vec<&Operand>],
188-
operand_coeffs: &[F],
189-
r_x_prime_tensor: &FieldBuffer<F>,
190-
r_y_tensor: &FieldBuffer<F>,
191-
) -> [[F; WORD_SIZE_BITS]; SHIFT_VARIANT_COUNT] {
178+
operand_coeffs: &[E],
179+
r_x_prime_tensor: &[E],
180+
r_y_tensor: &[E],
181+
) -> [[E; WORD_SIZE_BITS]; SHIFT_VARIANT_COUNT] {
192182
assert_eq!(operands.len(), operand_coeffs.len());
193183

194-
// Use parallelization for performance (see docstring for rationale).
195-
(operand_coeffs, operands)
196-
.into_par_iter()
197-
.map(|(&coeff, constraint_operands)| {
198-
let mut evals = [[F::ZERO; WORD_SIZE_BITS]; SHIFT_VARIANT_COUNT];
199-
for (&operand_terms, &constraint_eval) in
200-
iter::zip(constraint_operands, r_x_prime_tensor.as_ref())
184+
let zero_evals = array::from_fn(|_| array::from_fn::<E, WORD_SIZE_BITS, _>(|_| E::zero()));
185+
186+
iter::zip(operand_coeffs, operands)
187+
.map(|(coeff, constraint_operands)| {
188+
let mut evals: [[E; WORD_SIZE_BITS]; SHIFT_VARIANT_COUNT] =
189+
array::from_fn(|_| array::from_fn::<E, WORD_SIZE_BITS, _>(|_| E::zero()));
190+
for (operand_terms, constraint_eval) in iter::zip(constraint_operands, r_x_prime_tensor)
201191
{
202192
for ShiftedValueIndex {
203193
value_index,
204194
shift_variant,
205195
amount,
206-
} in operand_terms
196+
} in *operand_terms
207197
{
208198
let shift_id = match shift_variant {
209199
ShiftVariant::Sll => 0,
@@ -216,7 +206,7 @@ fn evaluate_matrices<F: BinaryField>(
216206
ShiftVariant::Rotr32 => 7,
217207
};
218208
evals[shift_id][*amount] +=
219-
constraint_eval * r_y_tensor.get(value_index.0 as usize);
209+
constraint_eval.clone() * &r_y_tensor[value_index.0 as usize];
220210
}
221211
}
222212

@@ -229,26 +219,23 @@ fn evaluate_matrices<F: BinaryField>(
229219

230220
evals
231221
})
232-
.reduce(
233-
|| [[F::ZERO; WORD_SIZE_BITS]; SHIFT_VARIANT_COUNT],
234-
|mut a, b| {
235-
for (a_op, b_op) in iter::zip(&mut a, b) {
236-
for (a_op_s, b_op_s) in iter::zip(&mut *a_op, b_op) {
237-
*a_op_s += b_op_s;
238-
}
222+
.fold(zero_evals, |mut a, b| {
223+
for (a_op, b_op) in iter::zip(&mut a, b) {
224+
for (a_op_s, b_op_s) in iter::zip(&mut *a_op, b_op) {
225+
*a_op_s += b_op_s;
239226
}
240-
a
241-
},
242-
)
227+
}
228+
a
229+
})
243230
}
244231

245232
#[cfg(test)]
246233
mod tests {
247-
use binius_field::{BinaryField128bGhash, Random};
234+
use binius_field::{BinaryField128bGhash, Field, Random};
248235
use binius_math::{
249236
BinarySubspace,
250237
test_utils::{index_to_hypercube_point, random_scalars},
251-
univariate::lagrange_evals,
238+
univariate::lagrange_evals_scalars,
252239
};
253240
use rand::{Rng, SeedableRng, rngs::StdRng};
254241

@@ -273,13 +260,13 @@ mod tests {
273260
let s = rng.random_range(0..64);
274261

275262
let challenge = subspace.get(i);
276-
let l_tilde = lagrange_evals(&subspace, challenge);
263+
let l_tilde = lagrange_evals_scalars(&subspace, challenge);
277264

278265
let r_j = index_to_hypercube_point::<BinaryField128bGhash>(LOG_WORD_SIZE_BITS, j);
279266
let r_s = index_to_hypercube_point::<BinaryField128bGhash>(LOG_WORD_SIZE_BITS, s);
280267

281268
let [sll, srl, sra, rotr, sll32, srl32, sra32, rotr32] =
282-
evaluate_h_op(l_tilde.to_ref(), &r_j, &r_s);
269+
evaluate_h_op(&l_tilde, &r_j, &r_s);
283270

284271
let expected_sll = j + s == i;
285272
let expected_srl = i + s == j;
@@ -324,7 +311,7 @@ mod tests {
324311
// Generate random evaluation points
325312
let challenge = BinaryField128bGhash::random(&mut rng);
326313
let subspace = BinarySubspace::<BinaryField128bGhash>::with_dim(LOG_WORD_SIZE_BITS);
327-
let l_tilde = lagrange_evals(&subspace, challenge);
314+
let l_tilde = lagrange_evals_scalars(&subspace, challenge);
328315
let r_j = random_scalars::<BinaryField128bGhash>(&mut rng, LOG_WORD_SIZE_BITS);
329316
let r_s = random_scalars::<BinaryField128bGhash>(&mut rng, LOG_WORD_SIZE_BITS);
330317

@@ -336,7 +323,7 @@ mod tests {
336323
let mut r_j_at_1 = r_j.clone();
337324
r_j_at_1[i] = BinaryField128bGhash::ONE;
338325
let [result_0, result_1, result_y] = [&r_j_at_0, &r_j_at_1, &r_j]
339-
.map(|r_j_variant| evaluate_h_op(l_tilde.to_ref(), r_j_variant, &r_s));
326+
.map(|r_j_variant| evaluate_h_op(&l_tilde, r_j_variant, &r_s));
340327
for variant in 0..SHIFT_VARIANT_COUNT {
341328
let expected = result_0[variant] * (BinaryField128bGhash::ONE - r_j[i])
342329
+ result_1[variant] * r_j[i];
@@ -349,7 +336,7 @@ mod tests {
349336
let mut r_s_at_1 = r_s.clone();
350337
r_s_at_1[i] = BinaryField128bGhash::ONE;
351338
let [result_0, result_1, result_y] = [&r_s_at_0, &r_s_at_1, &r_s]
352-
.map(|r_s_variant| evaluate_h_op(l_tilde.to_ref(), &r_j, r_s_variant));
339+
.map(|r_s_variant| evaluate_h_op(&l_tilde, &r_j, r_s_variant));
353340
for variant in 0..SHIFT_VARIANT_COUNT {
354341
let expected = result_0[variant] * (BinaryField128bGhash::ONE - r_s[i])
355342
+ result_1[variant] * r_s[i];

0 commit comments

Comments
 (0)