11// Copyright 2025 Irreducible Inc.
22
3- use std:: iter;
3+ use std:: { array , iter} ;
44
55use binius_core:: {
66 ShiftVariant ,
77 constraint_system:: { Operand , ShiftedValueIndex } ,
88} ;
9- use binius_field:: { BinaryField , Field , util:: powers} ;
9+ use binius_field:: { BinaryField , FieldOps , util:: powers} ;
1010use binius_math:: {
11- BinarySubspace , FieldBuffer , FieldSlice ,
12- inner_product:: { inner_product , inner_product_buffers } ,
13- multilinear:: { eq:: eq_ind_partial_eval , evaluate:: evaluate_inplace } ,
14- univariate:: lagrange_evals ,
11+ BinarySubspace ,
12+ inner_product:: inner_product_scalars ,
13+ multilinear:: { eq:: eq_ind_partial_eval_scalars , evaluate:: evaluate_inplace_scalars } ,
14+ univariate:: lagrange_evals_scalars ,
1515} ;
16- use binius_utils:: rayon:: prelude:: * ;
1716
1817use super :: {
1918 SHIFT_VARIANT_COUNT ,
@@ -28,71 +27,66 @@ use crate::config::{LOG_WORD_SIZE_BITS, WORD_SIZE_BITS};
2827///
2928/// This is the verifier's version of the h-parts evaluation - instead of building
3029/// full multilinear polynomials, it directly computes their evaluations.
31- pub fn evaluate_h_op < F : Field > (
32- l_tilde : FieldSlice < F > ,
33- r_j : & [ F ] ,
34- r_s : & [ F ] ,
35- ) -> [ F ; SHIFT_VARIANT_COUNT ] {
36- assert_eq ! ( l_tilde. log_len( ) , LOG_WORD_SIZE_BITS ) ;
30+ pub fn evaluate_h_op < E : FieldOps > ( l_tilde : & [ E ] , r_j : & [ E ] , r_s : & [ E ] ) -> [ E ; SHIFT_VARIANT_COUNT ] {
31+ assert_eq ! ( l_tilde. len( ) , WORD_SIZE_BITS ) ;
3732 assert_eq ! ( r_j. len( ) , LOG_WORD_SIZE_BITS ) ;
3833 assert_eq ! ( r_s. len( ) , LOG_WORD_SIZE_BITS ) ;
3934
4035 // Use helper functions to compute shift indicator helpers for 64-bit shifts
4136 let ( sigma, sigma_prime) = partial_eval_sigmas ( r_j, r_s) ;
4237 let sigma_transpose = partial_eval_sigmas_transpose ( r_j, r_s) ;
4338 let phi = partial_eval_phi ( r_s) ;
44- let j_product = r_j. iter ( ) . product :: < F > ( ) ;
39+ let j_product: E = r_j. iter ( ) . cloned ( ) . product ( ) ;
4540
4641 // Use helper functions to compute shift indicator helpers for 32-bit shifts
4742 let ( sigma32, sigma32_prime) = partial_eval_sigmas ( & r_j[ ..5 ] , & r_s[ ..5 ] ) ;
4843 let sigma32_transpose = partial_eval_sigmas_transpose ( & r_j[ ..5 ] , & r_s[ ..5 ] ) ;
4944 let phi32 = partial_eval_phi ( & r_s[ ..5 ] ) ;
50- let j_product32 = r_j[ ..5 ] . iter ( ) . product :: < F > ( ) ;
45+ let j_product32: E = r_j[ ..5 ] . iter ( ) . cloned ( ) . product ( ) ;
5146
5247 // Compute final results
53- let sll = inner_product_buffers ( & l_tilde, & sigma_transpose) ;
54- let srl = inner_product_buffers ( & l_tilde, & sigma) ;
48+ let sll = inner_product_scalars ( l_tilde. iter ( ) . cloned ( ) , sigma_transpose) ;
49+ let srl = inner_product_scalars ( l_tilde. iter ( ) . cloned ( ) , sigma. iter ( ) . cloned ( ) ) ;
5550 // sra == ∑ᵢ L̃(i) ⋅ (srlᵢ + ∏ₖ rⱼ[k] ⋅ φᵢ)
5651 // == ∑ᵢ L̃(i) ⋅ srlᵢ + ∏ₖ rⱼ[k] ⋅ [ ∑ᵢ L̃(i) ⋅ φᵢ ]
5752 // == srl + ∏ₖ rⱼ[k] ⋅ [ ∑ᵢ L̃(i) ⋅ φᵢ ]
58- let sra = srl + j_product * inner_product_buffers ( & l_tilde, & phi) ;
59- let rotr = inner_product (
60- l_tilde. iter_scalars ( ) ,
61- iter:: zip ( sigma. as_ref ( ) , sigma_prime. as_ref ( ) ) . map ( |( & s_i, & s_prime_i) | s_i + s_prime_i) ,
53+ let sra = srl. clone ( ) + j_product * inner_product_scalars ( l_tilde. iter ( ) . cloned ( ) , phi) ;
54+ let rotr = inner_product_scalars (
55+ l_tilde. iter ( ) . cloned ( ) ,
56+ iter:: zip ( & sigma, & sigma_prime) . map ( |( s_i, s_prime_i) | s_i. clone ( ) + s_prime_i) ,
6257 ) ;
6358
64- // TODO: This is really gross, need to clean it up after other shift reduction modifications.
65- let r_j_rest_tensor = eq_ind_partial_eval :: < F > ( & r_j[ 5 ..] ) ;
66- let l_tilde_chunks = l_tilde. chunks ( 5 ) ;
59+ let r_j_rest_tensor = eq_ind_partial_eval_scalars ( & r_j[ 5 ..] ) ;
60+ let chunk_size = 1 << 5 ; // 32
6761
68- let sll32 = inner_product (
69- l_tilde_chunks
70- . clone ( )
71- . map ( |l_tilde_i| inner_product_buffers ( & l_tilde_i , & sigma32_transpose ) ) ,
72- r_j_rest_tensor. iter_scalars ( ) ,
62+ let sll32 = inner_product_scalars (
63+ l_tilde . chunks ( chunk_size ) . map ( |chunk| {
64+ inner_product_scalars ( chunk . iter ( ) . cloned ( ) , sigma32_transpose . iter ( ) . cloned ( ) )
65+ } ) ,
66+ r_j_rest_tensor. iter ( ) . cloned ( ) ,
7367 ) ;
74- let srl32 = inner_product (
75- l_tilde_chunks
76- . clone ( )
77- . map ( |l_tilde_i| inner_product_buffers ( & l_tilde_i , & sigma32) ) ,
78- r_j_rest_tensor. iter_scalars ( ) ,
68+ let srl32 = inner_product_scalars (
69+ l_tilde
70+ . chunks ( chunk_size )
71+ . map ( |chunk| inner_product_scalars ( chunk . iter ( ) . cloned ( ) , sigma32. iter ( ) . cloned ( ) ) ) ,
72+ r_j_rest_tensor. iter ( ) . cloned ( ) ,
7973 ) ;
80- let sra32 = srl32
81- + inner_product (
82- l_tilde_chunks
83- . clone ( )
84- . map ( |l_tilde_i| j_product32 * inner_product_buffers ( & l_tilde_i, & phi32) ) ,
85- r_j_rest_tensor. iter_scalars ( ) ,
74+ let sra32 = srl32. clone ( )
75+ + inner_product_scalars (
76+ l_tilde. chunks ( chunk_size) . map ( |chunk| {
77+ j_product32. clone ( )
78+ * inner_product_scalars ( chunk. iter ( ) . cloned ( ) , phi32. iter ( ) . cloned ( ) )
79+ } ) ,
80+ r_j_rest_tensor. iter ( ) . cloned ( ) ,
8681 ) ;
87- let rotr32 = inner_product (
88- l_tilde_chunks. clone ( ) . map ( |l_tilde_i| {
89- inner_product (
90- l_tilde_i. iter_scalars ( ) ,
91- iter:: zip ( sigma32. as_ref ( ) , sigma32_prime. as_ref ( ) )
92- . map ( |( & s_i, & s_prime_i) | s_i + s_prime_i) ,
82+ let rotr32 = inner_product_scalars (
83+ l_tilde. chunks ( chunk_size) . map ( |chunk| {
84+ inner_product_scalars (
85+ chunk. iter ( ) . cloned ( ) ,
86+ iter:: zip ( & sigma32, & sigma32_prime) . map ( |( s_i, s_prime_i) | s_i. clone ( ) + s_prime_i) ,
9387 )
9488 } ) ,
95- r_j_rest_tensor. iter_scalars ( ) ,
89+ r_j_rest_tensor,
9690 ) ;
9791
9892 [ sll, srl, sra, rotr, sll32, srl32, sra32, rotr32]
@@ -133,31 +127,32 @@ pub fn evaluate_h_op<F: Field>(
133127/// # Errors
134128///
135129/// Returns an error if the binary subspace construction fails.
136- pub fn evaluate_monster_multilinear_for_operation < F : BinaryField , const ARITY : usize > (
130+ pub fn evaluate_monster_multilinear_for_operation < F , E , const ARITY : usize > (
137131 operand_vecs : & [ Vec < & Operand > ] ,
138- operator_data : & OperatorData < F , ARITY > ,
132+ operator_data : & OperatorData < E , ARITY > ,
139133 subspace : & BinarySubspace < F > ,
140- lambda : F ,
141- r_j : & [ F ] ,
142- r_s : & [ F ] ,
143- r_y : & [ F ] ,
144- ) -> Result < F , Error > {
134+ lambda : E ,
135+ r_j : & [ E ] ,
136+ r_s : & [ E ] ,
137+ r_y : & [ E ] ,
138+ ) -> Result < E , Error >
139+ where
140+ F : BinaryField ,
141+ E : FieldOps < Scalar = F > + From < F > ,
142+ {
145143 assert_eq ! ( subspace. dim( ) , LOG_WORD_SIZE_BITS ) ; // precondition
146144
147- let r_x_prime_tensor = eq_ind_partial_eval :: < F > ( & operator_data. r_x_prime ) ;
148- let r_y_tensor = eq_ind_partial_eval :: < F > ( r_y) ;
145+ let r_x_prime_tensor = eq_ind_partial_eval_scalars ( & operator_data. r_x_prime ) ;
146+ let r_y_tensor = eq_ind_partial_eval_scalars ( r_y) ;
149147
150- let l_tilde = lagrange_evals ( subspace, operator_data. r_zhat_prime ) ;
151- let h_op_evals = evaluate_h_op ( l_tilde. to_ref ( ) , r_j, r_s) ;
148+ let l_tilde = lagrange_evals_scalars ( subspace, operator_data. r_zhat_prime . clone ( ) ) ;
149+ let h_op_evals = evaluate_h_op ( & l_tilde, r_j, r_s) ;
152150
153151 let lambda_powers = powers ( lambda) . skip ( 1 ) . take ( ARITY ) . collect :: < Vec < _ > > ( ) ;
154152 let evals = evaluate_matrices ( operand_vecs, & lambda_powers, & r_x_prime_tensor, & r_y_tensor) ;
155153
156- let eval = inner_product (
157- evals. map ( |mut evals_op| {
158- let evals_op = FieldBuffer :: new ( LOG_WORD_SIZE_BITS , & mut evals_op[ ..] ) ;
159- evaluate_inplace ( evals_op, r_s)
160- } ) ,
154+ let eval = inner_product_scalars (
155+ evals. map ( |mut evals_op| evaluate_inplace_scalars ( & mut evals_op[ ..] , r_s) ) ,
161156 h_op_evals,
162157 ) ;
163158
@@ -178,32 +173,27 @@ pub fn evaluate_monster_multilinear_for_operation<F: BinaryField, const ARITY: u
178173/// * `operand_coeffs` - Coefficients (λ powers) for batching operand evaluations
179174/// * `r_x_prime_tensor` - Multilinear challenge tensor for constraint variables
180175/// * `r_y_tensor` - Challenge tensor for word index variables
181- ///
182- /// Note: This function uses multithreading (par_iter), which is an exception to the general
183- /// rule that the verifier should be single-threaded. The sparse matrix evaluation takes time
184- /// linear in the size of the constraint system, so we use parallelization here to make the
185- /// verifier performant on large constraint systems.
186- fn evaluate_matrices < F : BinaryField > (
176+ fn evaluate_matrices < F : BinaryField , E : FieldOps < Scalar = F > + From < F > > (
187177 operands : & [ Vec < & Operand > ] ,
188- operand_coeffs : & [ F ] ,
189- r_x_prime_tensor : & FieldBuffer < F > ,
190- r_y_tensor : & FieldBuffer < F > ,
191- ) -> [ [ F ; WORD_SIZE_BITS ] ; SHIFT_VARIANT_COUNT ] {
178+ operand_coeffs : & [ E ] ,
179+ r_x_prime_tensor : & [ E ] ,
180+ r_y_tensor : & [ E ] ,
181+ ) -> [ [ E ; WORD_SIZE_BITS ] ; SHIFT_VARIANT_COUNT ] {
192182 assert_eq ! ( operands. len( ) , operand_coeffs. len( ) ) ;
193183
194- // Use parallelization for performance (see docstring for rationale).
195- ( operand_coeffs , operands )
196- . into_par_iter ( )
197- . map ( |( & coeff, constraint_operands) | {
198- let mut evals = [ [ F :: ZERO ; WORD_SIZE_BITS ] ; SHIFT_VARIANT_COUNT ] ;
199- for ( & operand_terms , & constraint_eval ) in
200- iter:: zip ( constraint_operands, r_x_prime_tensor. as_ref ( ) )
184+ let zero_evals = array :: from_fn ( |_| array :: from_fn :: < E , WORD_SIZE_BITS , _ > ( |_| E :: zero ( ) ) ) ;
185+
186+ iter :: zip ( operand_coeffs , operands )
187+ . map ( |( coeff, constraint_operands) | {
188+ let mut evals: [ [ E ; WORD_SIZE_BITS ] ; SHIFT_VARIANT_COUNT ] =
189+ array :: from_fn ( |_| array :: from_fn :: < E , WORD_SIZE_BITS , _ > ( |_| E :: zero ( ) ) ) ;
190+ for ( operand_terms , constraint_eval ) in iter:: zip ( constraint_operands, r_x_prime_tensor)
201191 {
202192 for ShiftedValueIndex {
203193 value_index,
204194 shift_variant,
205195 amount,
206- } in operand_terms
196+ } in * operand_terms
207197 {
208198 let shift_id = match shift_variant {
209199 ShiftVariant :: Sll => 0 ,
@@ -216,7 +206,7 @@ fn evaluate_matrices<F: BinaryField>(
216206 ShiftVariant :: Rotr32 => 7 ,
217207 } ;
218208 evals[ shift_id] [ * amount] +=
219- constraint_eval * r_y_tensor. get ( value_index. 0 as usize ) ;
209+ constraint_eval. clone ( ) * & r_y_tensor[ value_index. 0 as usize ] ;
220210 }
221211 }
222212
@@ -229,26 +219,23 @@ fn evaluate_matrices<F: BinaryField>(
229219
230220 evals
231221 } )
232- . reduce (
233- || [ [ F :: ZERO ; WORD_SIZE_BITS ] ; SHIFT_VARIANT_COUNT ] ,
234- |mut a, b| {
235- for ( a_op, b_op) in iter:: zip ( & mut a, b) {
236- for ( a_op_s, b_op_s) in iter:: zip ( & mut * a_op, b_op) {
237- * a_op_s += b_op_s;
238- }
222+ . fold ( zero_evals, |mut a, b| {
223+ for ( a_op, b_op) in iter:: zip ( & mut a, b) {
224+ for ( a_op_s, b_op_s) in iter:: zip ( & mut * a_op, b_op) {
225+ * a_op_s += b_op_s;
239226 }
240- a
241- } ,
242- )
227+ }
228+ a
229+ } )
243230}
244231
245232#[ cfg( test) ]
246233mod tests {
247- use binius_field:: { BinaryField128bGhash , Random } ;
234+ use binius_field:: { BinaryField128bGhash , Field , Random } ;
248235 use binius_math:: {
249236 BinarySubspace ,
250237 test_utils:: { index_to_hypercube_point, random_scalars} ,
251- univariate:: lagrange_evals ,
238+ univariate:: lagrange_evals_scalars ,
252239 } ;
253240 use rand:: { Rng , SeedableRng , rngs:: StdRng } ;
254241
@@ -273,13 +260,13 @@ mod tests {
273260 let s = rng. random_range ( 0 ..64 ) ;
274261
275262 let challenge = subspace. get ( i) ;
276- let l_tilde = lagrange_evals ( & subspace, challenge) ;
263+ let l_tilde = lagrange_evals_scalars ( & subspace, challenge) ;
277264
278265 let r_j = index_to_hypercube_point :: < BinaryField128bGhash > ( LOG_WORD_SIZE_BITS , j) ;
279266 let r_s = index_to_hypercube_point :: < BinaryField128bGhash > ( LOG_WORD_SIZE_BITS , s) ;
280267
281268 let [ sll, srl, sra, rotr, sll32, srl32, sra32, rotr32] =
282- evaluate_h_op ( l_tilde. to_ref ( ) , & r_j, & r_s) ;
269+ evaluate_h_op ( & l_tilde, & r_j, & r_s) ;
283270
284271 let expected_sll = j + s == i;
285272 let expected_srl = i + s == j;
@@ -324,7 +311,7 @@ mod tests {
324311 // Generate random evaluation points
325312 let challenge = BinaryField128bGhash :: random ( & mut rng) ;
326313 let subspace = BinarySubspace :: < BinaryField128bGhash > :: with_dim ( LOG_WORD_SIZE_BITS ) ;
327- let l_tilde = lagrange_evals ( & subspace, challenge) ;
314+ let l_tilde = lagrange_evals_scalars ( & subspace, challenge) ;
328315 let r_j = random_scalars :: < BinaryField128bGhash > ( & mut rng, LOG_WORD_SIZE_BITS ) ;
329316 let r_s = random_scalars :: < BinaryField128bGhash > ( & mut rng, LOG_WORD_SIZE_BITS ) ;
330317
@@ -336,7 +323,7 @@ mod tests {
336323 let mut r_j_at_1 = r_j. clone ( ) ;
337324 r_j_at_1[ i] = BinaryField128bGhash :: ONE ;
338325 let [ result_0, result_1, result_y] = [ & r_j_at_0, & r_j_at_1, & r_j]
339- . map ( |r_j_variant| evaluate_h_op ( l_tilde. to_ref ( ) , r_j_variant, & r_s) ) ;
326+ . map ( |r_j_variant| evaluate_h_op ( & l_tilde, r_j_variant, & r_s) ) ;
340327 for variant in 0 ..SHIFT_VARIANT_COUNT {
341328 let expected = result_0[ variant] * ( BinaryField128bGhash :: ONE - r_j[ i] )
342329 + result_1[ variant] * r_j[ i] ;
@@ -349,7 +336,7 @@ mod tests {
349336 let mut r_s_at_1 = r_s. clone ( ) ;
350337 r_s_at_1[ i] = BinaryField128bGhash :: ONE ;
351338 let [ result_0, result_1, result_y] = [ & r_s_at_0, & r_s_at_1, & r_s]
352- . map ( |r_s_variant| evaluate_h_op ( l_tilde. to_ref ( ) , & r_j, r_s_variant) ) ;
339+ . map ( |r_s_variant| evaluate_h_op ( & l_tilde, & r_j, r_s_variant) ) ;
353340 for variant in 0 ..SHIFT_VARIANT_COUNT {
354341 let expected = result_0[ variant] * ( BinaryField128bGhash :: ONE - r_s[ i] )
355342 + result_1[ variant] * r_s[ i] ;
0 commit comments