From f4ecabe11f49b676e6f590167a64d9b9d0abb959 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 25 Nov 2025 01:14:57 +0400 Subject: [PATCH] max_width = 120 --- crates/air/src/prove.rs | 46 +- crates/air/src/uni_skip_utils.rs | 9 +- crates/air/src/utils.rs | 14 +- crates/air/src/verify.rs | 66 +- crates/air/tests/complex_air.rs | 88 +-- crates/air/tests/fib_air.rs | 58 +- crates/lean_compiler/src/a_simplify_lang.rs | 601 ++++-------------- .../src/b_compile_intermediate.rs | 87 +-- crates/lean_compiler/src/c_compile_final.rs | 55 +- crates/lean_compiler/src/ir/instruction.rs | 13 +- crates/lean_compiler/src/lang.rs | 84 +-- crates/lean_compiler/src/parser/grammar.rs | 5 +- .../src/parser/parsers/expression.rs | 40 +- .../src/parser/parsers/function.rs | 20 +- .../src/parser/parsers/literal.rs | 37 +- .../lean_compiler/src/parser/parsers/mod.rs | 7 +- .../src/parser/parsers/program.rs | 5 +- .../src/parser/parsers/statement.rs | 33 +- crates/lean_compiler/tests/test_compiler.rs | 119 +--- crates/lean_prover/src/common.rs | 37 +- crates/lean_prover/src/prove_execution.rs | 177 ++---- crates/lean_prover/src/verify_execution.rs | 100 +-- crates/lean_prover/tests/hash_chain.rs | 9 +- crates/lean_prover/tests/test_zkvm.rs | 26 +- .../witness_generation/src/execution_trace.rs | 15 +- .../src/instruction_encoder.rs | 6 +- .../witness_generation/src/poseidon_tables.rs | 3 +- crates/lean_vm/src/core/label.rs | 17 +- crates/lean_vm/src/diagnostics/profiler.rs | 34 +- crates/lean_vm/src/diagnostics/stack_trace.rs | 35 +- crates/lean_vm/src/execution/context.rs | 6 +- crates/lean_vm/src/execution/memory.rs | 10 +- crates/lean_vm/src/execution/runner.rs | 89 +-- crates/lean_vm/src/execution/tests.rs | 5 +- crates/lean_vm/src/isa/hint.rs | 18 +- crates/lean_vm/src/isa/instruction.rs | 12 +- crates/lean_vm/src/tables/dot_product/air.rs | 6 +- crates/lean_vm/src/tables/dot_product/exec.rs | 15 +- crates/lean_vm/src/tables/dot_product/mod.rs | 6 +- crates/lean_vm/src/tables/execution/air.rs | 14 +- crates/lean_vm/src/tables/execution/mod.rs | 9 +- crates/lean_vm/src/tables/poseidon_16/mod.rs | 15 +- crates/lean_vm/src/tables/poseidon_24/mod.rs | 5 +- crates/lean_vm/src/tables/table_trait.rs | 32 +- crates/lean_vm/tests/test_lean_vm.rs | 48 +- crates/lookup/src/logup_star.rs | 67 +- crates/lookup/src/quotient_gkr.rs | 96 +-- .../src/gkr_layers/batch_partial_rounds.rs | 4 +- .../src/gkr_layers/compression.rs | 4 +- crates/poseidon_circuit/src/gkr_layers/mod.rs | 6 +- crates/poseidon_circuit/src/lib.rs | 6 +- crates/poseidon_circuit/src/prove.rs | 34 +- crates/poseidon_circuit/src/tests.rs | 68 +- crates/poseidon_circuit/src/utils.rs | 7 +- crates/poseidon_circuit/src/verify.rs | 54 +- crates/poseidon_circuit/src/witness_gen.rs | 22 +- crates/rec_aggregation/src/recursion.rs | 56 +- crates/rec_aggregation/src/xmss_aggregate.rs | 36 +- .../src/commit_extension_from_base.rs | 11 +- .../src/generic_packed_lookup.rs | 79 +-- .../sub_protocols/src/normal_packed_lookup.rs | 14 +- crates/sub_protocols/src/packed_pcs.rs | 142 ++--- .../src/vectorized_packed_lookup.rs | 36 +- .../tests/test_generic_packed_lookup.rs | 17 +- .../tests/test_normal_packed_lookup.rs | 14 +- .../tests/test_vectorized_packed_lookup.rs | 23 +- crates/utils/src/misc.rs | 43 +- crates/utils/src/multilinear.rs | 30 +- crates/utils/src/wrappers.rs | 3 +- crates/xmss/src/lib.rs | 6 +- crates/xmss/src/wots.rs | 35 +- crates/xmss/src/xmss.rs | 33 +- crates/xmss/tests/test_xmss.rs | 4 +- rustfmt.toml | 2 +- src/main.rs | 12 +- 75 files changed, 775 insertions(+), 2325 deletions(-) diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index d231d82a..58ff6494 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -44,12 +44,7 @@ where *extra_data.alpha_powers_mut() = alpha .powers() - .take( - air.n_constraints() - + virtual_column_statements - .as_ref() - .map_or(0, |s| s.values.len()), - ) + .take(air.n_constraints() + virtual_column_statements.as_ref().map_or(0, |s| s.values.len())) .collect(); let n_sc_rounds = log_n_rows + 1 - univariate_skips; @@ -64,9 +59,7 @@ where .down_column_indexes_f() .par_iter() .zip_eq(last_row_shifted_f) - .map(|(&col_index, &final_value)| { - column_shifted(columns_f[col_index], final_value.as_base().unwrap()) - }) + .map(|(&col_index, &final_value)| column_shifted(columns_f[col_index], final_value.as_base().unwrap())) .collect::>(); let shifted_rows_ef = air .down_column_indexes_ef() @@ -81,10 +74,8 @@ where let mut columns_up_down_ef = columns_ef.to_vec(); // orginal columns, followed by shifted ones columns_up_down_ef.extend(shifted_rows_ef.iter().map(Vec::as_slice)); - let columns_up_down_group_f: MleGroupRef<'_, EF> = - MleGroupRef::<'_, EF>::Base(columns_up_down_f); - let columns_up_down_group_ef: MleGroupRef<'_, EF> = - MleGroupRef::<'_, EF>::Extension(columns_up_down_ef); + let columns_up_down_group_f: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Base(columns_up_down_f); + let columns_up_down_group_ef: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Extension(columns_up_down_ef); let columns_up_down_group_f_packed = columns_up_down_group_f.pack(); let columns_up_down_group_ef_packed = columns_up_down_group_ef.pack(); @@ -130,10 +121,8 @@ fn open_columns>>( columns_ef: &[&[EF]], outer_sumcheck_challenge: &[EF], ) -> (MultilinearPoint, Vec, Vec) { - let n_up_down_columns = columns_f.len() - + columns_ef.len() - + columns_with_shift_f.len() - + columns_with_shift_ef.len(); + let n_up_down_columns = + columns_f.len() + columns_ef.len() + columns_with_shift_f.len() + columns_with_shift_ef.len(); let batching_scalars = prover_state.sample_vec(log2_ceil_usize(n_up_down_columns)); let eval_eq_batching_scalars = eval_eq(&batching_scalars)[..n_up_down_columns].to_vec(); @@ -153,14 +142,8 @@ fn open_columns>>( }); } - let columns_shifted_f = &columns_with_shift_f - .iter() - .map(|&i| columns_f[i]) - .collect::>(); - let columns_shifted_ef = &columns_with_shift_ef - .iter() - .map(|&i| columns_ef[i]) - .collect::>(); + let columns_shifted_f = &columns_with_shift_f.iter().map(|&i| columns_f[i]).collect::>(); + let columns_shifted_ef = &columns_with_shift_ef.iter().map(|&i| columns_ef[i]).collect::>(); let mut batched_column_down = if columns_shifted_f.is_empty() { tracing::warn!("TODO optimize open_columns when no shifted F columns"); @@ -168,16 +151,14 @@ fn open_columns>>( } else { multilinears_linear_combination( columns_shifted_f, - &eval_eq_batching_scalars[columns_f.len() + columns_ef.len()..] - [..columns_shifted_f.len()], + &eval_eq_batching_scalars[columns_f.len() + columns_ef.len()..][..columns_shifted_f.len()], ) }; if !columns_shifted_ef.is_empty() { let batched_column_down_ef = multilinears_linear_combination( columns_shifted_ef, - &eval_eq_batching_scalars - [columns_f.len() + columns_ef.len() + columns_shifted_f.len()..], + &eval_eq_batching_scalars[columns_f.len() + columns_ef.len() + columns_shifted_f.len()..], ); batched_column_down .par_iter_mut() @@ -277,12 +258,7 @@ impl>> SumcheckComputation for MySumcheck { point[0] * point[1] + point[2] * point[3] } #[inline(always)] - fn eval_packed_base( - &self, - _: &[PFPacking], - _: &[EFPacking], - _: &Self::ExtraData, - ) -> EFPacking { + fn eval_packed_base(&self, _: &[PFPacking], _: &[EFPacking], _: &Self::ExtraData) -> EFPacking { unreachable!() } #[inline(always)] diff --git a/crates/air/src/uni_skip_utils.rs b/crates/air/src/uni_skip_utils.rs index aa7958b8..b1b42a37 100644 --- a/crates/air/src/uni_skip_utils.rs +++ b/crates/air/src/uni_skip_utils.rs @@ -6,8 +6,8 @@ pub fn matrix_next_mle_folded>>(outer_challenges: &[F]) let n = outer_challenges.len(); let mut res = F::zero_vec(1 << n); for k in 0..n { - let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1]) - * outer_challenges[n - k..].iter().copied().product::(); + let outer_challenges_prod = + (F::ONE - outer_challenges[n - k - 1]) * outer_challenges[n - k..].iter().copied().product::(); let mut eq_mle = eval_eq_scaled(&outer_challenges[0..n - k - 1], outer_challenges_prod); for (mut i, v) in eq_mle.iter_mut().enumerate() { i <<= k + 1; @@ -36,10 +36,7 @@ mod tests { for y in 0..1 << n_vars { let y_bools = to_big_endian_in_field::(y, n_vars); let expected = F::from_bool(x + 1 == y); - assert_eq!( - matrix.evaluate(&MultilinearPoint(y_bools.clone())), - expected - ); + assert_eq!(matrix.evaluate(&MultilinearPoint(y_bools.clone())), expected); assert_eq!(next_mle(&[x_bools.clone(), y_bools].concat()), expected); } } diff --git a/crates/air/src/utils.rs b/crates/air/src/utils.rs index c34093ea..63ae1972 100644 --- a/crates/air/src/utils.rs +++ b/crates/air/src/utils.rs @@ -36,11 +36,7 @@ use multilinear_toolkit::prelude::*; /// Field element: 1 if y = x + 1, 0 otherwise. pub(crate) fn next_mle(point: &[F]) -> F { // Check that the point length is even: we split into x and y of equal length. - assert_eq!( - point.len() % 2, - 0, - "Input point must have an even number of variables." - ); + assert_eq!(point.len() % 2, 0, "Input point must have an even number of variables."); let n = point.len() / 2; // Split point into x (first n) and y (last n). @@ -56,9 +52,7 @@ pub(crate) fn next_mle(point: &[F]) -> F { // // Indices are reversed because bits are big-endian. let eq_high_bits = (k + 1..n) - .map(|i| { - x[n - 1 - i] * y[n - 1 - i] + (F::ONE - x[n - 1 - i]) * (F::ONE - y[n - 1 - i]) - }) + .map(|i| x[n - 1 - i] * y[n - 1 - i] + (F::ONE - x[n - 1 - i]) * (F::ONE - y[n - 1 - i])) .product::(); // Term 2: carry bit at position k @@ -71,9 +65,7 @@ pub(crate) fn next_mle(point: &[F]) -> F { // // For i < k, enforce x_i = 1 and y_i = 0. // Condition: x_i * (1 - y_i). - let low_bits_are_one_zero = (0..k) - .map(|i| x[n - 1 - i] * (F::ONE - y[n - 1 - i])) - .product::(); + let low_bits_are_one_zero = (0..k).map(|i| x[n - 1 - i] * (F::ONE - y[n - 1 - i])).product::(); // Multiply the three terms for this k, representing one "carry pattern". eq_high_bits * carry_bit * low_bits_are_one_zero diff --git a/crates/air/src/verify.rs b/crates/air/src/verify.rs index 02b8cbc8..cdbfec62 100644 --- a/crates/air/src/verify.rs +++ b/crates/air/src/verify.rs @@ -23,12 +23,7 @@ where *extra_data.alpha_powers_mut() = alpha .powers() - .take( - air.n_constraints() - + virtual_column_statements - .as_ref() - .map_or(0, |s| s.values.len()), - ) + .take(air.n_constraints() + virtual_column_statements.as_ref().map_or(0, |s| s.values.len())) .collect(); let n_sc_rounds = log_n_rows + 1 - univariate_skips; @@ -38,12 +33,8 @@ where .unwrap_or_else(|| verifier_state.sample_vec(n_sc_rounds)); assert_eq!(zerocheck_challenges.len(), n_sc_rounds); - let (sc_sum, outer_statement) = sumcheck_verify_with_univariate_skip::( - verifier_state, - air.degree() + 1, - log_n_rows, - univariate_skips, - )?; + let (sc_sum, outer_statement) = + sumcheck_verify_with_univariate_skip::(verifier_state, air.degree() + 1, log_n_rows, univariate_skips)?; if sc_sum != virtual_column_statements .as_ref() @@ -59,9 +50,7 @@ where .collect::>(); let mut inner_sums = verifier_state.next_extension_scalars_vec( - air.n_columns_air() - + air.down_column_indexes_f().len() - + air.down_column_indexes_ef().len(), + air.n_columns_air() + air.down_column_indexes_f().len() + air.down_column_indexes_ef().len(), )?; let n_columns_down_f = air.down_column_indexes_f().len(); @@ -72,11 +61,7 @@ where &extra_data, ); - if eq_poly_with_skip( - &zerocheck_challenges, - &outer_statement.point, - univariate_skips, - ) * constraint_evals + if eq_poly_with_skip(&zerocheck_challenges, &outer_statement.point, univariate_skips) * constraint_evals != outer_statement.value { return Err(ProofError::InvalidProof); @@ -128,14 +113,8 @@ fn open_columns>>( evals_up_and_down.len() ); let last_row_selector = outer_selector_evals[(1 << univariate_skips) - 1] - * outer_sumcheck_challenge - .point - .iter() - .copied() - .product::(); - for (&last_row_value, down_col_eval) in - last_row_f.iter().zip(&mut evals_up_and_down[n_columns..]) - { + * outer_sumcheck_challenge.point.iter().copied().product::(); + for (&last_row_value, down_col_eval) in last_row_f.iter().zip(&mut evals_up_and_down[n_columns..]) { *down_col_eval -= last_row_selector * last_row_value; } for (&last_row_value, down_col_eval) in last_row_ef @@ -145,9 +124,7 @@ fn open_columns>>( *down_col_eval -= last_row_selector * last_row_value; } - let batching_scalars = verifier_state.sample_vec(log2_ceil_usize( - n_columns + last_row_f.len() + last_row_ef.len(), - )); + let batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns + last_row_f.len() + last_row_ef.len())); let eval_eq_batching_scalars = eval_eq(&batching_scalars); let batching_scalars_up = &eval_eq_batching_scalars[..n_columns]; @@ -155,13 +132,12 @@ fn open_columns>>( let sub_evals = verifier_state.next_extension_scalars_vec(1 << univariate_skips)?; - if dot_product::( - sub_evals.iter().copied(), - outer_selector_evals.iter().copied(), - ) != dot_product::( - evals_up_and_down.iter().copied(), - eval_eq_batching_scalars.iter().copied(), - ) { + if dot_product::(sub_evals.iter().copied(), outer_selector_evals.iter().copied()) + != dot_product::( + evals_up_and_down.iter().copied(), + eval_eq_batching_scalars.iter().copied(), + ) + { return Err(ProofError::InvalidProof); } @@ -173,9 +149,8 @@ fn open_columns>>( return Err(ProofError::InvalidProof); } - let matrix_up_sc_eval = - MultilinearPoint([epsilons.0.clone(), outer_sumcheck_challenge.point.0.clone()].concat()) - .eq_poly_outside(&inner_sumcheck_stement.point); + let matrix_up_sc_eval = MultilinearPoint([epsilons.0.clone(), outer_sumcheck_challenge.point.0.clone()].concat()) + .eq_poly_outside(&inner_sumcheck_stement.point); let matrix_down_sc_eval = next_mle( &[ epsilons.0, @@ -185,10 +160,8 @@ fn open_columns>>( .concat(), ); - let evaluations_remaining_to_verify_f = - verifier_state.next_extension_scalars_vec(n_columns_f)?; - let evaluations_remaining_to_verify_ef = - verifier_state.next_extension_scalars_vec(n_columns_ef)?; + let evaluations_remaining_to_verify_f = verifier_state.next_extension_scalars_vec(n_columns_f)?; + let evaluations_remaining_to_verify_ef = verifier_state.next_extension_scalars_vec(n_columns_ef)?; let evaluations_remaining_to_verify = [ evaluations_remaining_to_verify_f.clone(), evaluations_remaining_to_verify_ef.clone(), @@ -211,8 +184,7 @@ fn open_columns>>( .sum::(); if inner_sumcheck_stement.value - != matrix_up_sc_eval * batched_col_up_sc_eval - + matrix_down_sc_eval * batched_col_down_sc_eval + != matrix_up_sc_eval * batched_col_up_sc_eval + matrix_down_sc_eval * batched_col_down_sc_eval { return Err(ProofError::InvalidProof); } diff --git a/crates/air/tests/complex_air.rs b/crates/air/tests/complex_air.rs index eeca103b..e29f6f51 100644 --- a/crates/air/tests/complex_air.rs +++ b/crates/air/tests/complex_air.rs @@ -13,11 +13,7 @@ const N_COLS_F: usize = 2; type F = KoalaBear; type EF = QuinticExtensionFieldKB; -struct ExampleStructuredAir< - const N_COLUMNS: usize, - const N_PREPROCESSED_COLUMNS: usize, - const VIRTUAL_COLUMN: bool, ->; +struct ExampleStructuredAir; impl Air for ExampleStructuredAir @@ -93,9 +89,7 @@ fn generate_trace( witness_col.push( witness_cols_j_i_min_1 + F::from_usize(j + N_PREPROCESSED_COLUMNS) - + (0..3).map(|k| trace_ef[k][i - 1]).product::() - * trace_f[0][i - 1] - * trace_f[1][i - 1], + + (0..3).map(|k| trace_ef[k][i - 1]).product::() * trace_f[0][i - 1] * trace_f[1][i - 1], ); } } @@ -116,20 +110,10 @@ fn test_air_helper() { let n_rows = 1 << log_n_rows; let mut prover_state = build_prover_state::(); - let (columns_plus_one_f, columns_plus_one_ef) = - generate_trace::(n_rows + 1); - let columns_ref_f = columns_plus_one_f - .iter() - .map(|col| &col[..n_rows]) - .collect::>(); - let columns_ref_ef = columns_plus_one_ef - .iter() - .map(|col| &col[..n_rows]) - .collect::>(); - let mut last_row_ef = columns_plus_one_ef - .iter() - .map(|col| col[n_rows]) - .collect::>(); + let (columns_plus_one_f, columns_plus_one_ef) = generate_trace::(n_rows + 1); + let columns_ref_f = columns_plus_one_f.iter().map(|col| &col[..n_rows]).collect::>(); + let columns_ref_ef = columns_plus_one_ef.iter().map(|col| &col[..n_rows]).collect::>(); + let mut last_row_ef = columns_plus_one_ef.iter().map(|col| col[n_rows]).collect::>(); last_row_ef = last_row_ef[N_PREPROCESSED_COLUMNS - N_COLS_F..].to_vec(); let virtual_column_statement_prover = if VIRTUAL_COLUMN { @@ -172,29 +156,20 @@ fn test_air_helper() { let air = ExampleStructuredAir:: {}; - check_air_validity( + check_air_validity(&air, &vec![], &columns_ref_f, &columns_ref_ef, &[], &last_row_ef).unwrap(); + + let (point_prover, evaluations_remaining_to_prove_f, evaluations_remaining_to_prove_ef) = prove_air( + &mut prover_state, &air, - &vec![], + vec![], + UNIVARIATE_SKIPS, &columns_ref_f, &columns_ref_ef, &[], &last_row_ef, - ) - .unwrap(); - - let (point_prover, evaluations_remaining_to_prove_f, evaluations_remaining_to_prove_ef) = - prove_air( - &mut prover_state, - &air, - vec![], - UNIVARIATE_SKIPS, - &columns_ref_f, - &columns_ref_ef, - &[], - &last_row_ef, - virtual_column_statement_prover, - true, - ); + virtual_column_statement_prover, + true, + ); let mut verifier_state = build_verifier_state(&prover_state); let virtual_column_statement_verifier = if VIRTUAL_COLUMN { @@ -210,27 +185,20 @@ fn test_air_helper() { None }; - let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = - verify_air( - &mut verifier_state, - &air, - vec![], - UNIVARIATE_SKIPS, - log_n_rows, - &[], - &last_row_ef, - virtual_column_statement_verifier, - ) - .unwrap(); + let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = verify_air( + &mut verifier_state, + &air, + vec![], + UNIVARIATE_SKIPS, + log_n_rows, + &[], + &last_row_ef, + virtual_column_statement_verifier, + ) + .unwrap(); assert_eq!(point_prover, point_verifier); - assert_eq!( - &evaluations_remaining_to_prove_f, - &evaluations_remaining_to_verify_f - ); - assert_eq!( - &evaluations_remaining_to_prove_ef, - &evaluations_remaining_to_verify_ef - ); + assert_eq!(&evaluations_remaining_to_prove_f, &evaluations_remaining_to_verify_f); + assert_eq!(&evaluations_remaining_to_prove_ef, &evaluations_remaining_to_verify_ef); for i in 0..N_COLS_F { assert_eq!( columns_ref_f[i].evaluate(&point_prover), diff --git a/crates/air/tests/fib_air.rs b/crates/air/tests/fib_air.rs index 746cc66b..57659134 100644 --- a/crates/air/tests/fib_air.rs +++ b/crates/air/tests/fib_air.rs @@ -80,42 +80,34 @@ fn test_air_fibonacci() { ) .unwrap(); - let (point_prover, evaluations_remaining_to_prove_f, evaluations_remaining_to_prove_ef) = - prove_air( - &mut prover_state, - &air, - vec![], - UNIVARIATE_SKIPS, - &columns_ref_f, - &columns_ref_ef, - &last_row_f, - &last_row_ef, - None, - true, - ); + let (point_prover, evaluations_remaining_to_prove_f, evaluations_remaining_to_prove_ef) = prove_air( + &mut prover_state, + &air, + vec![], + UNIVARIATE_SKIPS, + &columns_ref_f, + &columns_ref_ef, + &last_row_f, + &last_row_ef, + None, + true, + ); let mut verifier_state = build_verifier_state(&prover_state); - let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = - verify_air( - &mut verifier_state, - &air, - vec![], - UNIVARIATE_SKIPS, - log_n_rows, - &last_row_f, - &last_row_ef, - None, - ) - .unwrap(); + let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = verify_air( + &mut verifier_state, + &air, + vec![], + UNIVARIATE_SKIPS, + log_n_rows, + &last_row_f, + &last_row_ef, + None, + ) + .unwrap(); assert_eq!(point_prover, point_verifier); - assert_eq!( - &evaluations_remaining_to_prove_f, - &evaluations_remaining_to_verify_f - ); - assert_eq!( - &evaluations_remaining_to_prove_ef, - &evaluations_remaining_to_verify_ef - ); + assert_eq!(&evaluations_remaining_to_prove_f, &evaluations_remaining_to_verify_f); + assert_eq!(&evaluations_remaining_to_prove_ef, &evaluations_remaining_to_verify_ef); assert_eq!( columns_ref_f[0].evaluate(&point_prover), evaluations_remaining_to_verify_f[0] diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index f54a9691..4a7fc1b6 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -2,8 +2,8 @@ use crate::{ Counter, F, ir::HighLevelOperation, lang::{ - AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, - Expression, Function, Line, Program, SimpleExpr, Var, + AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Expression, Function, + Line, Program, SimpleExpr, Var, }, }; use lean_vm::{SourceLineNumber, Table, TableT}; @@ -39,13 +39,9 @@ impl From for SimpleExpr { fn from(var_or_const: VarOrConstMallocAccess) -> Self { match var_or_const { VarOrConstMallocAccess::Var(var) => Self::Var(var), - VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - } => Self::ConstMallocAccess { - malloc_label, - offset, - }, + VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => { + Self::ConstMallocAccess { malloc_label, offset } + } } } } @@ -56,13 +52,9 @@ impl TryInto for SimpleExpr { fn try_into(self) -> Result { match self { Self::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), - Self::ConstMallocAccess { - malloc_label, - offset, - } => Ok(VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - }), + Self::ConstMallocAccess { malloc_label, offset } => { + Ok(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) + } _ => Err(()), } } @@ -225,8 +217,7 @@ impl ArrayManager { } let new_var = format!("@arr_aux_{}", self.counter); self.counter += 1; - self.aux_vars - .insert((array.clone(), index.clone()), new_var.clone()); + self.aux_vars.insert((array.clone(), index.clone()), new_var.clone()); new_var } } @@ -243,14 +234,10 @@ fn simplify_lines( for line in lines { match line { Line::Match { value, arms } => { - let simple_value = - simplify_expr(value, &mut res, counters, array_manager, const_malloc); + let simple_value = simplify_expr(value, &mut res, counters, array_manager, const_malloc); let mut simple_arms = vec![]; for (i, (pattern, statements)) in arms.iter().enumerate() { - assert_eq!( - *pattern, i, - "match patterns should be consecutive, starting from 0" - ); + assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0"); simple_arms.push(simplify_lines( statements, counters, @@ -285,14 +272,9 @@ fn simplify_lines( const_malloc, ); } - Expression::Binary { - left, - operation, - right, - } => { + Expression::Binary { left, operation, right } => { let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); + let right = simplify_expr(right, &mut res, counters, array_manager, const_malloc); res.push(SimpleLine::Assignment { var: var.clone().into(), operation: *operation, @@ -302,11 +284,7 @@ fn simplify_lines( } Expression::Log2Ceil { .. } => unreachable!(), }, - Line::ArrayAssign { - array, - index, - value, - } => { + Line::ArrayAssign { array, index, value } => { handle_array_assignment( counters, &mut res, @@ -320,8 +298,7 @@ fn simplify_lines( Line::Assert(boolean, line_number) => match boolean { Boolean::Different { left, right } => { let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); + let right = simplify_expr(right, &mut res, counters, array_manager, const_malloc); let diff_var = format!("@aux_var_{}", counters.aux_vars); counters.aux_vars += 1; res.push(SimpleLine::Assignment { @@ -339,8 +316,7 @@ fn simplify_lines( } Boolean::Equal { left, right } => { let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); + let right = simplify_expr(right, &mut res, counters, array_manager, const_malloc); let (var, other) = if let Ok(left) = left.clone().try_into() { (left, right) } else if let Ok(right) = right.clone().try_into() { @@ -367,18 +343,12 @@ fn simplify_lines( // Transform if a == b then X else Y into if a != b then Y else X let (left, right, then_branch, else_branch) = match condition { - Boolean::Equal { left, right } => { - (left, right, else_branch, then_branch) - } // switched - Boolean::Different { left, right } => { - (left, right, then_branch, else_branch) - } + Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched + Boolean::Different { left, right } => (left, right, then_branch, else_branch), }; - let left_simplified = - simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right_simplified = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); + let left_simplified = simplify_expr(left, &mut res, counters, array_manager, const_malloc); + let right_simplified = simplify_expr(right, &mut res, counters, array_manager, const_malloc); let diff_var = format!("@diff_{}", counters.aux_vars); counters.aux_vars += 1; @@ -391,13 +361,8 @@ fn simplify_lines( (diff_var.into(), then_branch, else_branch) } Condition::Expression(condition, assume_boolean) => { - let condition_simplified = simplify_expr( - condition, - &mut res, - counters, - array_manager, - const_malloc, - ); + let condition_simplified = + simplify_expr(condition, &mut res, counters, array_manager, const_malloc); match assume_boolean { AssumeBoolean::AssumeBoolean => {} @@ -408,9 +373,7 @@ fn simplify_lines( res.push(SimpleLine::Assignment { var: one_minus_condition_var.clone().into(), operation: HighLevelOperation::Sub, - arg0: SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::Scalar(1), - )), + arg0: SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(1))), arg1: condition_simplified.clone(), }); res.push(SimpleLine::TestZero { @@ -498,13 +461,7 @@ fn simplify_lines( for i in range { let mut body_copy = body.clone(); - replace_vars_for_unroll( - &mut body_copy, - iterator, - unroll_index, - i, - &internal_variables, - ); + replace_vars_for_unroll(&mut body_copy, iterator, unroll_index, i, &internal_variables); unrolled_lines.extend(simplify_lines( &body_copy, counters, @@ -555,10 +512,8 @@ fn simplify_lines( let mut external_vars: Vec<_> = external_vars.into_iter().collect(); - let start_simplified = - simplify_expr(start, &mut res, counters, array_manager, const_malloc); - let end_simplified = - simplify_expr(end, &mut res, counters, array_manager, const_malloc); + let start_simplified = simplify_expr(start, &mut res, counters, array_manager, const_malloc); + let end_simplified = simplify_expr(end, &mut res, counters, array_manager, const_malloc); for (simplified, original) in [ (start_simplified.clone(), start.clone()), @@ -617,10 +572,7 @@ fn simplify_lines( }); } Line::FunctionRet { return_data } => { - assert!( - !in_a_loop, - "Function return inside a loop is not currently supported" - ); + assert!(!in_a_loop, "Function return inside a loop is not currently supported"); let simplified_return_data = return_data .iter() .map(|ret| simplify_expr(ret, &mut res, counters, array_manager, const_malloc)) @@ -651,9 +603,7 @@ fn simplify_lines( } Line::Break => { assert!(in_a_loop, "Break statement outside of a loop"); - res.push(SimpleLine::FunctionRet { - return_data: vec![], - }); + res.push(SimpleLine::FunctionRet { return_data: vec![] }); } Line::MAlloc { var, @@ -661,27 +611,14 @@ fn simplify_lines( vectorized, vectorized_len, } => { - let simplified_size = - simplify_expr(size, &mut res, counters, array_manager, const_malloc); - let simplified_vectorized_len = simplify_expr( - vectorized_len, - &mut res, - counters, - array_manager, - const_malloc, - ); - if simplified_size.is_constant() - && !*vectorized - && const_malloc.forbidden_vars.contains(var) - { - println!( - "TODO: Optimization missed: Requires to align const malloc in if/else branches" - ); + let simplified_size = simplify_expr(size, &mut res, counters, array_manager, const_malloc); + let simplified_vectorized_len = + simplify_expr(vectorized_len, &mut res, counters, array_manager, const_malloc); + if simplified_size.is_constant() && !*vectorized && const_malloc.forbidden_vars.contains(var) { + println!("TODO: Optimization missed: Requires to align const malloc in if/else branches"); } match simplified_size { - SimpleExpr::Constant(const_size) - if !*vectorized && !const_malloc.forbidden_vars.contains(var) => - { + SimpleExpr::Constant(const_size) if !*vectorized && !const_malloc.forbidden_vars.contains(var) => { // TODO do this optimization even if we are in an if/else branch let label = const_malloc.counter; const_malloc.counter += 1; @@ -706,9 +643,7 @@ fn simplify_lines( assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); let simplified_to_decompose = to_decompose .iter() - .map(|expr| { - simplify_expr(expr, &mut res, counters, array_manager, const_malloc) - }) + .map(|expr| simplify_expr(expr, &mut res, counters, array_manager, const_malloc)) .collect::>(); let label = const_malloc.counter; const_malloc.counter += 1; @@ -723,9 +658,7 @@ fn simplify_lines( assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); let simplified_to_decompose = to_decompose .iter() - .map(|expr| { - simplify_expr(expr, &mut res, counters, array_manager, const_malloc) - }) + .map(|expr| simplify_expr(expr, &mut res, counters, array_manager, const_malloc)) .collect::>(); let label = const_malloc.counter; const_malloc.counter += 1; @@ -743,9 +676,7 @@ fn simplify_lines( res.push(SimpleLine::Panic); } Line::LocationReport { location } => { - res.push(SimpleLine::LocationReport { - location: *location, - }); + res.push(SimpleLine::LocationReport { location: *location }); } } } @@ -791,17 +722,11 @@ fn simplify_expr( ); SimpleExpr::Var(aux_arr) } - Expression::Binary { - left, - operation, - right, - } => { + Expression::Binary { left, operation, right } => { let left_var = simplify_expr(left, lines, counters, array_manager, const_malloc); let right_var = simplify_expr(right, lines, counters, array_manager, const_malloc); - if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = - (&left_var, &right_var) - { + if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) { return SimpleExpr::Constant(ConstExpression::Binary { left: Box::new(left_cst.clone()), operation: *operation, @@ -835,19 +760,16 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { let mut internal_vars = BTreeSet::new(); let mut external_vars = BTreeSet::new(); - let on_new_expr = - |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - for var in vars_in_expression(expr) { - if !internal_vars.contains(&var) { - external_vars.insert(var); - } + let on_new_expr = |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { + for var in vars_in_expression(expr) { + if !internal_vars.contains(&var) { + external_vars.insert(var); } - }; + } + }; let on_new_condition = - |condition: &Condition, - internal_vars: &BTreeSet, - external_vars: &mut BTreeSet| match condition { + |condition: &Condition, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| match condition { Condition::Comparison(Boolean::Equal { left, right }) | Condition::Comparison(Boolean::Different { left, right }) => { on_new_expr(left, internal_vars, external_vars); @@ -865,11 +787,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { for (_, statements) in arms { let (stmt_internal, stmt_external) = find_variable_usage(statements); internal_vars.extend(stmt_internal); - external_vars.extend( - stmt_external - .into_iter() - .filter(|v| !internal_vars.contains(v)), - ); + external_vars.extend(stmt_external.into_iter().filter(|v| !internal_vars.contains(v))); } } Line::Assignment { var, value } => { @@ -895,9 +813,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { .cloned(), ); } - Line::FunctionCall { - args, return_data, .. - } => { + Line::FunctionCall { args, return_data, .. } => { for arg in args { on_new_expr(arg, &internal_vars, &mut external_vars); } @@ -929,8 +845,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { on_new_expr(var, &internal_vars, &mut external_vars); } } - Line::DecomposeBits { var, to_decompose } - | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { for expr in to_decompose { on_new_expr(expr, &internal_vars, &mut external_vars); } @@ -956,11 +871,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { on_new_expr(start, &internal_vars, &mut external_vars); on_new_expr(end, &internal_vars, &mut external_vars); } - Line::ArrayAssign { - array, - index, - value, - } => { + Line::ArrayAssign { array, index, value } => { on_new_expr(&array.clone().into(), &internal_vars, &mut external_vars); on_new_expr(index, &internal_vars, &mut external_vars); on_new_expr(value, &internal_vars, &mut external_vars); @@ -972,11 +883,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { (internal_vars, external_vars) } -fn inline_simple_expr( - simple_expr: &mut SimpleExpr, - args: &BTreeMap, - inlining_count: usize, -) { +fn inline_simple_expr(simple_expr: &mut SimpleExpr, args: &BTreeMap, inlining_count: usize) { if let SimpleExpr::Var(var) = simple_expr { if let Some(replacement) = args.get(var) { *simple_expr = replacement.clone(); @@ -1005,12 +912,7 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining } } -pub fn inline_lines( - lines: &mut Vec, - args: &BTreeMap, - res: &[Var], - inlining_count: usize, -) { +pub fn inline_lines(lines: &mut Vec, args: &BTreeMap, res: &[Var], inlining_count: usize) { let inline_comparison = |comparison: &mut Boolean| { let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = comparison; inline_expr(left, args, inlining_count); @@ -1103,8 +1005,7 @@ pub fn inline_lines( inline_expr(var, args, inlining_count); } } - Line::DecomposeBits { var, to_decompose } - | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { for expr in to_decompose { inline_expr(expr, args, inlining_count); } @@ -1127,11 +1028,7 @@ pub fn inline_lines( inline_expr(start, args, inlining_count); inline_expr(end, args, inlining_count); } - Line::ArrayAssign { - array, - index, - value, - } => { + Line::ArrayAssign { array, index, value } => { inline_simple_expr(array, args, inlining_count); inline_expr(index, args, inlining_count); inline_expr(value, args, inlining_count); @@ -1189,11 +1086,7 @@ fn handle_array_assignment( if let SimpleExpr::Constant(offset) = simplified_index.clone() && let SimpleExpr::Var(array_var) = &array && let Some(label) = const_malloc.map.get(array_var) - && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { - left, - operation, - right, - }) = &access_type + && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { left, operation, right }) = &access_type { let arg0 = simplify_expr(left, res, counters, array_manager, const_malloc); let arg1 = simplify_expr(right, res, counters, array_manager, const_malloc); @@ -1211,9 +1104,7 @@ fn handle_array_assignment( let value_simplified = match access_type { ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), - ArrayAccessType::ArrayIsAssigned(expr) => { - simplify_expr(&expr, res, counters, array_manager, const_malloc) - } + ArrayAccessType::ArrayIsAssigned(expr) => simplify_expr(&expr, res, counters, array_manager, const_malloc), }; // TODO opti: in some case we could use ConstMallocAccess @@ -1269,9 +1160,7 @@ fn create_recursive_function( return_data: vec![], line_number, }); - body.push(SimpleLine::FunctionRet { - return_data: vec![], - }); + body.push(SimpleLine::FunctionRet { return_data: vec![] }); let diff_var = format!("@diff_{iterator}"); @@ -1285,9 +1174,7 @@ fn create_recursive_function( SimpleLine::IfNotZero { condition: diff_var.into(), then_branch: body, - else_branch: vec![SimpleLine::FunctionRet { - return_data: vec![], - }], + else_branch: vec![SimpleLine::FunctionRet { return_data: vec![] }], line_number, }, ]; @@ -1326,38 +1213,14 @@ fn replace_vars_for_unroll_in_expr( } } - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars); } Expression::Binary { left, right, .. } => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll_in_expr(right, iterator, unroll_index, iterator_value, internal_vars); } Expression::Log2Ceil { value } => { - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); } } } @@ -1372,33 +1235,15 @@ fn replace_vars_for_unroll( for line in lines { match line { Line::Match { value, arms } => { - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); for (_, statements) in arms { - replace_vars_for_unroll( - statements, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll(statements, iterator, unroll_index, iterator_value, internal_vars); } } Line::Assignment { var, value } => { assert!(var != iterator, "Weird"); *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); } Line::ArrayAssign { // array[index] = value @@ -1409,43 +1254,15 @@ fn replace_vars_for_unroll( if let SimpleExpr::Var(array_var) = array { assert!(array_var != iterator, "Weird"); if internal_vars.contains(array_var) { - *array_var = - format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); + *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); } } - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); } - Line::Assert( - Boolean::Equal { left, right } | Boolean::Different { left, right }, - _line_number, - ) => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + Line::Assert(Boolean::Equal { left, right } | Boolean::Different { left, right }, _line_number) => { + replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll_in_expr(right, iterator, unroll_index, iterator_value, internal_vars); } Line::IfCondition { condition, @@ -1454,49 +1271,17 @@ fn replace_vars_for_unroll( line_number: _, } => { match condition { - Condition::Comparison( - Boolean::Equal { left, right } | Boolean::Different { left, right }, - ) => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + Condition::Comparison(Boolean::Equal { left, right } | Boolean::Different { left, right }) => { + replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll_in_expr(right, iterator, unroll_index, iterator_value, internal_vars); } Condition::Expression(expr, _assume_bool) => { - replace_vars_for_unroll_in_expr( - expr, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars); } } - replace_vars_for_unroll( - then_branch, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - else_branch, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll(then_branch, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll(else_branch, iterator, unroll_index, iterator_value, internal_vars); } Line::ForLoop { iterator: other_iterator, @@ -1508,29 +1293,10 @@ fn replace_vars_for_unroll( line_number: _, } => { assert!(other_iterator != iterator); - *other_iterator = - format!("@unrolled_{unroll_index}_{iterator_value}_{other_iterator}"); - replace_vars_for_unroll_in_expr( - start, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - end, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - body, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + *other_iterator = format!("@unrolled_{unroll_index}_{iterator_value}_{other_iterator}"); + replace_vars_for_unroll_in_expr(start, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll_in_expr(end, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll(body, iterator, unroll_index, iterator_value, internal_vars); } Line::FunctionCall { function_name: _, @@ -1540,13 +1306,7 @@ fn replace_vars_for_unroll( } => { // Function calls are not unrolled, so we don't need to change them for arg in args { - replace_vars_for_unroll_in_expr( - arg, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); } for ret in return_data { *ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}"); @@ -1554,37 +1314,19 @@ fn replace_vars_for_unroll( } Line::FunctionRet { return_data } => { for ret in return_data { - replace_vars_for_unroll_in_expr( - ret, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(ret, iterator, unroll_index, iterator_value, internal_vars); } } Line::Precompile { table: _, args } => { for arg in args { - replace_vars_for_unroll_in_expr( - arg, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); } } Line::Print { line_info, content } => { // Print statements are not unrolled, so we don't need to change them *line_info += &format!(" (unrolled {unroll_index} {iterator_value})"); for var in content { - replace_vars_for_unroll_in_expr( - var, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(var, iterator, unroll_index, iterator_value, internal_vars); } } Line::MAlloc { @@ -1595,33 +1337,14 @@ fn replace_vars_for_unroll( } => { assert!(var != iterator, "Weird"); *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr( - size, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - vectorized_len, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(size, iterator, unroll_index, iterator_value, internal_vars); + replace_vars_for_unroll_in_expr(vectorized_len, iterator, unroll_index, iterator_value, internal_vars); } - Line::DecomposeBits { var, to_decompose } - | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { assert!(var != iterator, "Weird"); *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); for expr in to_decompose { - replace_vars_for_unroll_in_expr( - expr, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars); } } Line::CounterHint { var } => { @@ -1660,12 +1383,7 @@ fn handle_inlined_functions(program: &mut Program) { let mut counter2 = Counter::new(); let old_body = func.body.clone(); - handle_inlined_functions_helper( - &mut func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); + handle_inlined_functions_helper(&mut func.body, &inlined_functions, &mut counter1, &mut counter2); if func.body != old_body { any_changes = true; @@ -1681,12 +1399,7 @@ fn handle_inlined_functions(program: &mut Program) { let mut counter2 = Counter::new(); let old_body = func.body.clone(); - handle_inlined_functions_helper( - &mut func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); + handle_inlined_functions_helper(&mut func.body, &inlined_functions, &mut counter1, &mut counter2); if func.body != old_body { any_changes = true; @@ -1701,10 +1414,7 @@ fn handle_inlined_functions(program: &mut Program) { max_iterations -= 1; } - assert!( - max_iterations > 0, - "Too many iterations processing inline functions" - ); + assert!(max_iterations > 0, "Too many iterations processing inline functions"); // Remove all inlined functions from the program (they've been inlined) for func_name in inlined_functions.keys() { @@ -1750,12 +1460,7 @@ fn handle_inlined_functions_helper( .map(|((var, _), expr)| (var.clone(), expr.clone())) .collect::>(); let mut func_body = func.body.clone(); - inline_lines( - &mut func_body, - &inlined_args, - return_data, - total_inlined_counter.next(), - ); + inline_lines(&mut func_body, &inlined_args, return_data, total_inlined_counter.next()); inlined_lines.extend(func_body); lines.remove(i); // remove the call to the inlined function @@ -1780,24 +1485,12 @@ fn handle_inlined_functions_helper( total_inlined_counter, ); } - Line::ForLoop { - body, unroll: _, .. - } => { - handle_inlined_functions_helper( - body, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); + Line::ForLoop { body, unroll: _, .. } => { + handle_inlined_functions_helper(body, inlined_functions, inlined_var_counter, total_inlined_counter); } Line::Match { arms, .. } => { for (_, arm) in arms { - handle_inlined_functions_helper( - arm, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); + handle_inlined_functions_helper(arm, inlined_functions, inlined_var_counter, total_inlined_counter); } } _ => {} @@ -1836,11 +1529,7 @@ fn handle_const_arguments(program: &mut Program) { for name in function_names { if let Some(func) = new_functions.get_mut(&name) { let initial_count = additional_functions.len(); - handle_const_arguments_helper( - &mut func.body, - &constant_functions, - &mut additional_functions, - ); + handle_const_arguments_helper(&mut func.body, &constant_functions, &mut additional_functions); if additional_functions.len() > initial_count { changed = true; } @@ -1883,9 +1572,9 @@ fn handle_const_arguments_helper( let mut const_evals = Vec::new(); for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { if *is_constant { - let const_eval = arg_expr.naive_eval().unwrap_or_else(|| { - panic!("Failed to evaluate constant argument: {arg_expr}") - }); + let const_eval = arg_expr + .naive_eval() + .unwrap_or_else(|| panic!("Failed to evaluate constant argument: {arg_expr}")); const_evals.push((arg_var.clone(), const_eval)); } } @@ -1913,10 +1602,7 @@ fn handle_const_arguments_helper( } let mut new_body = func.body.clone(); - replace_vars_by_const_in_lines( - &mut new_body, - &const_evals.iter().cloned().collect(), - ); + replace_vars_by_const_in_lines(&mut new_body, &const_evals.iter().cloned().collect()); new_functions.insert( const_funct_name.clone(), Function { @@ -1942,9 +1628,7 @@ fn handle_const_arguments_helper( handle_const_arguments_helper(then_branch, constant_functions, new_functions); handle_const_arguments_helper(else_branch, constant_functions, new_functions); } - Line::ForLoop { - body, unroll: _, .. - } => { + Line::ForLoop { body, unroll: _, .. } => { // TODO we should unroll before const arguments handling handle_const_arguments_helper(body, constant_functions, new_functions); } @@ -1968,10 +1652,7 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) }, Expression::ArrayAccess { array, index } => { if let SimpleExpr::Var(array_var) = array { - assert!( - !map.contains_key(array_var), - "Array {array_var} is a constant" - ); + assert!(!map.contains_key(array_var), "Array {array_var} is a constant"); } replace_vars_by_const_in_expr(index, map); } @@ -2037,31 +1718,19 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(value, map); } - Line::ArrayAssign { - array, - index, - value, - } => { + Line::ArrayAssign { array, index, value } => { if let SimpleExpr::Var(array_var) = array { - assert!( - !map.contains_key(array_var), - "Array {array_var} is a constant" - ); + assert!(!map.contains_key(array_var), "Array {array_var} is a constant"); } replace_vars_by_const_in_expr(index, map); replace_vars_by_const_in_expr(value, map); } - Line::FunctionCall { - args, return_data, .. - } => { + Line::FunctionCall { args, return_data, .. } => { for arg in args { replace_vars_by_const_in_expr(arg, map); } for ret in return_data { - assert!( - !map.contains_key(ret), - "Return variable {ret} is a constant" - ); + assert!(!map.contains_key(ret), "Return variable {ret} is a constant"); } } Line::IfCondition { @@ -2083,9 +1752,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_lines(then_branch, map); replace_vars_by_const_in_lines(else_branch, map); } - Line::ForLoop { - body, start, end, .. - } => { + Line::ForLoop { body, start, end, .. } => { replace_vars_by_const_in_expr(start, map); replace_vars_by_const_in_expr(end, map); replace_vars_by_const_in_lines(body, map); @@ -2111,8 +1778,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_expr(var, map); } } - Line::DecomposeBits { var, to_decompose } - | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); for expr in to_decompose { replace_vars_by_const_in_expr(expr, map); @@ -2139,10 +1805,7 @@ impl Display for VarOrConstMallocAccess { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Var(var) => write!(f, "{var}"), - Self::ConstMallocAccess { - malloc_label, - offset, - } => { + Self::ConstMallocAccess { malloc_label, offset } => { write!(f, "ConstMallocAccess({malloc_label}, {offset})") } } @@ -2216,11 +1879,7 @@ impl SimpleLine { Self::RawAccess { res, index, shift } => { format!("memory[{index} + {shift}] = {res}") } - Self::TestZero { - operation, - arg0, - arg1, - } => { + Self::TestZero { operation, arg0, arg1 } => { format!("0 = {arg0} {operation} {arg1}") } Self::IfNotZero { @@ -2244,9 +1903,7 @@ impl SimpleLine { if else_branch.is_empty() { format!("if {condition} != 0 {{\n{then_str}\n{spaces}}}") } else { - format!( - "if {condition} != 0 {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" - ) + format!("if {condition} != 0 {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}") } } Self::FunctionCall { @@ -2255,11 +1912,7 @@ impl SimpleLine { return_data, line_number: _, } => { - let args_str = args - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); + let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); let return_data_str = return_data .iter() .map(|var| var.to_string()) @@ -2287,21 +1940,11 @@ impl SimpleLine { format!( "{}({})", &precompile.name(), - args.iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", ") + args.iter().map(|arg| format!("{arg}")).collect::>().join(", ") ) } - Self::Print { - line_info: _, - content, - } => { - let content_str = content - .iter() - .map(|c| format!("{c}")) - .collect::>() - .join(", "); + Self::Print { line_info: _, content } => { + let content_str = content.iter().map(|c| format!("{c}")).collect::>().join(", "); format!("print({content_str})") } Self::HintMAlloc { @@ -2316,11 +1959,7 @@ impl SimpleLine { format!("{var} = malloc({size})") } } - Self::ConstMalloc { - var, - size, - label: _, - } => { + Self::ConstMalloc { var, size, label: _ } => { format!("{var} = malloc({size})") } Self::Panic => "panic".to_string(), @@ -2347,11 +1986,7 @@ impl Display for SimpleFunction { .join("\n"); if self.instructions.is_empty() { - write!( - f, - "fn {}({}) -> {} {{}}", - self.name, args_str, self.n_returned_vars - ) + write!(f, "fn {}({}) -> {} {{}}", self.name, args_str, self.n_returned_vars) } else { write!( f, diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index f67c8225..e77ac8ae 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -28,10 +28,7 @@ impl Compiler { .get(var) .unwrap_or_else(|| panic!("Variable {var} not in scope"))) .into(), - VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - } => ConstExpression::Binary { + VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => ConstExpression::Binary { left: Box::new( self.const_mallocs .get(malloc_label) @@ -53,10 +50,7 @@ impl SimpleExpr { offset: compiler.get_offset(&var.clone().into()), }, Self::Constant(c) => IntermediaryMemOrFpOrConstant::Constant(c.clone()), - Self::ConstMallocAccess { - malloc_label, - offset, - } => IntermediaryMemOrFpOrConstant::MemoryAfterFp { + Self::ConstMallocAccess { malloc_label, offset } => IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { malloc_label: *malloc_label, offset: offset.clone(), @@ -73,10 +67,7 @@ impl IntermediateValue { offset: compiler.get_offset(&var.clone().into()), }, SimpleExpr::Constant(c) => Self::Constant(c.clone()), - SimpleExpr::ConstMallocAccess { - malloc_label, - offset, - } => Self::MemoryAfterFp { + SimpleExpr::ConstMallocAccess { malloc_label, offset } => Self::MemoryAfterFp { offset: ConstExpression::Binary { left: Box::new( compiler @@ -93,25 +84,18 @@ impl IntermediateValue { } } - fn from_var_or_const_malloc_access( - var_or_const: &VarOrConstMallocAccess, - compiler: &Compiler, - ) -> Self { + fn from_var_or_const_malloc_access(var_or_const: &VarOrConstMallocAccess, compiler: &Compiler) -> Self { Self::from_simple_expr(&var_or_const.clone().into(), compiler) } } -pub fn compile_to_intermediate_bytecode( - simple_program: SimpleProgram, -) -> Result { +pub fn compile_to_intermediate_bytecode(simple_program: SimpleProgram) -> Result { let mut compiler = Compiler::default(); let mut memory_sizes = BTreeMap::new(); for function in simple_program.functions.values() { let instructions = compile_function(function, &mut compiler)?; - compiler - .bytecode - .insert(Label::function(&function.name), instructions); + compiler.bytecode.insert(Label::function(&function.name), instructions); memory_sizes.insert(function.name.clone(), compiler.stack_size); } @@ -191,11 +175,7 @@ fn compile_lines( } } - SimpleLine::TestZero { - operation, - arg0, - arg1, - } => { + SimpleLine::TestZero { operation, arg0, arg1 } => { instructions.push(IntermediateInstruction::computation( *operation, IntermediateValue::from_simple_expr(arg0, compiler), @@ -230,10 +210,7 @@ fn compile_lines( *declared_vars = if i == 0 { arm_declared_vars } else { - declared_vars - .intersection(&arm_declared_vars) - .cloned() - .collect() + declared_vars.intersection(&arm_declared_vars).cloned().collect() }; } compiler.stack_size = new_stack_size; @@ -249,8 +226,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Computation { operation: Operation::Mul, arg_a: value_simplified, - arg_c: ConstExpression::Value(ConstantValue::MatchBlockSize { match_index }) - .into(), + arg_c: ConstExpression::Value(ConstantValue::MatchBlockSize { match_index }).into(), res: value_scaled_offset.clone(), }); @@ -261,10 +237,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: value_scaled_offset, - arg_c: ConstExpression::Value(ConstantValue::MatchFirstBlockStart { - match_index, - }) - .into(), + arg_c: ConstExpression::Value(ConstantValue::MatchFirstBlockStart { match_index }).into(), res: jump_dest_offset.clone(), }); instructions.push(IntermediateInstruction::Jump { @@ -272,13 +245,7 @@ fn compile_lines( updated_fp: None, }); - let remaining = compile_lines( - function_name, - &lines[i + 1..], - compiler, - final_jump, - declared_vars, - )?; + let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump, declared_vars)?; compiler.bytecode.insert(end_label, remaining); return Ok(instructions); @@ -386,21 +353,12 @@ fn compile_lines( let else_stack = compiler.stack_size; compiler.stack_size = then_stack.max(else_stack); - *declared_vars = then_declared_vars - .intersection(&else_declared_vars) - .cloned() - .collect(); + *declared_vars = then_declared_vars.intersection(&else_declared_vars).cloned().collect(); compiler.bytecode.insert(if_label, then_instructions); compiler.bytecode.insert(else_label, else_instructions); - let remaining = compile_lines( - function_name, - &lines[i + 1..], - compiler, - final_jump, - declared_vars, - )?; + let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump, declared_vars)?; compiler.bytecode.insert(end_label, remaining); return Ok(instructions); @@ -487,11 +445,7 @@ fn compile_lines( arg_a: IntermediateValue::from_simple_expr(&args[0], compiler), arg_b: IntermediateValue::from_simple_expr(&args[1], compiler), arg_c: IntermediateValue::from_simple_expr(&args[2], compiler), - aux: args - .get(3) - .unwrap_or(&SimpleExpr::zero()) - .as_constant() - .unwrap(), + aux: args.get(3).unwrap_or(&SimpleExpr::zero()).as_constant().unwrap(), }); } @@ -599,9 +553,7 @@ fn compile_lines( }); } SimpleLine::LocationReport { location } => { - instructions.push(IntermediateInstruction::LocationReport { - location: *location, - }); + instructions.push(IntermediateInstruction::LocationReport { location: *location }); } } } @@ -646,10 +598,7 @@ fn mark_vars_as_declared>(vocs: &[VoC], declared: &mut B } } -fn validate_vars_declared>( - vocs: &[VoC], - declared: &BTreeSet, -) -> Result<(), String> { +fn validate_vars_declared>(vocs: &[VoC], declared: &BTreeSet) -> Result<(), String> { for voc in vocs { if let SimpleExpr::Var(v) = voc.borrow() && !declared.contains(v) @@ -677,9 +626,7 @@ fn setup_function_call( IntermediateInstruction::Deref { shift_0: new_fp_pos.into(), shift_1: ConstExpression::zero(), - res: IntermediaryMemOrFpOrConstant::Constant(ConstExpression::label( - return_label.clone(), - )), + res: IntermediaryMemOrFpOrConstant::Constant(ConstExpression::label(return_label.clone())), }, IntermediateInstruction::Deref { shift_0: new_fp_pos.into(), diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 3e875144..1b8934b3 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -44,10 +44,7 @@ pub fn compile_to_low_level_bytecode( }], ); - let starting_frame_memory = *intermediate_bytecode - .memory_size_per_function - .get("main") - .unwrap(); + let starting_frame_memory = *intermediate_bytecode.memory_size_per_function.get("main").unwrap(); let mut hints = BTreeMap::new(); let mut label_to_pc = BTreeMap::new(); @@ -81,16 +78,13 @@ pub fn compile_to_low_level_bytecode( for (label, instructions) in &intermediate_bytecode.bytecode { label_to_pc.insert(label.clone(), pc); if let Label::Function(function_name) = label { - hints - .entry(pc) - .or_insert_with(Vec::new) - .push(Hint::StackFrame { - label: label.clone(), - size: *intermediate_bytecode - .memory_size_per_function - .get(function_name) - .unwrap(), - }); + hints.entry(pc).or_insert_with(Vec::new).push(Hint::StackFrame { + label: label.clone(), + size: *intermediate_bytecode + .memory_size_per_function + .get(function_name) + .unwrap(), + }); } code_blocks.push((label.clone(), pc, instructions.clone())); pc += count_real_instructions(instructions); @@ -125,10 +119,7 @@ pub fn compile_to_low_level_bytecode( let mut low_level_bytecode = Vec::new(); for (label, pc) in label_to_pc.clone() { - hints - .entry(pc) - .or_insert_with(Vec::new) - .push(Hint::Label { label }); + hints.entry(pc).or_insert_with(Vec::new).push(Hint::Label { label }); } let compiler = Compiler { @@ -191,8 +182,7 @@ fn compile_block( condition: IntermediateValue, dest: IntermediateValue, updated_fp: Option| { - let dest = - try_as_mem_or_constant(&dest).expect("Fatal: Could not materialize jump destination"); + let dest = try_as_mem_or_constant(&dest).expect("Fatal: Could not materialize jump destination"); let label = match dest { MemOrConstant::Constant(dest) => hints .get(&usize::try_from(dest.as_canonical_u32()).unwrap()) @@ -265,20 +255,14 @@ fn compile_block( res: MemOrConstant::one(), }); } - IntermediateInstruction::Deref { - shift_0, - shift_1, - res, - } => { + IntermediateInstruction::Deref { shift_0, shift_1, res } => { low_level_bytecode.push(Instruction::Deref { shift_0: eval_const_expression(&shift_0, compiler).to_usize(), shift_1: eval_const_expression(&shift_1, compiler).to_usize(), res: match res { - IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset } => { - MemOrFpOrConstant::MemoryAfterFp { - offset: eval_const_expression_usize(&offset, compiler), - } - } + IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset } => MemOrFpOrConstant::MemoryAfterFp { + offset: eval_const_expression_usize(&offset, compiler), + }, IntermediaryMemOrFpOrConstant::Fp => MemOrFpOrConstant::Fp, IntermediaryMemOrFpOrConstant::Constant(c) => { MemOrFpOrConstant::Constant(eval_const_expression(&c, compiler)) @@ -292,8 +276,7 @@ fn compile_block( updated_fp, } => codegen_jump(hints, low_level_bytecode, condition, dest, updated_fp), IntermediateInstruction::Jump { dest, updated_fp } => { - let one = - IntermediateValue::Constant(ConstExpression::Value(ConstantValue::Scalar(1))); + let one = IntermediateValue::Constant(ConstExpression::Value(ConstantValue::Scalar(1))); codegen_jump(hints, low_level_bytecode, one, dest, updated_fp) } IntermediateInstruction::Precompile { @@ -355,9 +338,7 @@ fn compile_block( vectorized_len, } => { let size = try_as_mem_or_constant(&size).unwrap(); - let vectorized_len = try_as_constant(&vectorized_len, compiler) - .unwrap() - .to_usize(); + let vectorized_len = try_as_constant(&vectorized_len, compiler).unwrap().to_usize(); let hint = Hint::RequestMemory { function_name: function_name.clone(), offset: eval_const_expression_usize(&offset, compiler), @@ -411,9 +392,7 @@ fn eval_constant_value(constant: &ConstantValue, compiler: &Compiler) -> usize { } ConstantValue::Label(label) => compiler.label_to_pc.get(label).copied().unwrap(), ConstantValue::MatchBlockSize { match_index } => compiler.match_block_sizes[*match_index], - ConstantValue::MatchFirstBlockStart { match_index } => { - compiler.match_first_block_starts[*match_index] - } + ConstantValue::MatchFirstBlockStart { match_index } => compiler.match_first_block_starts[*match_index], } } diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index a5d38335..5b193356 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -63,7 +63,7 @@ pub enum IntermediateInstruction { res_offset: usize, }, Print { - line_info: String, // information about the line where the print occurs + line_info: String, // information about the line where the print occurs content: Vec, // values to print }, // noop, debug purpose only @@ -132,11 +132,7 @@ impl Display for IntermediateInstruction { } => { write!(f, "{res} = {arg_a} {operation} {arg_c}") } - Self::Deref { - shift_0, - shift_1, - res, - } => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"), + Self::Deref { shift_0, shift_1, res } => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"), Self::Panic => write!(f, "panic"), Self::Jump { dest, updated_fp } => { if let Some(fp) = updated_fp { @@ -175,10 +171,7 @@ impl Display for IntermediateInstruction { vectorized_len, } => { if *vectorized { - write!( - f, - "m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})" - ) + write!(f, "m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})") } else { write!(f, "m[fp + {offset}] = request_memory({size})") } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index efebac6b..ba1c001a 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -138,11 +138,7 @@ impl TryFrom for ConstExpression { Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), Expression::Value(_) => Err(()), Expression::ArrayAccess { .. } => Err(()), - Expression::Binary { - left, - operation, - right, - } => { + Expression::Binary { left, operation, right } => { let left_expr = Self::try_from(*left)?; let right_expr = Self::try_from(*right)?; Ok(Self::Binary { @@ -187,11 +183,9 @@ impl ConstExpression { { match self { Self::Value(value) => func(value), - Self::Binary { - left, - operation, - right, - } => Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?)), + Self::Binary { left, operation, right } => { + Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?)) + } Self::Log2Ceil { value } => { let value = value.eval_with(func)?; Some(F::from_usize(log2_ceil_usize(value.to_usize()))) @@ -276,10 +270,7 @@ impl From for Expression { impl Expression { pub fn naive_eval(&self) -> Option { - self.eval_with( - &|value: &SimpleExpr| value.as_constant()?.naive_eval(), - &|_, _| None, - ) + self.eval_with(&|value: &SimpleExpr| value.as_constant()?.naive_eval(), &|_, _| None) } pub fn eval_with(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option @@ -289,14 +280,8 @@ impl Expression { { match self { Self::Value(value) => value_fn(value), - Self::ArrayAccess { array, index } => { - array_fn(array, index.eval_with(value_fn, array_fn)?) - } - Self::Binary { - left, - operation, - right, - } => Some(operation.eval( + Self::ArrayAccess { array, index } => array_fn(array, index.eval_with(value_fn, array_fn)?), + Self::Binary { left, operation, right } => Some(operation.eval( left.eval_with(value_fn, array_fn)?, right.eval_with(value_fn, array_fn)?, )), @@ -401,11 +386,7 @@ impl Display for Expression { Self::ArrayAccess { array, index } => { write!(f, "{array}[{index}]") } - Self::Binary { - left, - operation, - right, - } => { + Self::Binary { left, operation, right } => { write!(f, "({left} {operation} {right})") } Self::Log2Ceil { value } => { @@ -441,11 +422,7 @@ impl Line { Self::Assignment { var, value } => { format!("{var} = {value}") } - Self::ArrayAssign { - array, - index, - value, - } => { + Self::ArrayAssign { array, index, value } => { format!("{array}[{index}] = {value}") } Self::Assert(condition, _line_number) => format!("assert {condition}"), @@ -470,9 +447,7 @@ impl Line { if else_branch.is_empty() { format!("if {condition} {{\n{then_str}\n{spaces}}}") } else { - format!( - "if {condition} {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" - ) + format!("if {condition} {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}") } } Self::CounterHint { var } => { @@ -509,11 +484,7 @@ impl Line { return_data, line_number: _, } => { - let args_str = args - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); + let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); let return_data_str = return_data .iter() .map(|var| var.to_string()) @@ -541,21 +512,11 @@ impl Line { format!( "{}({})", precompile.name(), - args.iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", ") + args.iter().map(|arg| format!("{arg}")).collect::>().join(", ") ) } - Self::Print { - line_info: _, - content, - } => { - let content_str = content - .iter() - .map(|c| format!("{c}")) - .collect::>() - .join(", "); + Self::Print { line_info: _, content } => { + let content_str = content.iter().map(|c| format!("{c}")).collect::>().join(", "); format!("print({content_str})") } Self::MAlloc { @@ -638,10 +599,7 @@ impl Display for SimpleExpr { match self { Self::Var(var) => write!(f, "{var}"), Self::Constant(constant) => write!(f, "{constant}"), - Self::ConstMallocAccess { - malloc_label, - offset, - } => { + Self::ConstMallocAccess { malloc_label, offset } => { write!(f, "malloc_access({malloc_label}, {offset})") } } @@ -652,11 +610,7 @@ impl Display for ConstExpression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Value(value) => write!(f, "{value}"), - Self::Binary { - left, - operation, - right, - } => { + Self::Binary { left, operation, right } => { write!(f, "({left} {operation} {right})") } Self::Log2Ceil { value } => { @@ -706,11 +660,7 @@ impl Display for Function { .join("\n"); if self.body.is_empty() { - write!( - f, - "fn {}({}) -> {} {{}}", - self.name, args_str, self.n_returned_vars - ) + write!(f, "fn {}({}) -> {} {{}}", self.name, args_str, self.n_returned_vars) } else { write!( f, diff --git a/crates/lean_compiler/src/parser/grammar.rs b/crates/lean_compiler/src/parser/grammar.rs index 78e25f68..a236c99c 100644 --- a/crates/lean_compiler/src/parser/grammar.rs +++ b/crates/lean_compiler/src/parser/grammar.rs @@ -19,10 +19,7 @@ pub fn get_location(pair: &ParsePair<'_>) -> (usize, usize) { } /// Utility function to safely get the next inner element from a parser. -pub fn next_inner<'i>( - mut pairs: impl Iterator>, - expected: &str, -) -> Option> { +pub fn next_inner<'i>(mut pairs: impl Iterator>, expected: &str) -> Option> { pairs.next().or_else(|| { eprintln!("Warning: Expected {expected} but found nothing"); None diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index 39c5c70d..6d0c0d6e 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -19,30 +19,14 @@ impl Parse for ExpressionParser { let inner = next_inner_pair(&mut pair.into_inner(), "expression body")?; Self::parse(inner, ctx) } - Rule::neq_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::NotEqual) - } - Rule::eq_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Equal) - } - Rule::add_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Add) - } - Rule::sub_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Sub) - } - Rule::mul_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mul) - } - Rule::mod_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mod) - } - Rule::div_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Div) - } - Rule::exp_expr => { - BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Exp) - } + Rule::neq_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::NotEqual), + Rule::eq_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Equal), + Rule::add_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Add), + Rule::sub_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Sub), + Rule::mul_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mul), + Rule::mod_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mod), + Rule::div_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Div), + Rule::exp_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Exp), Rule::primary => PrimaryExpressionParser::parse(pair, ctx), other_rule => Err(ParseError::SemanticError(SemanticError::new(format!( "ExpressionParser: Unexpected rule {other_rule:?}" @@ -102,9 +86,7 @@ pub struct ArrayAccessParser; impl Parse for ArrayAccessParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let array = next_inner_pair(&mut inner, "array name")? - .as_str() - .to_string(); + let array = next_inner_pair(&mut inner, "array name")?.as_str().to_string(); let index = ExpressionParser::parse(next_inner_pair(&mut inner, "array index")?, ctx)?; Ok(Expression::ArrayAccess { @@ -122,8 +104,6 @@ impl Parse for Log2CeilParser { let mut inner = pair.into_inner(); let expr = ExpressionParser::parse(next_inner_pair(&mut inner, "log2_ceil value")?, ctx)?; - Ok(Expression::Log2Ceil { - value: Box::new(expr), - }) + Ok(Expression::Log2Ceil { value: Box::new(expr) }) } } diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index c09a8a6d..4e43477e 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -18,9 +18,7 @@ pub struct FunctionParser; impl Parse for FunctionParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let name = next_inner_pair(&mut inner, "function name")? - .as_str() - .to_string(); + let name = next_inner_pair(&mut inner, "function name")?.as_str().to_string(); let mut arguments = Vec::new(); let mut n_returned_vars = 0; @@ -122,13 +120,7 @@ impl Parse for FunctionCallParser { if res_item.as_rule() == Rule::var_list { return_data = VarListParser::parse(res_item, ctx)? .into_iter() - .filter_map(|v| { - if let SimpleExpr::Var(var) = v { - Some(var) - } else { - None - } - }) + .filter_map(|v| if let SimpleExpr::Var(var) = v { Some(var) } else { None }) .collect(); } } @@ -190,9 +182,7 @@ impl FunctionCallParser { } "print" => { if !return_data.is_empty() { - return Err( - SemanticError::new("Print function should not return values").into(), - ); + return Err(SemanticError::new("Print function should not return values").into()); } Ok(Line::Print { line_info: function_name.clone(), @@ -227,9 +217,7 @@ impl FunctionCallParser { } "panic" => { if !return_data.is_empty() || !args.is_empty() { - return Err( - SemanticError::new("Panic has no args and returns no values").into(), - ); + return Err(SemanticError::new("Panic has no args and returns no values").into()); } Ok(Line::Panic) } diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index aac154af..df7d4cf3 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -17,9 +17,7 @@ pub struct ConstantDeclarationParser; impl Parse<(String, usize)> for ConstantDeclarationParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<(String, usize)> { let mut inner = pair.into_inner(); - let name = next_inner_pair(&mut inner, "constant name")? - .as_str() - .to_string(); + let name = next_inner_pair(&mut inner, "constant name")?.as_str().to_string(); let value_pair = next_inner_pair(&mut inner, "constant value")?; // Parse the expression and evaluate it @@ -35,10 +33,7 @@ impl Parse<(String, usize)> for ConstantDeclarationParser { &|_, _| None, ) .ok_or_else(|| { - SemanticError::with_context( - format!("Failed to evaluate constant: {name}"), - "constant declaration", - ) + SemanticError::with_context(format!("Failed to evaluate constant: {name}"), "constant declaration") })? .to_usize(); @@ -58,9 +53,7 @@ impl Parse for VarOrConstantParser { let inner = pair.into_inner().next().unwrap(); Self::parse(inner, ctx) } - Rule::identifier | Rule::constant_value => { - Self::parse_identifier_or_constant(text, ctx) - } + Rule::identifier | Rule::constant_value => Self::parse_identifier_or_constant(text, ctx), _ => Err(SemanticError::new("Expected identifier or constant").into()), } } @@ -82,15 +75,15 @@ impl VarOrConstantParser { _ => { // Try to resolve as defined constant if let Some(value) = ctx.get_constant(text) { - Ok(SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::Scalar(value), - ))) + Ok(SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar( + value, + )))) } // Try to parse as numeric literal else if let Ok(value) = text.parse::() { - Ok(SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::Scalar(value), - ))) + Ok(SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar( + value, + )))) } // Otherwise treat as variable reference else { @@ -112,10 +105,9 @@ impl Parse for ConstExprParser { Rule::constant_value => { let text = inner.as_str(); match text { - "public_input_start" => Err(SemanticError::new( - "public_input_start cannot be used as match pattern", - ) - .into()), + "public_input_start" => { + Err(SemanticError::new("public_input_start cannot be used as match pattern").into()) + } _ => { if let Some(value) = ctx.get_constant(text) { Ok(value) @@ -132,10 +124,7 @@ impl Parse for ConstExprParser { } } _ => Err(SemanticError::with_context( - format!( - "Only constant values are allowed in match patterns: {}", - inner.as_str() - ), + format!("Only constant values are allowed in match patterns: {}", inner.as_str()), "match pattern", ) .into()), diff --git a/crates/lean_compiler/src/parser/parsers/mod.rs b/crates/lean_compiler/src/parser/parsers/mod.rs index 1288b559..ffd64c54 100644 --- a/crates/lean_compiler/src/parser/parsers/mod.rs +++ b/crates/lean_compiler/src/parser/parsers/mod.rs @@ -68,12 +68,7 @@ pub fn expect_rule(pair: &ParsePair<'_>, expected: Rule) -> ParseResult<()> { if pair.as_rule() == expected { Ok(()) } else { - Err(SemanticError::new(format!( - "Expected {:?} but found {:?}", - expected, - pair.as_rule() - )) - .into()) + Err(SemanticError::new(format!("Expected {:?} but found {:?}", expected, pair.as_rule())).into()) } } diff --git a/crates/lean_compiler/src/parser/parsers/program.rs b/crates/lean_compiler/src/parser/parsers/program.rs index 5bb3b6e2..e2106023 100644 --- a/crates/lean_compiler/src/parser/parsers/program.rs +++ b/crates/lean_compiler/src/parser/parsers/program.rs @@ -14,10 +14,7 @@ use std::collections::BTreeMap; pub struct ProgramParser; impl Parse<(Program, BTreeMap)> for ProgramParser { - fn parse( - pair: ParsePair<'_>, - _ctx: &mut ParseContext, - ) -> ParseResult<(Program, BTreeMap)> { + fn parse(pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult<(Program, BTreeMap)> { let mut ctx = ParseContext::new(); let mut functions = BTreeMap::new(); let mut function_locations = BTreeMap::new(); diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index de055693..eb13e585 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -29,9 +29,7 @@ impl Parse for StatementParser { Rule::assert_eq_statement => AssertEqParser::parse(inner, ctx), Rule::assert_not_eq_statement => AssertNotEqParser::parse(inner, ctx), Rule::break_statement => Ok(Line::Break), - Rule::continue_statement => { - Err(SemanticError::new("Continue statement not implemented yet").into()) - } + Rule::continue_statement => Err(SemanticError::new("Continue statement not implemented yet").into()), _ => Err(SemanticError::new("Unknown statement").into()), } } @@ -43,9 +41,7 @@ pub struct AssignmentParser; impl Parse for AssignmentParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let var = next_inner_pair(&mut inner, "variable name")? - .as_str() - .to_string(); + let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); let expr = next_inner_pair(&mut inner, "assignment value")?; let value = ExpressionParser::parse(expr, ctx)?; @@ -59,9 +55,7 @@ pub struct ArrayAssignParser; impl Parse for ArrayAssignParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let array = next_inner_pair(&mut inner, "array name")? - .as_str() - .to_string(); + let array = next_inner_pair(&mut inner, "array name")?.as_str().to_string(); let index = ExpressionParser::parse(next_inner_pair(&mut inner, "array index")?, ctx)?; let value = ExpressionParser::parse(next_inner_pair(&mut inner, "array value")?, ctx)?; @@ -133,11 +127,8 @@ impl Parse for ConditionParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let inner_pair = next_inner_pair(&mut pair.into_inner(), "inner expression")?; if inner_pair.as_rule() == Rule::assumed_bool_expr { - ExpressionParser::parse( - next_inner_pair(&mut inner_pair.into_inner(), "inner expression")?, - ctx, - ) - .map(|e| Condition::Expression(e, AssumeBoolean::AssumeBoolean)) + ExpressionParser::parse(next_inner_pair(&mut inner_pair.into_inner(), "inner expression")?, ctx) + .map(|e| Condition::Expression(e, AssumeBoolean::AssumeBoolean)) } else { let expr_result = ExpressionParser::parse(inner_pair, ctx); match expr_result { @@ -158,10 +149,7 @@ impl Parse for ConditionParser { left: *left, right: *right, })), - Ok(expr) => Ok(Condition::Expression( - expr, - AssumeBoolean::DoNotAssumeBoolean, - )), + Ok(expr) => Ok(Condition::Expression(expr, AssumeBoolean::DoNotAssumeBoolean)), } } } @@ -174,9 +162,7 @@ impl Parse for ForStatementParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let line_number = pair.line_col().0; let mut inner = pair.into_inner(); - let iterator = next_inner_pair(&mut inner, "loop iterator")? - .as_str() - .to_string(); + let iterator = next_inner_pair(&mut inner, "loop iterator")?.as_str().to_string(); // Check for optional reverse clause let mut rev = false; @@ -321,9 +307,6 @@ impl Parse for AssertNotEqParser { let left = ExpressionParser::parse(next_inner_pair(&mut inner, "left assertion")?, ctx)?; let right = ExpressionParser::parse(next_inner_pair(&mut inner, "right assertion")?, ctx)?; - Ok(Line::Assert( - Boolean::Different { left, right }, - line_number, - )) + Ok(Line::Assert(Boolean::Different { left, right }, line_number)) } } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2016e296..cf9edafa 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -21,12 +21,7 @@ fn test_duplicate_function_name() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -40,12 +35,7 @@ fn test_duplicate_constant_name() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -66,12 +56,7 @@ fn test_fibonacci_program() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -89,12 +74,7 @@ fn test_edge_case_0() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -107,12 +87,7 @@ fn test_edge_case_1() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -130,12 +105,7 @@ fn test_edge_case_2() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -151,12 +121,7 @@ fn test_decompose_bits() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -172,12 +137,7 @@ fn test_unroll() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -189,12 +149,7 @@ fn test_rev_unroll() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -214,12 +169,7 @@ fn test_mini_program_0() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -262,12 +212,7 @@ fn test_mini_program_1() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -295,12 +240,7 @@ fn test_mini_program_2() { return sum, product; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -438,12 +378,7 @@ fn test_inlined() { return; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -495,12 +430,7 @@ fn test_match() { return x * x * x * x * x * x; } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } // #[test] @@ -541,12 +471,7 @@ fn test_const_functions_calling_const_functions() { } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -569,12 +494,7 @@ fn test_inline_functions_calling_inline_functions() { } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } #[test] @@ -601,10 +521,5 @@ fn test_nested_inline_functions() { } "#; - compile_and_run( - program.to_string(), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } diff --git a/crates/lean_prover/src/common.rs b/crates/lean_prover/src/common.rs index 35f73c70..5adeb3fd 100644 --- a/crates/lean_prover/src/common.rs +++ b/crates/lean_prover/src/common.rs @@ -27,21 +27,11 @@ pub(crate) fn get_base_dims( ], p16_default_cubes .iter() - .map(|&c| { - ColDims::padded( - table_heights[Table::poseidon16().index()].n_rows_non_padded_maxed(), - c, - ) - }) + .map(|&c| ColDims::padded(table_heights[Table::poseidon16().index()].n_rows_non_padded_maxed(), c)) .collect::>(), // commited cubes for poseidon16 p24_default_cubes .iter() - .map(|&c| { - ColDims::padded( - table_heights[Table::poseidon24().index()].n_rows_non_padded_maxed(), - c, - ) - }) + .map(|&c| ColDims::padded(table_heights[Table::poseidon24().index()].n_rows_non_padded_maxed(), c)) .collect::>(), ] .concat(); @@ -51,10 +41,7 @@ pub(crate) fn get_base_dims( dims } -pub(crate) fn fold_bytecode( - bytecode: &Bytecode, - folding_challenges: &MultilinearPoint, -) -> Vec { +pub(crate) fn fold_bytecode(bytecode: &Bytecode, folding_challenges: &MultilinearPoint) -> Vec { let encoded_bytecode = padd_with_zero_to_next_power_of_two( &bytecode .instructions @@ -65,13 +52,9 @@ pub(crate) fn fold_bytecode( fold_multilinear_chunks(&encoded_bytecode, folding_challenges) } -pub(crate) fn initial_and_final_pc_conditions( - log_n_cycles: usize, -) -> (Evaluation, Evaluation) { - let initial_pc_statement = - Evaluation::new(EF::zero_vec(log_n_cycles), EF::from_usize(STARTING_PC)); - let final_pc_statement = - Evaluation::new(vec![EF::ONE; log_n_cycles], EF::from_usize(ENDING_PC)); +pub(crate) fn initial_and_final_pc_conditions(log_n_cycles: usize) -> (Evaluation, Evaluation) { + let initial_pc_statement = Evaluation::new(EF::zero_vec(log_n_cycles), EF::from_usize(STARTING_PC)); + let final_pc_statement = Evaluation::new(vec![EF::ONE; log_n_cycles], EF::from_usize(ENDING_PC)); (initial_pc_statement, final_pc_statement) } @@ -81,9 +64,7 @@ fn split_at(stmt: &MultiEvaluation, start: usize, end: usize) -> Vec Vec>> { +pub(crate) fn poseidon_16_vectorized_lookup_statements(p16_gkr: &GKRPoseidonResult) -> Vec>> { vec![ split_at(&p16_gkr.input_statements, 0, VECTOR_LEN), split_at(&p16_gkr.input_statements, VECTOR_LEN, VECTOR_LEN * 2), @@ -92,9 +73,7 @@ pub(crate) fn poseidon_16_vectorized_lookup_statements( ] } -pub(crate) fn poseidon_24_vectorized_lookup_statements( - p24_gkr: &GKRPoseidonResult, -) -> Vec>> { +pub(crate) fn poseidon_24_vectorized_lookup_statements(p24_gkr: &GKRPoseidonResult) -> Vec>> { vec![ split_at(&p24_gkr.input_statements, 0, VECTOR_LEN), split_at(&p24_gkr.input_statements, VECTOR_LEN, VECTOR_LEN * 2), diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index aeac0f86..2c90fea0 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -41,8 +41,7 @@ pub fn prove_execution( ) }); exec_summary = std::mem::take(&mut execution_result.summary); - info_span!("Building execution trace") - .in_scope(|| get_execution_trace(bytecode, execution_result)) + info_span!("Building execution trace").in_scope(|| get_execution_trace(bytecode, execution_result)) }); if memory.len() < 1 << MIN_LOG_MEMORY_SIZE { @@ -87,10 +86,7 @@ pub fn prove_execution( prover_state.add_base_scalars( &[ vec![private_memory.len()], - traces - .iter() - .map(|t| t.n_rows_non_padded()) - .collect::>(), + traces.iter().map(|t| t.n_rows_non_padded()).collect::>(), ] .concat() .into_iter() @@ -120,9 +116,7 @@ pub fn prove_execution( ] .concat(); for i in 0..N_TABLES { - base_pols.extend( - ALL_TABLES[i].committed_columns(&traces[i], commitmenent_extension_helper[i].as_ref()), - ); + base_pols.extend(ALL_TABLES[i].committed_columns(&traces[i], commitmenent_extension_helper[i].as_ref())); } // 1st Commitment @@ -134,8 +128,7 @@ pub fn prove_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, ); - let random_point_p16 = - MultilinearPoint(prover_state.sample_vec(traces[Table::poseidon16().index()].log_padded())); + let random_point_p16 = MultilinearPoint(prover_state.sample_vec(traces[Table::poseidon16().index()].log_padded())); let p16_gkr = prove_poseidon_gkr( &mut prover_state, &p16_witness, @@ -144,8 +137,7 @@ pub fn prove_execution( &p16_gkr_layers, ); - let random_point_p24 = - MultilinearPoint(prover_state.sample_vec(traces[Table::poseidon24().index()].log_padded())); + let random_point_p24 = MultilinearPoint(prover_state.sample_vec(traces[Table::poseidon24().index()].log_padded())); let p24_gkr = prove_poseidon_gkr( &mut prover_state, &p24_witness, @@ -182,10 +174,8 @@ pub fn prove_execution( let bytecode_lookup_claim_1 = Evaluation::new( air_points[Table::execution().index()].clone(), - padd_with_zero_to_next_power_of_two( - &evals_f[Table::execution().index()][..N_INSTRUCTION_COLUMNS], - ) - .evaluate(&bytecode_compression_challenges), + padd_with_zero_to_next_power_of_two(&evals_f[Table::execution().index()][..N_INSTRUCTION_COLUMNS]) + .evaluate(&bytecode_compression_challenges), ); let bytecode_poly_eq_point = eval_eq(&air_points[Table::execution().index()]); let bytecode_pushforward = compute_pushforward( @@ -204,14 +194,10 @@ pub fn prove_execution( .flat_map(|i| ALL_TABLES[i].normal_lookup_index_columns_ef(&traces[i])) .collect(), (0..N_TABLES) - .flat_map(|i| { - vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_f()] - }) + .flat_map(|i| vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_f()]) .collect(), (0..N_TABLES) - .flat_map(|i| { - vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_ef()] - }) + .flat_map(|i| vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_ef()]) .collect(), (0..N_TABLES) .flat_map(|i| ALL_TABLES[i].normal_lookup_default_indexes_f()) @@ -241,9 +227,7 @@ pub fn prove_execution( .flat_map(|i| ALL_TABLES[i].vector_lookup_index_columns(&traces[i])) .collect(), (0..N_TABLES) - .flat_map(|i| { - vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_vector_lookups()] - }) + .flat_map(|i| vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_vector_lookups()]) .collect(), (0..N_TABLES) .flat_map(|i| ALL_TABLES[i].vector_lookup_default_indexes()) @@ -262,10 +246,8 @@ pub fn prove_execution( statements.extend(poseidon_24_vectorized_lookup_statements(&p24_gkr)); // special case continue; } - statements.extend(table.vectorized_lookups_statements( - &air_points[table.index()], - &evals_f[table.index()], - )); + statements + .extend(table.vectorized_lookups_statements(&air_points[table.index()], &evals_f[table.index()])); } statements }, @@ -297,11 +279,9 @@ pub fn prove_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, ); - let mut normal_lookup_statements = - normal_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); + let mut normal_lookup_statements = normal_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); - let vectorized_lookup_statements = - vectorized_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); + let vectorized_lookup_statements = vectorized_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); let bytecode_logup_star_statements = prove_logup_star( &mut prover_state, @@ -348,13 +328,13 @@ pub fn prove_execution( let (initial_pc_statement, final_pc_statement) = initial_and_final_pc_conditions(traces[Table::execution().index()].log_padded()); - final_statements[Table::execution().index()] - [ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)] - .extend(vec![ - bytecode_logup_star_statements.on_indexes.clone(), - initial_pc_statement, - final_pc_statement, - ]); + final_statements[Table::execution().index()][ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)].extend( + vec![ + bytecode_logup_star_statements.on_indexes.clone(), + initial_pc_statement, + final_pc_statement, + ], + ); // First Opening let mut all_base_statements = [ @@ -438,35 +418,28 @@ fn prove_bus_and_air( let mut denominators = unsafe { uninitialized_vec(n_buses_padded * n_rows) }; for (bus, denomniators_chunk) in t.buses().iter().zip(denominators.chunks_exact_mut(n_rows)) { - denomniators_chunk - .par_iter_mut() - .enumerate() - .for_each(|(i, v)| { - *v = bus_challenge - + finger_print( - match &bus.table { - BusTable::Constant(table) => table.embed(), - BusTable::Variable(col) => trace.base[*col][i], - }, - bus.data - .iter() - .map(|col| trace.base[*col][i]) - .collect::>() - .as_slice(), - fingerprint_challenge, - ); - }); + denomniators_chunk.par_iter_mut().enumerate().for_each(|(i, v)| { + *v = bus_challenge + + finger_print( + match &bus.table { + BusTable::Constant(table) => table.embed(), + BusTable::Variable(col) => trace.base[*col][i], + }, + bus.data + .iter() + .map(|col| trace.base[*col][i]) + .collect::>() + .as_slice(), + fingerprint_challenge, + ); + }); } denominators[n_rows * n_buses..] .par_iter_mut() .for_each(|v| *v = EF::ONE); // TODO avoid embedding !! - let numerators_embedded = numerators - .par_iter() - .copied() - .map(EF::from) - .collect::>(); + let numerators_embedded = numerators.par_iter().copied().map(EF::from).collect::>(); // TODO avoid reallocation due to packing (pack directly when constructing) let numerators_packed = pack_extension(&numerators_embedded); @@ -490,11 +463,7 @@ fn prove_bus_and_air( let sub_numerators_evals = numerators .par_chunks_exact(1 << (log_n_rows - UNIVARIATE_SKIPS)) .take(n_buses << UNIVARIATE_SKIPS) - .map(|chunk| { - chunk.evaluate(&MultilinearPoint( - bus_point_global[1 + log_n_buses..].to_vec(), - )) - }) + .map(|chunk| chunk.evaluate(&MultilinearPoint(bus_point_global[1 + log_n_buses..].to_vec()))) .collect::>(); prover_state.add_extension_scalars(&sub_numerators_evals); // sanity check: @@ -511,11 +480,7 @@ fn prove_bus_and_air( let sub_denominators_evals = denominators .par_chunks_exact(1 << (log_n_rows - UNIVARIATE_SKIPS)) .take(n_buses << UNIVARIATE_SKIPS) - .map(|chunk| { - chunk.evaluate(&MultilinearPoint( - bus_point_global[1 + log_n_buses..].to_vec(), - )) - }) + .map(|chunk| chunk.evaluate(&MultilinearPoint(bus_point_global[1 + log_n_buses..].to_vec()))) .collect::>(); prover_state.add_extension_scalars(&sub_denominators_evals); // sanity check: @@ -530,31 +495,15 @@ fn prove_bus_and_air( ); let epsilon = prover_state.sample(); - let bus_point = MultilinearPoint( - [vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat(), - ); + let bus_point = MultilinearPoint([vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat()); let bus_selector_values = sub_numerators_evals .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| { - evaluate_univariate_multilinear::<_, _, _, false>( - chunk, - &[epsilon], - &uni_selectors, - None, - ) - }) + .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) .collect(); let bus_data_values = sub_denominators_evals .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| { - evaluate_univariate_multilinear::<_, _, _, false>( - chunk, - &[epsilon], - &uni_selectors, - None, - ) - }) + .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) .collect(); (bus_point, bus_selector_values, bus_data_values) @@ -579,8 +528,7 @@ fn prove_bus_and_air( let bus_virtual_statement = MultiEvaluation::new(bus_point, bus_final_values); for bus in t.buses() { - quotient -= - bus.padding_contribution(t, trace.padding_len(), bus_challenge, fingerprint_challenge); + quotient -= bus.padding_contribution(t, trace.padding_len(), bus_challenge, fingerprint_challenge); } let extra_data = ExtraDataForBuses { @@ -588,26 +536,25 @@ fn prove_bus_and_air( bus_beta, alpha_powers: vec![], // filled later }; - let (air_point, evals_f, evals_ef) = - info_span!("Table AIR proof", table = t.name()).in_scope(|| { - macro_rules! prove_air_for_table { - ($t:expr) => { - prove_air( - prover_state, - $t, - extra_data, - UNIVARIATE_SKIPS, - &trace.base[..$t.n_columns_f_air()], - &trace.ext[..$t.n_columns_ef_air()], - &$t.air_padding_row_f(), - &$t.air_padding_row_ef(), - Some(bus_virtual_statement), - $t.n_columns_air() + $t.total_n_down_columns_air() > 5, // heuristic - ) - }; - } - delegate_to_inner!(t => prove_air_for_table) - }); + let (air_point, evals_f, evals_ef) = info_span!("Table AIR proof", table = t.name()).in_scope(|| { + macro_rules! prove_air_for_table { + ($t:expr) => { + prove_air( + prover_state, + $t, + extra_data, + UNIVARIATE_SKIPS, + &trace.base[..$t.n_columns_f_air()], + &trace.ext[..$t.n_columns_ef_air()], + &$t.air_padding_row_f(), + &$t.air_padding_row_ef(), + Some(bus_virtual_statement), + $t.n_columns_air() + $t.total_n_down_columns_air() > 5, // heuristic + ) + }; + } + delegate_to_inner!(t => prove_air_for_table) + }); (quotient, air_point, evals_f, evals_ef) } diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 5dfecd39..e3c79c77 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -59,9 +59,8 @@ pub fn verify_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, )?; - let random_point_p16 = MultilinearPoint( - verifier_state.sample_vec(table_heights[Table::poseidon16().index()].log_padded()), - ); + let random_point_p16 = + MultilinearPoint(verifier_state.sample_vec(table_heights[Table::poseidon16().index()].log_padded())); let p16_gkr = verify_poseidon_gkr( &mut verifier_state, table_heights[Table::poseidon16().index()].log_padded(), @@ -71,9 +70,8 @@ pub fn verify_execution( true, ); - let random_point_p24 = MultilinearPoint( - verifier_state.sample_vec(table_heights[Table::poseidon24().index()].log_padded()), - ); + let random_point_p24 = + MultilinearPoint(verifier_state.sample_vec(table_heights[Table::poseidon24().index()].log_padded())); let p24_gkr = verify_poseidon_gkr( &mut verifier_state, table_heights[Table::poseidon24().index()].log_padded(), @@ -110,29 +108,17 @@ pub fn verify_execution( let bytecode_lookup_claim_1 = Evaluation::new( air_points[Table::execution().index()].clone(), - padd_with_zero_to_next_power_of_two( - &evals_f[Table::execution().index()][..N_INSTRUCTION_COLUMNS], - ) - .evaluate(&bytecode_compression_challenges), + padd_with_zero_to_next_power_of_two(&evals_f[Table::execution().index()][..N_INSTRUCTION_COLUMNS]) + .evaluate(&bytecode_compression_challenges), ); let normal_lookup_into_memory = NormalPackedLookupVerifier::step_1( &mut verifier_state, (0..N_TABLES) - .flat_map(|i| { - vec![ - table_heights[i].n_rows_non_padded_maxed(); - ALL_TABLES[i].num_normal_lookups_f() - ] - }) + .flat_map(|i| vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_f()]) .collect(), (0..N_TABLES) - .flat_map(|i| { - vec![ - table_heights[i].n_rows_non_padded_maxed(); - ALL_TABLES[i].num_normal_lookups_ef() - ] - }) + .flat_map(|i| vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_ef()]) .collect(), (0..N_TABLES) .flat_map(|i| ALL_TABLES[i].normal_lookup_default_indexes_f()) @@ -153,9 +139,7 @@ pub fn verify_execution( let vectorized_lookup_into_memory = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( &mut verifier_state, (0..N_TABLES) - .flat_map(|i| { - vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_vector_lookups()] - }) + .flat_map(|i| vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_vector_lookups()]) .collect(), (0..N_TABLES) .flat_map(|i| ALL_TABLES[i].vector_lookup_default_indexes()) @@ -171,10 +155,8 @@ pub fn verify_execution( statements.extend(poseidon_24_vectorized_lookup_statements(&p24_gkr)); // special case continue; } - statements.extend(table.vectorized_lookups_statements( - &air_points[table.index()], - &evals_f[table.index()], - )); + statements + .extend(table.vectorized_lookups_statements(&air_points[table.index()], &evals_f[table.index()])); } statements }, @@ -202,11 +184,9 @@ pub fn verify_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, )?; - let mut normal_lookup_statements = - normal_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; + let mut normal_lookup_statements = normal_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; - let vectorized_lookup_statements = - vectorized_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; + let vectorized_lookup_statements = vectorized_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; let bytecode_logup_star_statements = verify_logup_star( &mut verifier_state, @@ -256,13 +236,13 @@ pub fn verify_execution( let (initial_pc_statement, final_pc_statement) = initial_and_final_pc_conditions(table_heights[Table::execution().index()].log_padded()); - final_statements[Table::execution().index()] - [ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)] - .extend(vec![ - bytecode_logup_star_statements.on_indexes.clone(), - initial_pc_statement, - final_pc_statement, - ]); + final_statements[Table::execution().index()][ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)].extend( + vec![ + bytecode_logup_star_statements.on_indexes.clone(), + initial_pc_statement, + final_pc_statement, + ], + ); let mut all_base_statements = [ vec![memory_statements], @@ -317,10 +297,7 @@ fn verify_bus_and_air( assert!(n_buses > 0, "Table {} has no buses", t.name()); let (mut quotient, bus_point_global, numerator_value_global, denominator_value_global) = - verify_gkr_quotient::<_, TWO_POW_UNIVARIATE_SKIPS>( - verifier_state, - log_n_rows + log_n_buses, - )?; + verify_gkr_quotient::<_, TWO_POW_UNIVARIATE_SKIPS>(verifier_state, log_n_rows + log_n_buses)?; let (bus_point, bus_selector_values, bus_data_values) = if n_buses == 1 { // easy case @@ -332,8 +309,7 @@ fn verify_bus_and_air( } else { let uni_selectors = univariate_selectors::(UNIVARIATE_SKIPS); - let sub_numerators_evals = - verifier_state.next_extension_scalars_vec(n_buses << UNIVARIATE_SKIPS)?; + let sub_numerators_evals = verifier_state.next_extension_scalars_vec(n_buses << UNIVARIATE_SKIPS)?; assert_eq!( numerator_value_global, evaluate_univariate_multilinear::<_, _, _, false>( @@ -344,8 +320,7 @@ fn verify_bus_and_air( ), ); - let sub_denominators_evals = - verifier_state.next_extension_scalars_vec(n_buses << UNIVARIATE_SKIPS)?; + let sub_denominators_evals = verifier_state.next_extension_scalars_vec(n_buses << UNIVARIATE_SKIPS)?; assert_eq!( denominator_value_global, evaluate_univariate_multilinear::<_, _, _, false>( @@ -356,31 +331,15 @@ fn verify_bus_and_air( ), ); let epsilon = verifier_state.sample(); - let bus_point = MultilinearPoint( - [vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat(), - ); + let bus_point = MultilinearPoint([vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat()); let bus_selector_values = sub_numerators_evals .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| { - evaluate_univariate_multilinear::<_, _, _, false>( - chunk, - &[epsilon], - &uni_selectors, - None, - ) - }) + .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) .collect(); let bus_data_values = sub_denominators_evals .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| { - evaluate_univariate_multilinear::<_, _, _, false>( - chunk, - &[epsilon], - &uni_selectors, - None, - ) - }) + .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) .collect(); (bus_point, bus_selector_values, bus_data_values) @@ -405,12 +364,7 @@ fn verify_bus_and_air( let bus_virtual_statement = MultiEvaluation::new(bus_point, bus_final_values); for bus in t.buses() { - quotient -= bus.padding_contribution( - t, - table_height.padding_len(), - bus_challenge, - fingerprint_challenge, - ); + quotient -= bus.padding_contribution(t, table_height.padding_len(), bus_challenge, fingerprint_challenge); } let extra_data = ExtraDataForBuses { diff --git a/crates/lean_prover/tests/hash_chain.rs b/crates/lean_prover/tests/hash_chain.rs index c375df64..04794701 100644 --- a/crates/lean_prover/tests/hash_chain.rs +++ b/crates/lean_prover/tests/hash_chain.rs @@ -1,7 +1,5 @@ use lean_compiler::*; -use lean_prover::{ - prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder, -}; +use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder}; use lean_vm::{F, execute_bytecode}; use multilinear_toolkit::prelude::*; use std::time::Instant; @@ -50,10 +48,7 @@ fn benchmark_poseidon_chain() { const LOG_CHAIN_LENGTH: usize = 17; const CHAIN_LENGTH: usize = 1 << LOG_CHAIN_LENGTH; - let program_str = program_str.replace( - "LOG_CHAIN_LENGTH_PLACEHOLDER", - &LOG_CHAIN_LENGTH.to_string(), - ); + let program_str = program_str.replace("LOG_CHAIN_LENGTH_PLACEHOLDER", &LOG_CHAIN_LENGTH.to_string()); let mut public_input = F::zero_vec(1 << 13); public_input[0..8].copy_from_slice(&iterate_hash(&Default::default(), CHAIN_LENGTH)); diff --git a/crates/lean_prover/tests/test_zkvm.rs b/crates/lean_prover/tests/test_zkvm.rs index b3d31a43..8ee67ecd 100644 --- a/crates/lean_prover/tests/test_zkvm.rs +++ b/crates/lean_prover/tests/test_zkvm.rs @@ -1,7 +1,5 @@ use lean_compiler::*; -use lean_prover::{ - prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder, -}; +use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder}; use lean_vm::*; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -62,19 +60,11 @@ fn test_zk_vm_all_precompiles() { .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) .collect::>(), ); - let dot_product_base_ext: EF = dot_product( - dot_product_slice_ext_a.into_iter(), - dot_product_slice_base.into_iter(), - ); - let dot_product_ext_ext: EF = dot_product( - dot_product_slice_ext_a.into_iter(), - dot_product_slice_ext_b.into_iter(), - ); + let dot_product_base_ext: EF = dot_product(dot_product_slice_ext_a.into_iter(), dot_product_slice_base.into_iter()); + let dot_product_ext_ext: EF = dot_product(dot_product_slice_ext_a.into_iter(), dot_product_slice_ext_b.into_iter()); - public_input[1000..][..DIMENSION] - .copy_from_slice(dot_product_base_ext.as_basis_coefficients_slice()); - public_input[1000 + DIMENSION..][..DIMENSION] - .copy_from_slice(dot_product_ext_ext.as_basis_coefficients_slice()); + public_input[1000..][..DIMENSION].copy_from_slice(dot_product_base_ext.as_basis_coefficients_slice()); + public_input[1000 + DIMENSION..][..DIMENSION].copy_from_slice(dot_product_ext_ext.as_basis_coefficients_slice()); test_zk_vm_helper(program_str, (&public_input, &[]), 0); } @@ -121,11 +111,7 @@ fn test_prove_fibonacci() { test_zk_vm_helper(&program_str, (&[F::ZERO; 1 << 14], &[]), 0); } -fn test_zk_vm_helper( - program_str: &str, - (public_input, private_input): (&[F], &[F]), - no_vec_runtime_memory: usize, -) { +fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[F]), no_vec_runtime_memory: usize) { utils::init_tracing(); let bytecode = compile_program(program_str.to_string()); let time = std::time::Instant::now(); diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/witness_generation/src/execution_trace.rs index 7504d6fe..83b9b048 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/witness_generation/src/execution_trace.rs @@ -12,10 +12,7 @@ pub struct ExecutionTrace { pub memory: Vec, // of length a multiple of public_memory_size } -pub fn get_execution_trace( - bytecode: &Bytecode, - mut execution_result: ExecutionResult, -) -> ExecutionTrace { +pub fn get_execution_trace(bytecode: &Bytecode, mut execution_result: ExecutionResult) -> ExecutionTrace { assert_eq!(execution_result.pcs.len(), execution_result.fps.len()); // padding to make proof work even on small programs (TODO make this more elegant) @@ -78,8 +75,8 @@ pub fn get_execution_trace( + (F::ONE - field_repr[COL_INDEX_FLAG_A]) * value_a; let nu_b = field_repr[COL_INDEX_FLAG_B] * field_repr[COL_INDEX_OPERAND_B] + (F::ONE - field_repr[COL_INDEX_FLAG_B]) * value_b; - let nu_c = field_repr[COL_INDEX_FLAG_C] * F::from_usize(fp) - + (F::ONE - field_repr[COL_INDEX_FLAG_C]) * value_c; + let nu_c = + field_repr[COL_INDEX_FLAG_C] * F::from_usize(fp) + (F::ONE - field_repr[COL_INDEX_FLAG_C]) * value_c; *trace_row[COL_INDEX_EXEC_NU_A] = nu_a; *trace_row[COL_INDEX_EXEC_NU_B] = nu_b; *trace_row[COL_INDEX_EXEC_NU_C] = nu_c; @@ -94,11 +91,7 @@ pub fn get_execution_trace( *trace_row[COL_INDEX_MEM_ADDRESS_C] = addr_c; }); - let mut memory_padded = memory - .0 - .par_iter() - .map(|&v| v.unwrap_or(F::ZERO)) - .collect::>(); + let mut memory_padded = memory.0.par_iter().map(|&v| v.unwrap_or(F::ZERO)).collect::>(); memory_padded.resize(memory.0.len().next_power_of_two(), F::ZERO); let ExecutionResult { mut traces, .. } = execution_result; diff --git a/crates/lean_prover/witness_generation/src/instruction_encoder.rs b/crates/lean_prover/witness_generation/src/instruction_encoder.rs index c14029bd..92d3412a 100644 --- a/crates/lean_prover/witness_generation/src/instruction_encoder.rs +++ b/crates/lean_prover/witness_generation/src/instruction_encoder.rs @@ -23,11 +23,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { set_nu_b(&mut fields, res); set_nu_c(&mut fields, arg_c); } - Instruction::Deref { - shift_0, - shift_1, - res, - } => { + Instruction::Deref { shift_0, shift_1, res } => { fields[COL_INDEX_DEREF] = F::ONE; fields[COL_INDEX_FLAG_A] = F::ZERO; fields[COL_INDEX_OPERAND_A] = F::from_usize(*shift_0); diff --git a/crates/lean_prover/witness_generation/src/poseidon_tables.rs b/crates/lean_prover/witness_generation/src/poseidon_tables.rs index 7555b7b7..c9f343f4 100644 --- a/crates/lean_prover/witness_generation/src/poseidon_tables.rs +++ b/crates/lean_prover/witness_generation/src/poseidon_tables.rs @@ -20,8 +20,7 @@ where let inputs: [_; WIDTH] = array::from_fn(|i| &trace.base[start_index + i][..]); let n_poseidons = inputs[0].len(); assert!(n_poseidons.is_power_of_two()); - let inputs_packed: [_; WIDTH] = - array::from_fn(|i| PFPacking::::pack_slice(inputs[i]).to_vec()); // TODO avoid cloning + let inputs_packed: [_; WIDTH] = array::from_fn(|i| PFPacking::::pack_slice(inputs[i]).to_vec()); // TODO avoid cloning generate_poseidon_witness::, WIDTH, N_COMMITED_CUBES>( inputs_packed, layers, diff --git a/crates/lean_vm/src/core/label.rs b/crates/lean_vm/src/core/label.rs index e4224299..1ff8c0ef 100644 --- a/crates/lean_vm/src/core/label.rs +++ b/crates/lean_vm/src/core/label.rs @@ -59,11 +59,7 @@ pub enum AuxKind { /// @inlined_var_{count}_{var} InlinedVar { count: usize, var: String }, /// @unrolled_{index}_{value}_{var} - UnrolledVar { - index: usize, - value: usize, - var: String, - }, + UnrolledVar { index: usize, value: usize, var: String }, /// @incremented_{var} Incremented(String), /// @trash_{id} @@ -75,11 +71,7 @@ impl std::fmt::Display for Label { match self { Self::Function(name) => write!(f, "@function_{name}"), Self::EndProgram => write!(f, "@end_program"), - Self::If { - id, - kind, - line_number, - } => match kind { + Self::If { id, kind, line_number } => match kind { IfKind::If => write!(f, "@if_{id}_line_{line_number}"), IfKind::Else => write!(f, "@else_{id}_line_{line_number}"), IfKind::End => write!(f, "@if_else_end_{id}_line_{line_number}"), @@ -180,10 +172,7 @@ impl Label { pub fn inlined_var(count: usize, var: impl Into) -> Self { Self::AuxVar { - kind: AuxKind::InlinedVar { - count, - var: var.into(), - }, + kind: AuxKind::InlinedVar { count, var: var.into() }, id: 0, // Not used for this variant } } diff --git a/crates/lean_vm/src/diagnostics/profiler.rs b/crates/lean_vm/src/diagnostics/profiler.rs index 633ed3ad..3451328b 100644 --- a/crates/lean_vm/src/diagnostics/profiler.rs +++ b/crates/lean_vm/src/diagnostics/profiler.rs @@ -33,9 +33,7 @@ pub(crate) fn profiling_report( } else { // New function call call_stack.push(current_function_name.clone()); - let stats = function_stats - .entry(current_function_name.clone()) - .or_default(); + let stats = function_stats.entry(current_function_name.clone()).or_default(); stats.call_count += 1; } prev_function_name = current_function_name.clone(); @@ -58,14 +56,9 @@ pub(crate) fn profiling_report( let mut report = String::new(); - report.push_str( - "\n╔═════════════════════════════════════════════════════════════════════════╗\n", - ); - report - .push_str("║ PROFILING REPORT ║\n"); - report.push_str( - "╚═════════════════════════════════════════════════════════════════════════╝\n\n", - ); + report.push_str("\n╔═════════════════════════════════════════════════════════════════════════╗\n"); + report.push_str("║ PROFILING REPORT ║\n"); + report.push_str("╚═════════════════════════════════════════════════════════════════════════╝\n\n"); report.push_str("──────────────────────────────────────────────────────────────────────────\n"); report.push_str(" │ Exclusive │ Inclusive │ \n"); @@ -155,10 +148,7 @@ pub(crate) fn memory_profiling_report(profile: &MemoryProfile) -> String { report.push_str("= DETAILED MEMORY PROFILING =\n"); report.push_str("============================================\n"); report.push('\n'); - report.push_str(&format!( - "Total memory footprint: {}\n", - pretty_integer(footprint) - )); + report.push_str(&format!("Total memory footprint: {}\n", pretty_integer(footprint))); report.push_str(&format!( "Total allocated memory: {} ({:.2}% of footprint)\n", pretty_integer(allocated), @@ -218,9 +208,7 @@ pub(crate) fn memory_profiling_report(profile: &MemoryProfile) -> String { report } -fn function_allocations( - profile: &MemoryProfile, -) -> BTreeMap> { +fn function_allocations(profile: &MemoryProfile) -> BTreeMap> { let mut allocations = BTreeMap::new(); for (addr, object) in profile.objects.iter() { @@ -306,10 +294,7 @@ fn all_allocated_memory(profile: &MemoryProfile) -> BTreeSet { } /// Get the number of used memory addresses which are not allocated. -fn count_used_but_not_allocated( - used: &BTreeSet, - allocated: &BTreeSet, -) -> usize { +fn count_used_but_not_allocated(used: &BTreeSet, allocated: &BTreeSet) -> usize { let diff = BTreeSet::from_iter(used.difference(allocated)); let len = diff.len(); if len > 0 { @@ -322,9 +307,6 @@ fn count_used_but_not_allocated( } /// Get the number of allocated memory addresses which are not used. -fn count_allocated_but_not_used( - used: &BTreeSet, - allocated: &BTreeSet, -) -> usize { +fn count_allocated_but_not_used(used: &BTreeSet, allocated: &BTreeSet) -> usize { allocated.difference(used).count() } diff --git a/crates/lean_vm/src/diagnostics/stack_trace.rs b/crates/lean_vm/src/diagnostics/stack_trace.rs index e52b56f1..e9100de7 100644 --- a/crates/lean_vm/src/diagnostics/stack_trace.rs +++ b/crates/lean_vm/src/diagnostics/stack_trace.rs @@ -17,26 +17,18 @@ pub(crate) fn pretty_stack_trace( let mut prev_function_line = usize::MAX; let mut skipped_lines: usize = 0; // Track skipped lines for current function - result - .push_str("╔═════════════════════════════════════════════════════════════════════════╗\n"); - result - .push_str("║ STACK TRACE ║\n"); - result.push_str( - "╚═════════════════════════════════════════════════════════════════════════╝\n\n", - ); + result.push_str("╔═════════════════════════════════════════════════════════════════════════╗\n"); + result.push_str("║ STACK TRACE ║\n"); + result.push_str("╚═════════════════════════════════════════════════════════════════════════╝\n\n"); for (idx, &line_num) in instructions.iter().enumerate() { - let (current_function_line, current_function_name) = - find_function_for_line(line_num, function_locations); + let (current_function_line, current_function_name) = find_function_for_line(line_num, function_locations); if prev_function_line != current_function_line { assert_eq!(skipped_lines, 0); // Check if we're returning to a previous function or calling a new one - if let Some(pos) = call_stack - .iter() - .position(|(_, f)| f == ¤t_function_name) - { + if let Some(pos) = call_stack.iter().position(|(_, f)| f == ¤t_function_name) { // Returning to a previous function - pop the stack while call_stack.len() > pos + 1 { call_stack.pop(); @@ -64,12 +56,8 @@ pub(crate) fn pretty_stack_trace( true } else { // Count remaining lines in this function - let remaining_in_function = count_remaining_lines_in_function( - idx, - instructions, - function_locations, - current_function_line, - ); + let remaining_in_function = + count_remaining_lines_in_function(idx, instructions, function_locations, current_function_line); remaining_in_function < STACK_TRACE_MAX_LINES_PER_FUNCTION }; @@ -78,9 +66,7 @@ pub(crate) fn pretty_stack_trace( // Show skipped lines message if transitioning from skipping to showing if skipped_lines > 0 { let indent = "│ ".repeat(call_stack.len()); - result.push_str(&format!( - "{indent}├─ ... ({skipped_lines} lines skipped) ...\n" - )); + result.push_str(&format!("{indent}├─ ... ({skipped_lines} lines skipped) ...\n")); skipped_lines = 0; } @@ -118,10 +104,7 @@ pub(crate) fn pretty_stack_trace( result } -pub(crate) fn find_function_for_line( - line_num: usize, - function_locations: &BTreeMap, -) -> (usize, String) { +pub(crate) fn find_function_for_line(line_num: usize, function_locations: &BTreeMap) -> (usize, String) { function_locations .range(..=line_num) .next_back() diff --git a/crates/lean_vm/src/execution/context.rs b/crates/lean_vm/src/execution/context.rs index 17a78a3b..b30be5eb 100644 --- a/crates/lean_vm/src/execution/context.rs +++ b/crates/lean_vm/src/execution/context.rs @@ -40,11 +40,7 @@ pub struct ExecutionContext<'a> { } impl<'a> ExecutionContext<'a> { - pub fn new( - source_code: &'a str, - function_locations: &'a BTreeMap, - profiler_enabled: bool, - ) -> Self { + pub fn new(source_code: &'a str, function_locations: &'a BTreeMap, profiler_enabled: bool) -> Self { Self { source_code, function_locations, diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index 7830d99a..aa19397d 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -25,11 +25,7 @@ impl Memory { /// /// Returns an error if the address is uninitialized pub fn get(&self, index: usize) -> Result { - self.0 - .get(index) - .copied() - .flatten() - .ok_or(RunnerError::UndefinedMemory) + self.0.get(index).copied().flatten().ok_or(RunnerError::UndefinedMemory) } /// Sets a value at a memory address @@ -85,9 +81,7 @@ impl Memory { index: usize, // normal pointer len: usize, ) -> Result, RunnerError> { - (0..len) - .map(|i| self.get_ef_element(index + i * DIMENSION)) - .collect() + (0..len).map(|i| self.get_ef_element(index + i * DIMENSION)).collect() } /// Set an extension field element in memory diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index a8ca32eb..2abb0bbe 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -1,16 +1,16 @@ //! VM execution runner use crate::core::{ - DIMENSION, F, NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, POSEIDON_16_NULL_HASH_PTR, - POSEIDON_24_NULL_HASH_PTR, VECTOR_LEN, ZERO_VEC_PTR, + DIMENSION, F, NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, POSEIDON_16_NULL_HASH_PTR, POSEIDON_24_NULL_HASH_PTR, + VECTOR_LEN, ZERO_VEC_PTR, }; use crate::diagnostics::{ExecutionResult, MemoryProfile, RunnerError, memory_profiling_report}; use crate::execution::{ExecutionHistory, Memory}; use crate::isa::Bytecode; use crate::isa::instruction::InstructionContext; use crate::{ - ALL_TABLES, CodeAddress, ENDING_PC, HintExecutionContext, N_TABLES, STARTING_PC, - SourceLineNumber, Table, TableTrace, + ALL_TABLES, CodeAddress, ENDING_PC, HintExecutionContext, N_TABLES, STARTING_PC, SourceLineNumber, Table, + TableTrace, }; use multilinear_toolkit::prelude::*; use std::array; @@ -24,38 +24,26 @@ const STACK_TRACE_INSTRUCTIONS: usize = 5000; /// Build public memory with standard initialization pub fn build_public_memory(public_input: &[F]) -> Vec { // padded to a power of two - let public_memory_len = - (NONRESERVED_PROGRAM_INPUT_START + public_input.len()).next_power_of_two(); + let public_memory_len = (NONRESERVED_PROGRAM_INPUT_START + public_input.len()).next_power_of_two(); let mut public_memory = F::zero_vec(public_memory_len); - public_memory[NONRESERVED_PROGRAM_INPUT_START..][..public_input.len()] - .copy_from_slice(public_input); + public_memory[NONRESERVED_PROGRAM_INPUT_START..][..public_input.len()].copy_from_slice(public_input); // "zero" vector let zero_start = ZERO_VEC_PTR * VECTOR_LEN; - for slot in public_memory - .iter_mut() - .skip(zero_start) - .take(2 * VECTOR_LEN) - { + for slot in public_memory.iter_mut().skip(zero_start).take(2 * VECTOR_LEN) { *slot = F::ZERO; } // "one" vector public_memory[ONE_VEC_PTR * VECTOR_LEN] = F::ONE; let one_start = ONE_VEC_PTR * VECTOR_LEN + 1; - for slot in public_memory - .iter_mut() - .skip(one_start) - .take(VECTOR_LEN - 1) - { + for slot in public_memory.iter_mut().skip(one_start).take(VECTOR_LEN - 1) { *slot = F::ZERO; } - public_memory - [POSEIDON_16_NULL_HASH_PTR * VECTOR_LEN..(POSEIDON_16_NULL_HASH_PTR + 2) * VECTOR_LEN] + public_memory[POSEIDON_16_NULL_HASH_PTR * VECTOR_LEN..(POSEIDON_16_NULL_HASH_PTR + 2) * VECTOR_LEN] .copy_from_slice(&poseidon16_permute([F::ZERO; 16])); - public_memory - [POSEIDON_24_NULL_HASH_PTR * VECTOR_LEN..(POSEIDON_24_NULL_HASH_PTR + 1) * VECTOR_LEN] + public_memory[POSEIDON_24_NULL_HASH_PTR * VECTOR_LEN..(POSEIDON_24_NULL_HASH_PTR + 1) * VECTOR_LEN] .copy_from_slice(&poseidon24_permute([F::ZERO; 24])[16..]); public_memory } @@ -84,8 +72,7 @@ pub fn execute_bytecode( ) .unwrap_or_else(|err| { let lines_history = &instruction_history.lines; - let latest_instructions = - &lines_history[lines_history.len().saturating_sub(STACK_TRACE_INSTRUCTIONS)..]; + let latest_instructions = &lines_history[lines_history.len().saturating_sub(STACK_TRACE_INSTRUCTIONS)..]; println!( "\n{}", crate::diagnostics::pretty_stack_trace( @@ -163,8 +150,7 @@ fn execute_bytecode_helper( // set public memory let mut memory = Memory::new(build_public_memory(public_input)); - let public_memory_size = - (NONRESERVED_PROGRAM_INPUT_START + public_input.len()).next_power_of_two(); + let public_memory_size = (NONRESERVED_PROGRAM_INPUT_START + public_input.len()).next_power_of_two(); let mut fp = public_memory_size; for (i, value) in private_input.iter().enumerate() { @@ -173,8 +159,7 @@ fn execute_bytecode_helper( let mut mem_profile = MemoryProfile { used: BTreeSet::new(), - public_inputs: NONRESERVED_PROGRAM_INPUT_START - ..NONRESERVED_PROGRAM_INPUT_START + public_memory_size, + public_inputs: NONRESERVED_PROGRAM_INPUT_START..NONRESERVED_PROGRAM_INPUT_START + public_memory_size, private_inputs: fp..fp + private_input.len(), objects: BTreeMap::new(), }; @@ -183,8 +168,7 @@ fn execute_bytecode_helper( fp = fp.next_multiple_of(DIMENSION); let initial_ap = fp + bytecode.starting_frame_memory; - let initial_ap_vec = - (initial_ap + no_vec_runtime_memory).next_multiple_of(VECTOR_LEN) / VECTOR_LEN; + let initial_ap_vec = (initial_ap + no_vec_runtime_memory).next_multiple_of(VECTOR_LEN) / VECTOR_LEN; let mut pc = STARTING_PC; let mut ap = initial_ap; @@ -299,49 +283,32 @@ fn execute_bytecode_helper( let mut summary = String::new(); if profiling { - let report = - crate::diagnostics::profiling_report(instruction_history, &bytecode.function_locations); + let report = crate::diagnostics::profiling_report(instruction_history, &bytecode.function_locations); summary.push_str(&report); } if !std_out.is_empty() { - summary.push_str( - "╔═════════════════════════════════════════════════════════════════════════╗\n", - ); - summary.push_str( - "║ STD-OUT ║\n", - ); - summary.push_str( - "╚═════════════════════════════════════════════════════════════════════════╝\n", - ); + summary.push_str("╔═════════════════════════════════════════════════════════════════════════╗\n"); + summary.push_str("║ STD-OUT ║\n"); + summary.push_str("╚═════════════════════════════════════════════════════════════════════════╝\n"); summary.push_str(&format!("\n{std_out}")); - summary.push_str( - "──────────────────────────────────────────────────────────────────────────\n\n", - ); + summary.push_str("──────────────────────────────────────────────────────────────────────────\n\n"); } - summary - .push_str("╔═════════════════════════════════════════════════════════════════════════╗\n"); - summary - .push_str("║ STATS ║\n"); - summary.push_str( - "╚═════════════════════════════════════════════════════════════════════════╝\n\n", - ); + summary.push_str("╔═════════════════════════════════════════════════════════════════════════╗\n"); + summary.push_str("║ STATS ║\n"); + summary.push_str("╚═════════════════════════════════════════════════════════════════════════╝\n\n"); summary.push_str(&format!("CYCLES: {}\n", pretty_integer(cpu_cycles))); summary.push_str(&format!("MEMORY: {}\n", pretty_integer(memory.0.len()))); summary.push('\n'); - let runtime_memory_size = - memory.0.len() - (NONRESERVED_PROGRAM_INPUT_START + public_input.len()); + let runtime_memory_size = memory.0.len() - (NONRESERVED_PROGRAM_INPUT_START + public_input.len()); summary.push_str(&format!( "Bytecode size: {}\n", pretty_integer(bytecode.instructions.len()) )); - summary.push_str(&format!( - "Public input size: {}\n", - pretty_integer(public_input.len()) - )); + summary.push_str(&format!("Public input size: {}\n", pretty_integer(public_input.len()))); summary.push_str(&format!( "Private input size: {}\n", pretty_integer(private_input.len()) @@ -365,10 +332,7 @@ fn execute_bytecode_helper( summary.push('\n'); - if traces[Table::poseidon16().index()].base[0].len() - + traces[Table::poseidon24().index()].base[0].len() - > 0 - { + if traces[Table::poseidon16().index()].base[0].len() + traces[Table::poseidon24().index()].base[0].len() > 0 { summary.push_str(&format!( "Poseidon2_16 calls: {}, Poseidon2_24 calls: {}, (1 poseidon per {} instructions)\n", pretty_integer(traces[Table::poseidon16().index()].base[0].len()), @@ -397,8 +361,7 @@ fn execute_bytecode_helper( summary.push_str(&format!("JUMP: {jump_counts}\n")); } - summary - .push_str("──────────────────────────────────────────────────────────────────────────\n"); + summary.push_str("──────────────────────────────────────────────────────────────────────────\n"); if profiling { for (addr, val) in (0..).zip(memory.0.iter()) { diff --git a/crates/lean_vm/src/execution/tests.rs b/crates/lean_vm/src/execution/tests.rs index 428299d7..ca03ba2c 100644 --- a/crates/lean_vm/src/execution/tests.rs +++ b/crates/lean_vm/src/execution/tests.rs @@ -25,10 +25,7 @@ fn test_memory_already_set_error() { memory.set(0, F::ONE).unwrap(); // Setting different value should fail - assert!(matches!( - memory.set(0, F::ZERO), - Err(RunnerError::MemoryAlreadySet) - )); + assert!(matches!(memory.set(0, F::ZERO), Err(RunnerError::MemoryAlreadySet))); } #[test] diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 68cccfc0..09a5d034 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -132,8 +132,7 @@ impl Hint { } } else { let allocation_start_addr = *ctx.ap; - ctx.memory - .set(ctx.fp + *offset, F::from_usize(allocation_start_addr))?; + ctx.memory.set(ctx.fp + *offset, F::from_usize(allocation_start_addr))?; *ctx.ap += size; if ctx.profiling { @@ -179,8 +178,7 @@ impl Hint { } } Self::CounterHint { res_offset } => { - ctx.memory - .set(ctx.fp + *res_offset, F::from_usize(*ctx.counter_hint))?; + ctx.memory.set(ctx.fp + *res_offset, F::from_usize(*ctx.counter_hint))?; *ctx.counter_hint += 1; } Self::Inverse { arg, res_offset } => { @@ -206,8 +204,7 @@ impl Hint { values[1], pretty_integer(ctx.cpu_cycles - *ctx.last_checkpoint_cpu_cycles), pretty_integer(new_no_vec_memory + new_vec_memory), - new_vec_memory as f64 / (new_no_vec_memory + new_vec_memory) as f64 - * 100.0 + new_vec_memory as f64 / (new_no_vec_memory + new_vec_memory) as f64 * 100.0 ); } @@ -255,10 +252,7 @@ impl Display for Hint { vectorized_len, } => { if *vectorized { - write!( - f, - "m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})" - ) + write!(f, "m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})") } else { write!(f, "m[fp + {offset}] = request_memory({size})") } @@ -305,9 +299,7 @@ impl Display for Hint { Self::Inverse { arg, res_offset } => { write!(f, "m[fp + {res_offset}] = inverse({arg})") } - Self::LocationReport { - location: line_number, - } => { + Self::LocationReport { location: line_number } => { write!(f, "source line number: {line_number}") } Self::Label { label } => { diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index 613aa42a..2e2cd141 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -126,11 +126,7 @@ impl Instruction { *ctx.pc += 1; Ok(()) } - Self::Deref { - shift_0, - shift_1, - res, - } => { + Self::Deref { shift_0, shift_1, res } => { if res.is_value_unknown(ctx.memory, *ctx.fp) { let memory_address_res = res.memory_address(*ctx.fp)?; let ptr = ctx.memory.get(*ctx.fp + shift_0)?; @@ -198,11 +194,7 @@ impl Display for Instruction { } => { write!(f, "{res} = {arg_a} {operation} {arg_c}") } - Self::Deref { - shift_0, - shift_1, - res, - } => { + Self::Deref { shift_0, shift_1, res } => { write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]") } Self::Jump { diff --git a/crates/lean_vm/src/tables/dot_product/air.rs b/crates/lean_vm/src/tables/dot_product/air.rs index e6deafe5..693ef9b9 100644 --- a/crates/lean_vm/src/tables/dot_product/air.rs +++ b/crates/lean_vm/src/tables/dot_product/air.rs @@ -1,6 +1,5 @@ use crate::{ - DIMENSION, EF, ExtraDataForBuses, TableT, eval_virtual_bus_column, - tables::dot_product::DotProductPrecompile, + DIMENSION, EF, ExtraDataForBuses, TableT, eval_virtual_bus_column, tables::dot_product::DotProductPrecompile, }; use multilinear_toolkit::prelude::*; use p3_air::{Air, AirBuilder}; @@ -123,8 +122,7 @@ impl Air for DotProductPrecompile { builder.assert_zero(flag_down * (len - AB::F::ONE)); let index_a_increment = AB::F::from_usize(if BE { 1 } else { DIMENSION }); builder.assert_zero(not_flag_down.clone() * (index_a - (index_a_down - index_a_increment))); - builder - .assert_zero(not_flag_down * (index_b - (index_b_down - AB::F::from_usize(DIMENSION)))); + builder.assert_zero(not_flag_down * (index_b - (index_b_down - AB::F::from_usize(DIMENSION)))); builder.assert_zero_ef((computation - res) * flag); } diff --git a/crates/lean_vm/src/tables/dot_product/exec.rs b/crates/lean_vm/src/tables/dot_product/exec.rs index 5a5994ae..e774b523 100644 --- a/crates/lean_vm/src/tables/dot_product/exec.rs +++ b/crates/lean_vm/src/tables/dot_product/exec.rs @@ -18,8 +18,7 @@ pub(super) fn exec_dot_product_be( let slice_0 = memory.slice(ptr_arg_0.to_usize(), size)?; let slice_1 = memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size)?; - let dot_product_result = - dot_product::(slice_1.iter().copied(), slice_0.iter().copied()); + let dot_product_result = dot_product::(slice_1.iter().copied(), slice_0.iter().copied()); memory.set_ef_element(ptr_res.to_usize(), dot_product_result)?; @@ -38,12 +37,10 @@ pub(super) fn exec_dot_product_be( trace.base[DOT_PRODUCT_AIR_COL_FLAG].push(F::ONE); trace.base[DOT_PRODUCT_AIR_COL_FLAG].extend(F::zero_vec(size - 1)); trace.base[DOT_PRODUCT_AIR_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_A] - .extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); + trace.base[DOT_PRODUCT_AIR_COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); trace.base[DOT_PRODUCT_AIR_COL_INDEX_B] .extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_RES] - .extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.base[DOT_PRODUCT_AIR_COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); trace.base[dot_product_air_col_value_a(true)].extend(slice_0); trace.ext[DOT_PRODUCT_AIR_COL_VALUE_B].extend(slice_1); trace.ext[DOT_PRODUCT_AIR_COL_VALUE_RES].extend(vec![dot_product_result; size]); @@ -71,8 +68,7 @@ pub(super) fn exec_dot_product_ee( (vec![EF::ONE], slice_0[0]) } else { let slice_1 = memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size)?; - let dot_product_result = - dot_product::(slice_1.iter().copied(), slice_0.iter().copied()); + let dot_product_result = dot_product::(slice_1.iter().copied(), slice_0.iter().copied()); (slice_1, dot_product_result) }; @@ -97,8 +93,7 @@ pub(super) fn exec_dot_product_ee( .extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i * DIMENSION))); trace.base[DOT_PRODUCT_AIR_COL_INDEX_B] .extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_RES] - .extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.base[DOT_PRODUCT_AIR_COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); trace.ext[dot_product_air_col_value_a(false)].extend(slice_0); trace.ext[DOT_PRODUCT_AIR_COL_VALUE_B].extend(slice_1); trace.ext[DOT_PRODUCT_AIR_COL_VALUE_RES].extend(vec![dot_product_result; size]); diff --git a/crates/lean_vm/src/tables/dot_product/mod.rs b/crates/lean_vm/src/tables/dot_product/mod.rs index 17b4167f..3954abca 100644 --- a/crates/lean_vm/src/tables/dot_product/mod.rs +++ b/crates/lean_vm/src/tables/dot_product/mod.rs @@ -14,11 +14,7 @@ pub struct DotProductPrecompile; // BE = true for base-extension impl TableT for DotProductPrecompile { fn name(&self) -> &'static str { - if BE { - "dot_product_be" - } else { - "dot_product_ee" - } + if BE { "dot_product_be" } else { "dot_product_ee" } } fn identifier(&self) -> Table { diff --git a/crates/lean_vm/src/tables/execution/air.rs b/crates/lean_vm/src/tables/execution/air.rs index d5ea82f1..aca81b7a 100644 --- a/crates/lean_vm/src/tables/execution/air.rs +++ b/crates/lean_vm/src/tables/execution/air.rs @@ -6,8 +6,7 @@ use crate::{EF, ExecutionTable, ExtraDataForBuses, eval_virtual_bus_column}; pub const N_INSTRUCTION_COLUMNS: usize = 13; pub const N_COMMITTED_EXEC_COLUMNS: usize = 5; pub const N_MEMORY_VALUE_COLUMNS: usize = 3; // virtual (lookup into memory, with logup*) -pub const N_EXEC_AIR_COLUMNS: usize = - N_INSTRUCTION_COLUMNS + N_COMMITTED_EXEC_COLUMNS + N_MEMORY_VALUE_COLUMNS; +pub const N_EXEC_AIR_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_COMMITTED_EXEC_COLUMNS + N_MEMORY_VALUE_COLUMNS; // Instruction columns pub const COL_INDEX_OPERAND_A: usize = 0; @@ -132,12 +131,9 @@ impl Air for ExecutionTable { builder.assert_zero(add * (nu_b.clone() - (nu_a.clone() + nu_c.clone()))); builder.assert_zero(mul * (nu_b.clone() - nu_a.clone() * nu_c.clone())); - builder - .assert_zero(deref.clone() * (addr_c.clone() - (value_a.clone() + operand_c.clone()))); + builder.assert_zero(deref.clone() * (addr_c.clone() - (value_a.clone() + operand_c.clone()))); builder.assert_zero(deref.clone() * aux.clone() * (value_c.clone() - nu_b.clone())); - builder.assert_zero( - deref.clone() * (aux.clone() - AB::F::ONE) * (value_c.clone() - fp.clone()), - ); + builder.assert_zero(deref.clone() * (aux.clone() - AB::F::ONE) * (value_c.clone() - fp.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_pc.clone() - pc_plus_one.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_fp.clone() - fp.clone())); @@ -145,9 +141,7 @@ impl Air for ExecutionTable { builder.assert_zero(jump.clone() * nu_a.clone() * nu_a_minus_one.clone()); builder.assert_zero(jump.clone() * nu_a.clone() * (next_pc.clone() - nu_b.clone())); builder.assert_zero(jump.clone() * nu_a.clone() * (next_fp.clone() - nu_c.clone())); - builder.assert_zero( - jump.clone() * nu_a_minus_one.clone() * (next_pc.clone() - pc_plus_one.clone()), - ); + builder.assert_zero(jump.clone() * nu_a_minus_one.clone() * (next_pc.clone() - pc_plus_one.clone())); builder.assert_zero(jump.clone() * nu_a_minus_one.clone() * (next_fp.clone() - fp.clone())); } } diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 8d9dcb22..643dcc7f 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -90,14 +90,7 @@ impl TableT for ExecutionTable { } #[inline(always)] - fn execute( - &self, - _: F, - _: F, - _: F, - _: usize, - _: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { + fn execute(&self, _: F, _: F, _: F, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { unreachable!() } } diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index c48e05b6..ee9a36e5 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -139,10 +139,7 @@ impl TableT for Poseidon16Precompile { input[..VECTOR_LEN].copy_from_slice(&arg0); input[VECTOR_LEN..].copy_from_slice(&arg1); - let output = match ctx - .poseidon16_precomputed - .get(*ctx.n_poseidon16_precomputed_used) - { + let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { Some(precomputed) if precomputed.0 == input => { *ctx.n_poseidon16_precomputed_used += 1; precomputed.1 @@ -154,10 +151,7 @@ impl TableT for Poseidon16Precompile { let (index_res_b, res_b): (F, [F; VECTOR_LEN]) = if is_compression { (F::from_usize(ZERO_VEC_PTR), [F::ZERO; VECTOR_LEN]) } else { - ( - index_res_a + F::ONE, - output[VECTOR_LEN..].try_into().unwrap(), - ) + (index_res_a + F::ONE, output[VECTOR_LEN..].try_into().unwrap()) }; ctx.memory.set_vector(index_res_a.to_usize(), res_a)?; @@ -223,9 +217,6 @@ impl Air for Poseidon16Precompile { builder.assert_bool(flag.clone()); builder.assert_bool(compression.clone()); - builder.assert_eq( - index_res_bis, - (index_res + AB::F::ONE) * (AB::F::ONE - compression), - ); + builder.assert_eq(index_res_bis, (index_res + AB::F::ONE) * (AB::F::ONE - compression)); } } diff --git a/crates/lean_vm/src/tables/poseidon_24/mod.rs b/crates/lean_vm/src/tables/poseidon_24/mod.rs index 0f1db254..3573d6d7 100644 --- a/crates/lean_vm/src/tables/poseidon_24/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_24/mod.rs @@ -125,10 +125,7 @@ impl TableT for Poseidon24Precompile { input[VECTOR_LEN..2 * VECTOR_LEN].copy_from_slice(&arg1); input[2 * VECTOR_LEN..].copy_from_slice(&arg2); - let output = match ctx - .poseidon24_precomputed - .get(*ctx.n_poseidon24_precomputed_used) - { + let output = match ctx.poseidon24_precomputed.get(*ctx.n_poseidon24_precomputed_used) { Some(precomputed) if precomputed.0 == input => { *ctx.n_poseidon24_precomputed_used += 1; precomputed.1 diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index 7c71bf6e..6a5ea5dd 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -5,8 +5,7 @@ use std::{any::TypeId, array, mem::transmute_copy}; use utils::ToUsize; use sub_protocols::{ - ColDims, ExtensionCommitmentFromBaseProver, ExtensionCommitmentFromBaseVerifier, - committed_dims_extension_from_base, + ColDims, ExtensionCommitmentFromBaseProver, ExtensionCommitmentFromBaseVerifier, committed_dims_extension_from_base, }; // Zero padding will be added to each at least, if this minimum is not reached @@ -131,8 +130,7 @@ impl>> ExtraDataForBuses { assert_eq!(TypeId::of::(), TypeId::of::>()); unsafe { transmute_copy::<_, _>(&( - self.fingerprint_challenge_powers - .map(|c| EFPacking::::from(c)), + self.fingerprint_challenge_powers.map(|c| EFPacking::::from(c)), EFPacking::::from(self.bus_beta), )) } @@ -287,10 +285,7 @@ pub trait TableT: Air { assert_eq!(air_values_f.len(), self.n_columns_f_air()); let mut statements = Vec::new(); for lookup in self.normal_lookups_f() { - statements.push(vec![Evaluation::new( - air_point.clone(), - air_values_f[lookup.values], - )]); + statements.push(vec![Evaluation::new(air_point.clone(), air_values_f[lookup.values])]); } statements } @@ -302,10 +297,7 @@ pub trait TableT: Air { assert_eq!(air_values_ef.len(), self.n_columns_ef_air()); let mut statements = Vec::new(); for lookup in self.normal_lookups_ef() { - statements.push(vec![Evaluation::new( - air_point.clone(), - air_values_ef[lookup.values], - )]); + statements.push(vec![Evaluation::new(air_point.clone(), air_values_ef[lookup.values])]); } statements } @@ -388,10 +380,7 @@ pub trait TableT: Air { } cols } - fn vector_lookup_values_columns<'a>( - &self, - trace: &'a TableTrace, - ) -> Vec<[&'a [F]; VECTOR_LEN]> { + fn vector_lookup_values_columns<'a>(&self, trace: &'a TableTrace) -> Vec<[&'a [F]; VECTOR_LEN]> { let mut cols = Vec::new(); for lookup in self.vector_lookups() { cols.push(array::from_fn(|i| &trace.base[lookup.values[i]][..])); @@ -420,10 +409,7 @@ pub trait TableT: Air { default_indexes } fn find_committed_column_index_f(&self, col: ColIndex) -> usize { - self.commited_columns_f() - .iter() - .position(|&c| c == col) - .unwrap() + self.commited_columns_f().iter().position(|&c| c == col).unwrap() } } @@ -445,11 +431,7 @@ impl Bus { BusTable::Constant(t) => F::from_usize(t.index()), BusTable::Variable(col) => padding_row_f[*col], }; - let default_data = self - .data - .iter() - .map(|&col| padding_row_f[col]) - .collect::>(); + let default_data = self.data.iter().map(|&col| padding_row_f[col]).collect::>(); EF::from(default_selector * self.direction.to_field_flag() * F::from_usize(padding)) / (bus_challenge + finger_print(default_table, &default_data, fingerprint_challenge)) } diff --git a/crates/lean_vm/tests/test_lean_vm.rs b/crates/lean_vm/tests/test_lean_vm.rs index a4cfe933..5b123787 100644 --- a/crates/lean_vm/tests/test_lean_vm.rs +++ b/crates/lean_vm/tests/test_lean_vm.rs @@ -34,8 +34,7 @@ const POSEIDON24_ARG_A_VALUES: [[u64; VECTOR_LEN]; 2] = [ ]; const POSEIDON24_ARG_B_VALUES: [u64; VECTOR_LEN] = [221, 222, 223, 224, 225, 226, 227, 228]; const DOT_ARG0_VALUES: [[u64; DIMENSION]; DOT_PRODUCT_LEN] = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]; -const DOT_ARG1_VALUES: [[u64; DIMENSION]; DOT_PRODUCT_LEN] = - [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]; +const DOT_ARG1_VALUES: [[u64; DIMENSION]; DOT_PRODUCT_LEN] = [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]; const MLE_COEFF_VALUES: [u64; 1 << MLE_N_VARS] = [7, 9]; const MLE_POINT_VALUES: [u64; DIMENSION] = [21, 22, 23, 24, 25]; @@ -83,27 +82,12 @@ fn set_base_slice(public_input: &mut [F], start_index: usize, values: &[u64]) { fn build_test_case() -> (Bytecode, Vec) { let mut public_input = vec![F::ZERO; PUBLIC_INPUT_LEN]; - set_vector( - &mut public_input, - POSEIDON16_ARG_A_PTR, - &POSEIDON16_ARG_A_VALUES, - ); - set_vector( - &mut public_input, - POSEIDON16_ARG_B_PTR, - &POSEIDON16_ARG_B_VALUES, - ); + set_vector(&mut public_input, POSEIDON16_ARG_A_PTR, &POSEIDON16_ARG_A_VALUES); + set_vector(&mut public_input, POSEIDON16_ARG_B_PTR, &POSEIDON16_ARG_B_VALUES); - let poseidon24_chunks = [ - &POSEIDON24_ARG_A_VALUES[0][..], - &POSEIDON24_ARG_A_VALUES[1][..], - ]; + let poseidon24_chunks = [&POSEIDON24_ARG_A_VALUES[0][..], &POSEIDON24_ARG_A_VALUES[1][..]]; set_multivector(&mut public_input, POSEIDON24_ARG_A_PTR, &poseidon24_chunks); - set_vector( - &mut public_input, - POSEIDON24_ARG_B_PTR, - &POSEIDON24_ARG_B_VALUES, - ); + set_vector(&mut public_input, POSEIDON24_ARG_B_PTR, &POSEIDON24_ARG_B_VALUES); set_ef_slice(&mut public_input, DOT_ARG0_PTR, &DOT_ARG0_VALUES); set_ef_slice(&mut public_input, DOT_ARG1_PTR, &DOT_ARG1_VALUES); @@ -180,9 +164,7 @@ fn build_test_case() -> (Bytecode, Vec) { table: Table::dot_product_ee(), arg_a: MemOrConstant::Constant(f(DOT_ARG0_PTR as u64)), arg_b: MemOrConstant::Constant(f(DOT_ARG1_PTR as u64)), - arg_c: MemOrFp::MemoryAfterFp { - offset: DOT_RES_OFFSET, - }, + arg_c: MemOrFp::MemoryAfterFp { offset: DOT_RES_OFFSET }, aux: DOT_PRODUCT_LEN, }, ]; @@ -200,13 +182,7 @@ fn build_test_case() -> (Bytecode, Vec) { fn run_program() -> (Bytecode, ExecutionResult) { let (bytecode, public_input) = build_test_case(); - let result = execute_bytecode( - &bytecode, - (&public_input, &[]), - 1 << 20, - false, - (&vec![], &vec![]), - ); + let result = execute_bytecode(&bytecode, (&public_input, &[]), 1 << 20, false, (&vec![], &vec![])); println!("{}", result.summary); (bytecode, result) } @@ -224,12 +200,6 @@ fn test_operation_compute() { let add = Operation::Add; let mul = Operation::Mul; - assert_eq!( - add.compute(F::from_usize(2), F::from_usize(3)), - F::from_usize(5) - ); - assert_eq!( - mul.compute(F::from_usize(2), F::from_usize(3)), - F::from_usize(6) - ); + assert_eq!(add.compute(F::from_usize(2), F::from_usize(3)), F::from_usize(5)); + assert_eq!(mul.compute(F::from_usize(2), F::from_usize(3)), F::from_usize(6)); } diff --git a/crates/lookup/src/logup_star.rs b/crates/lookup/src/logup_star.rs index 478bcad4..aeb96bb7 100644 --- a/crates/lookup/src/logup_star.rs +++ b/crates/lookup/src/logup_star.rs @@ -48,14 +48,13 @@ where // TODO use max_index let _ = max_index; - let (poly_eq_point_packed, pushforward_packed, table_packed) = - info_span!("packing").in_scope(|| { - ( - MleRef::Extension(poly_eq_point).pack_if(packing), - MleRef::Extension(pushforward).pack_if(packing), - table.pack_if(packing), - ) - }); + let (poly_eq_point_packed, pushforward_packed, table_packed) = info_span!("packing").in_scope(|| { + ( + MleRef::Extension(poly_eq_point).pack_if(packing), + MleRef::Extension(pushforward).pack_if(packing), + table.pack_if(packing), + ) + }); let (sc_point, inner_evals, prod) = info_span!("logup_star sumcheck", table_length, indexes_length).in_scope(|| { @@ -96,10 +95,7 @@ where let (_, claim_point_left, _, eval_c_minus_indexes) = prove_gkr_quotient::<_, 2>( prover_state, - &MleGroupRef::merge(&[ - &poly_eq_point_packed.by_ref(), - &c_minus_indexes_packed.by_ref(), - ]), + &MleGroupRef::merge(&[&poly_eq_point_packed.by_ref(), &c_minus_indexes_packed.by_ref()]), ); let c_minus_increments = MleRef::Extension( @@ -111,10 +107,7 @@ where let c_minus_increments_packed = c_minus_increments.pack_if(packing); let (_, claim_point_right, pushforward_final_eval, _) = prove_gkr_quotient::<_, 2>( prover_state, - &MleGroupRef::merge(&[ - &pushforward_packed.by_ref(), - &c_minus_increments_packed.by_ref(), - ]), + &MleGroupRef::merge(&[&pushforward_packed.by_ref(), &c_minus_increments_packed.by_ref()]), ); let on_indexes = Evaluation::new(claim_point_left, c - eval_c_minus_indexes); @@ -140,16 +133,9 @@ where EF: ExtensionField>, PF: PrimeField64, { - let (sum, postponed) = - sumcheck_verify(verifier_state, log_table_len, 2).map_err(|_| ProofError::InvalidProof)?; + let (sum, postponed) = sumcheck_verify(verifier_state, log_table_len, 2).map_err(|_| ProofError::InvalidProof)?; - if sum - != claims - .iter() - .zip(alpha.powers()) - .map(|(c, a)| c.value * a) - .sum::() - { + if sum != claims.iter().zip(alpha.powers()).map(|(c, a)| c.value * a).sum::() { return Err(ProofError::InvalidProof); } @@ -185,10 +171,7 @@ where return Err(ProofError::InvalidProof); } - on_pushforward.push(Evaluation::new( - claim_point_right.clone(), - pushforward_final_eval, - )); + on_pushforward.push(Evaluation::new(claim_point_right.clone(), pushforward_final_eval)); let big_endian_mle = claim_point_right .iter() @@ -275,11 +258,7 @@ mod tests { let challenger = build_challenger(); - let point = MultilinearPoint( - (0..log_indexes_len) - .map(|_| rng.random()) - .collect::>(), - ); + let point = MultilinearPoint((0..log_indexes_len).map(|_| rng.random()).collect::>()); let mut prover_state = FSProver::new(challenger.clone()); let eval = values.evaluate(&point); @@ -301,20 +280,11 @@ mod tests { println!("Proving logup_star took {} ms", time.elapsed().as_millis()); let mut verifier_state = FSVerifier::new(prover_state.proof_data().to_vec(), challenger); - let verifier_statements = verify_logup_star( - &mut verifier_state, - log_table_len, - log_indexes_len, - &[claim], - EF::ONE, - ) - .unwrap(); + let verifier_statements = + verify_logup_star(&mut verifier_state, log_table_len, log_indexes_len, &[claim], EF::ONE).unwrap(); assert_eq!(&verifier_statements, &prover_statements); - assert_eq!( - prover_state.challenger().state(), - verifier_state.challenger().state() - ); + assert_eq!(prover_state.challenger().state(), verifier_state.challenger().state()); assert_eq!( indexes.evaluate(&verifier_statements.on_indexes.point), @@ -339,10 +309,7 @@ mod tests { .map(|x| (0..n_muls).map(|_| *x).product::>()) .sum::>(); assert!(sum != EFPacking::::ONE); - println!( - "Optimal time we can hope for: {} ms", - time.elapsed().as_millis() - ); + println!("Optimal time we can hope for: {} ms", time.elapsed().as_millis()); } } } diff --git a/crates/lookup/src/quotient_gkr.rs b/crates/lookup/src/quotient_gkr.rs index 9f25f2a8..a23b088f 100644 --- a/crates/lookup/src/quotient_gkr.rs +++ b/crates/lookup/src/quotient_gkr.rs @@ -17,8 +17,7 @@ pub fn prove_gkr_quotient>, const N_GROUPS: usize>( ) -> (EF, MultilinearPoint, EF, EF) { assert!(N_GROUPS.is_power_of_two() && N_GROUPS >= 2); assert_eq!(numerators_and_denominators.n_columns(), 2); - let mut layers: Vec> = - vec![split_mle_group(numerators_and_denominators, N_GROUPS / 2).into()]; + let mut layers: Vec> = vec![split_mle_group(numerators_and_denominators, N_GROUPS / 2).into()]; loop { let prev_layer: MleGroup<'_, EF> = layers.last().unwrap().by_ref().into(); @@ -37,8 +36,7 @@ pub fn prove_gkr_quotient>, const N_GROUPS: usize>( let last_layer = last_layer.as_owned_or_clone().as_extension().unwrap(); assert_eq!(last_layer.len(), N_GROUPS); - let last_nums_and_dens: [[_; 2]; N_GROUPS] = - array::from_fn(|i| last_layer[i].to_vec().try_into().unwrap()); + let last_nums_and_dens: [[_; 2]; N_GROUPS] = array::from_fn(|i| last_layer[i].to_vec().try_into().unwrap()); for nd in &last_nums_and_dens { prover_state.add_extension_scalars(nd); } @@ -56,21 +54,9 @@ pub fn prove_gkr_quotient>, const N_GROUPS: usize>( .collect::>(); for layer in layers[1..].iter().rev() { - (point, claims) = prove_gkr_quotient_step::<_, N_GROUPS>( - prover_state, - layer.by_ref(), - &point, - claims, - false, - ); + (point, claims) = prove_gkr_quotient_step::<_, N_GROUPS>(prover_state, layer.by_ref(), &point, claims, false); } - (point, claims) = prove_gkr_quotient_step::<_, N_GROUPS>( - prover_state, - layers[0].by_ref(), - &point, - claims, - true, - ); + (point, claims) = prove_gkr_quotient_step::<_, N_GROUPS>(prover_state, layers[0].by_ref(), &point, claims, true); assert_eq!(claims.len(), 2); (quotient, point, claims[0], claims[1]) } @@ -113,18 +99,8 @@ fn prove_gkr_quotient_step>, const N_GROUPS: usize>( let next_claims = if univariate_skip { let selectors = univariate_selectors(log2_strict_usize(N_GROUPS)); vec![ - evaluate_univariate_multilinear::<_, _, _, false>( - &inner_evals[..N_GROUPS], - &[beta], - &selectors, - None, - ), - evaluate_univariate_multilinear::<_, _, _, false>( - &inner_evals[N_GROUPS..], - &[beta], - &selectors, - None, - ), + evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[..N_GROUPS], &[beta], &selectors, None), + evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[N_GROUPS..], &[beta], &selectors, None), ] } else { inner_evals @@ -160,8 +136,7 @@ pub fn verify_gkr_quotient>, const N_GROUPS: usize>( .collect::>(); for i in 1..n_vars - log2_strict_usize(N_GROUPS) { - (point, claims) = - verify_gkr_quotient_step::<_, N_GROUPS>(verifier_state, i, &point, claims, false)?; + (point, claims) = verify_gkr_quotient_step::<_, N_GROUPS>(verifier_state, i, &point, claims, false)?; } (point, claims) = verify_gkr_quotient_step::<_, N_GROUPS>( verifier_state, @@ -209,18 +184,8 @@ fn verify_gkr_quotient_step>, const N_GROUPS: usize>( let next_claims = if univariate_skip { let selectors = univariate_selectors(log2_strict_usize(N_GROUPS)); vec![ - evaluate_univariate_multilinear::<_, _, _, false>( - &inner_evals[..N_GROUPS], - &[beta], - &selectors, - None, - ), - evaluate_univariate_multilinear::<_, _, _, false>( - &inner_evals[N_GROUPS..], - &[beta], - &selectors, - None, - ), + evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[..N_GROUPS], &[beta], &selectors, None), + evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[N_GROUPS..], &[beta], &selectors, None), ] } else { inner_evals @@ -240,10 +205,7 @@ fn sum_quotients>>( ) -> MleGroupOwned { match numerators_and_denominators { MleGroupRef::ExtensionPacked(numerators_and_denominators) => { - MleGroupOwned::ExtensionPacked(sum_quotients_helper( - &numerators_and_denominators, - n_groups, - )) + MleGroupOwned::ExtensionPacked(sum_quotients_helper(&numerators_and_denominators, n_groups)) } MleGroupRef::Extension(numerators_and_denominators) => { MleGroupOwned::Extension(sum_quotients_helper(&numerators_and_denominators, n_groups)) @@ -263,8 +225,7 @@ fn sum_quotients_helper( let mut new_denominators = Vec::new(); let (prev_numerators, prev_denominators) = numerators_and_denominators.split_at(n_groups / 2); for i in 0..n_groups / 2 { - let (new_num, new_den) = - sum_quotients_2_by_2::(prev_numerators[i], prev_denominators[i]); + let (new_num, new_den) = sum_quotients_2_by_2::(prev_numerators[i], prev_denominators[i]); new_numerators.push(new_num); new_denominators.push(new_den); } @@ -297,9 +258,7 @@ fn split_mle_group<'a, EF: ExtensionField>>( ) -> MleGroupRef<'a, EF> { match polys { MleGroupRef::Extension(polys) => MleGroupRef::Extension(split_chunks(polys, n_groups)), - MleGroupRef::ExtensionPacked(polys) => { - MleGroupRef::ExtensionPacked(split_chunks(polys, n_groups)) - } + MleGroupRef::ExtensionPacked(polys) => MleGroupRef::ExtensionPacked(split_chunks(polys, n_groups)), _ => unreachable!(), } } @@ -314,9 +273,7 @@ fn split_chunks<'a, A>(numerators_and_denominators: &[&'a [A]], num_groups: usiz assert_eq!(slice.len(), n); res.extend(split_at_many( slice, - &(1..num_groups) - .map(|i| i * n / num_groups) - .collect::>(), + &(1..num_groups).map(|i| i * n / num_groups).collect::>(), )); } res @@ -352,48 +309,31 @@ mod tests { let denominators_indexes = (0..n) .map(|_| PF::::from_usize(rng.random_range(..n))) .collect::>(); - let denominators = denominators_indexes - .iter() - .map(|&i| c - i) - .collect::>(); + let denominators = denominators_indexes.iter().map(|&i| c - i).collect::>(); let real_quotient = sum_all_quotients(&numerators, &denominators); let mut prover_state = build_prover_state(); let time = Instant::now(); let prover_statements = prove_gkr_quotient::( &mut prover_state, - &MleGroupRef::ExtensionPacked(vec![ - &pack_extension(&numerators), - &pack_extension(&denominators), - ]), + &MleGroupRef::ExtensionPacked(vec![&pack_extension(&numerators), &pack_extension(&denominators)]), ); println!("Proving time: {:?}", time.elapsed()); let mut verifier_state = build_verifier_state(&prover_state); - let verifier_statements = - verify_gkr_quotient::(&mut verifier_state, log_n).unwrap(); + let verifier_statements = verify_gkr_quotient::(&mut verifier_state, log_n).unwrap(); assert_eq!(&verifier_statements, &prover_statements); let (retrieved_quotient, claim_point, claim_num, claim_den) = verifier_statements; assert_eq!(retrieved_quotient, real_quotient); let selectors = univariate_selectors::>(log2_strict_usize(N_GROUPS)); assert_eq!( - evaluate_univariate_multilinear::<_, _, _, true>( - &numerators, - &claim_point, - &selectors, - None - ), + evaluate_univariate_multilinear::<_, _, _, true>(&numerators, &claim_point, &selectors, None), claim_num ); assert_eq!( - evaluate_univariate_multilinear::<_, _, _, true>( - &denominators, - &claim_point, - &selectors, - None - ), + evaluate_univariate_multilinear::<_, _, _, true>(&denominators, &claim_point, &selectors, None), claim_den ); } diff --git a/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs b/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs index c0d7edac..4cecb47e 100644 --- a/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs +++ b/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs @@ -1,9 +1,7 @@ use std::array; use multilinear_toolkit::prelude::*; -use p3_koala_bear::{ - GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, -}; +use p3_koala_bear::{GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters}; use p3_monty_31::InternalLayerBaseParameters; use p3_poseidon2::GenericPoseidon2LinearLayers; diff --git a/crates/poseidon_circuit/src/gkr_layers/compression.rs b/crates/poseidon_circuit/src/gkr_layers/compression.rs index b823a7e1..d8a55f63 100644 --- a/crates/poseidon_circuit/src/gkr_layers/compression.rs +++ b/crates/poseidon_circuit/src/gkr_layers/compression.rs @@ -41,9 +41,7 @@ where res += EFPacking::::from(alpha_powers[i]) * point[i]; } for i in self.compressed_output..WIDTH { - res += EFPacking::::from(alpha_powers[i]) - * point[i] - * (PFPacking::::ONE - compressed); + res += EFPacking::::from(alpha_powers[i]) * point[i] * (PFPacking::::ONE - compressed); } res diff --git a/crates/poseidon_circuit/src/gkr_layers/mod.rs b/crates/poseidon_circuit/src/gkr_layers/mod.rs index 53c171ca..b5ddff97 100644 --- a/crates/poseidon_circuit/src/gkr_layers/mod.rs +++ b/crates/poseidon_circuit/src/gkr_layers/mod.rs @@ -31,8 +31,7 @@ impl PoseidonGKRLayers unsafe { Self::build_generic( - &*(&KOALABEAR_RC16_EXTERNAL_INITIAL as *const [[F; 16]] - as *const [[F; WIDTH]]), + &*(&KOALABEAR_RC16_EXTERNAL_INITIAL as *const [[F; 16]] as *const [[F; WIDTH]]), &KOALABEAR_RC16_INTERNAL, &*(&KOALABEAR_RC16_EXTERNAL_FINAL as *const [[F; 16]] as *const [[F; WIDTH]]), compressed_output, @@ -40,8 +39,7 @@ impl PoseidonGKRLayers unsafe { Self::build_generic( - &*(&KOALABEAR_RC24_EXTERNAL_INITIAL as *const [[F; 24]] - as *const [[F; WIDTH]]), + &*(&KOALABEAR_RC24_EXTERNAL_INITIAL as *const [[F; 24]] as *const [[F; WIDTH]]), &KOALABEAR_RC24_INTERNAL, &*(&KOALABEAR_RC24_EXTERNAL_FINAL as *const [[F; 24]] as *const [[F; WIDTH]]), compressed_output, diff --git a/crates/poseidon_circuit/src/lib.rs b/crates/poseidon_circuit/src/lib.rs index ef815677..965d238f 100644 --- a/crates/poseidon_circuit/src/lib.rs +++ b/crates/poseidon_circuit/src/lib.rs @@ -26,8 +26,8 @@ pub(crate) type EF = QuinticExtensionFieldKB; /// remain to be proven #[derive(Debug, Clone)] pub struct GKRPoseidonResult { - pub output_statements: MultiEvaluation, // of length width - pub input_statements: MultiEvaluation, // of length width - pub cubes_statements: MultiEvaluation, // of length n_committed_cubes + pub output_statements: MultiEvaluation, // of length width + pub input_statements: MultiEvaluation, // of length width + pub cubes_statements: MultiEvaluation, // of length n_committed_cubes pub on_compression_selector: Option>, // univariate_skips = 1 here (TODO dont do this) } diff --git a/crates/poseidon_circuit/src/prove.rs b/crates/poseidon_circuit/src/prove.rs index 84d68311..1b7f1d09 100644 --- a/crates/poseidon_circuit/src/prove.rs +++ b/crates/poseidon_circuit/src/prove.rs @@ -52,12 +52,8 @@ where let mut output_claims = vec![]; let mut claims = vec![]; for evals in inner_evals { - output_claims - .push(evals.evaluate(&MultilinearPoint(point[..univariate_skips].to_vec()))); - claims.push(dot_product( - selectors_at_alpha.iter().copied(), - evals.into_iter(), - )) + output_claims.push(evals.evaluate(&MultilinearPoint(point[..univariate_skips].to_vec()))); + claims.push(dot_product(selectors_at_alpha.iter().copied(), evals.into_iter())) } point = [vec![alpha], point[univariate_skips..].to_vec()].concat(); (output_claims, claims) @@ -95,12 +91,7 @@ where None }; - for (layer, full_round_constants) in witness - .final_full_layers - .iter() - .zip(&layers.final_full_rounds) - .rev() - { + for (layer, full_round_constants) in witness.final_full_layers.iter().zip(&layers.final_full_rounds).rev() { claims = apply_matrix(&inv_mds_matrix, &claims); (point, claims) = prove_gkr_round( @@ -216,10 +207,7 @@ fn prove_gkr_round> + 'static>( ) -> (Vec, Vec) { let batching_scalar = prover_state.sample(); let batching_scalars_powers = batching_scalar.powers().collect_n(output_claims.len()); - let batched_claim: EF = dot_product( - output_claims.iter().copied(), - batching_scalars_powers.iter().copied(), - ); + let batched_claim: EF = dot_product(output_claims.iter().copied(), batching_scalars_powers.iter().copied()); let (sumcheck_point, sumcheck_inner_evals, sumcheck_final_sum) = sumcheck_prove( univariate_skips, @@ -326,12 +314,8 @@ fn inner_evals_on_commited_columns( .map(|col| { col.chunks_exact(eq_mle.len()) .map(|chunk| { - let ef_sum = dot_product::, _, _>( - eq_mle.iter().copied(), - chunk.iter().copied(), - ); - as PackedFieldExtension>::to_ext_iter([ef_sum]) - .sum::() + let ef_sum = dot_product::, _, _>(eq_mle.iter().copied(), chunk.iter().copied()); + as PackedFieldExtension>::to_ext_iter([ef_sum]).sum::() }) .collect::>() }) @@ -340,11 +324,9 @@ fn inner_evals_on_commited_columns( prover_state.add_extension_scalars(&inner_evals); let mut values_to_prove = vec![]; let pcs_batching_scalars_inputs = prover_state.sample_vec(univariate_skips); - let point_to_prove = - MultilinearPoint([pcs_batching_scalars_inputs.clone(), point[1..].to_vec()].concat()); + let point_to_prove = MultilinearPoint([pcs_batching_scalars_inputs.clone(), point[1..].to_vec()].concat()); for col_inner_evals in inner_evals.chunks_exact(1 << univariate_skips) { - values_to_prove - .push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); + values_to_prove.push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); } MultiEvaluation::new(point_to_prove, values_to_prove) } diff --git a/crates/poseidon_circuit/src/tests.rs b/crates/poseidon_circuit/src/tests.rs index ac8c99e6..a0d3959a 100644 --- a/crates/poseidon_circuit/src/tests.rs +++ b/crates/poseidon_circuit/src/tests.rs @@ -1,25 +1,21 @@ use multilinear_toolkit::prelude::*; -use p3_koala_bear::{ - KoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, QuinticExtensionFieldKB, -}; +use p3_koala_bear::{KoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, QuinticExtensionFieldKB}; use p3_monty_31::InternalLayerBaseParameters; use rand::{Rng, SeedableRng, rngs::StdRng}; use std::{array, time::Instant}; use sub_protocols::{ - ColDims, packed_pcs_commit, packed_pcs_global_statements_for_prover, - packed_pcs_global_statements_for_verifier, packed_pcs_parse_commitment, + ColDims, packed_pcs_commit, packed_pcs_global_statements_for_prover, packed_pcs_global_statements_for_verifier, + packed_pcs_parse_commitment, }; use utils::{ - build_prover_state, build_verifier_state, init_tracing, poseidon16_permute_mut, - poseidon24_permute_mut, transposed_par_iter_mut, -}; -use whir_p3::{ - FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles, + build_prover_state, build_verifier_state, init_tracing, poseidon16_permute_mut, poseidon24_permute_mut, + transposed_par_iter_mut, }; +use whir_p3::{FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles}; use crate::{ - GKRPoseidonResult, default_cube_layers, generate_poseidon_witness, - gkr_layers::PoseidonGKRLayers, prove_poseidon_gkr, verify_poseidon_gkr, + GKRPoseidonResult, default_cube_layers, generate_poseidon_witness, gkr_layers::PoseidonGKRLayers, + prove_poseidon_gkr, verify_poseidon_gkr, }; type F = KoalaBear; @@ -35,11 +31,7 @@ fn test_poseidon_benchmark() { run_poseidon_benchmark::<16, 16, 3>(12, true); } -pub fn run_poseidon_benchmark< - const WIDTH: usize, - const N_COMMITED_CUBES: usize, - const UNIVARIATE_SKIPS: usize, ->( +pub fn run_poseidon_benchmark( log_n_poseidons: usize, compress: bool, ) where @@ -64,17 +56,11 @@ pub fn run_poseidon_benchmark< let n_poseidons = 1 << log_n_poseidons; let n_compressions = if compress { n_poseidons / 3 } else { 0 }; - let perm_inputs = (0..n_poseidons) - .map(|_| rng.random()) - .collect::>(); - let input: [_; WIDTH] = - array::from_fn(|i| perm_inputs.par_iter().map(|x| x[i]).collect::>()); - let input_packed: [_; WIDTH] = - array::from_fn(|i| PFPacking::::pack_slice(&input[i]).to_vec()); + let perm_inputs = (0..n_poseidons).map(|_| rng.random()).collect::>(); + let input: [_; WIDTH] = array::from_fn(|i| perm_inputs.par_iter().map(|x| x[i]).collect::>()); + let input_packed: [_; WIDTH] = array::from_fn(|i| PFPacking::::pack_slice(&input[i]).to_vec()); - let layers = PoseidonGKRLayers::::build( - compress.then_some(COMPRESSION_OUTPUT_WIDTH), - ); + let layers = PoseidonGKRLayers::::build(compress.then_some(COMPRESSION_OUTPUT_WIDTH)); let default_cubes = default_cube_layers::(&layers); @@ -87,14 +73,7 @@ pub fn run_poseidon_benchmark< let log_smallest_decomposition_chunk = 0; // unused because everything is a power of 2 - let ( - mut verifier_state, - proof_size_pcs, - proof_size_gkr, - output_layer, - prover_duration, - output_statements_prover, - ) = { + let (mut verifier_state, proof_size_pcs, proof_size_gkr, output_layer, prover_duration, output_statements_prover) = { // ---------------------------------------------------- PROVER ---------------------------------------------------- let prover_time = Instant::now(); @@ -153,10 +132,7 @@ pub fn run_poseidon_benchmark< if let Some(on_compression_selector) = on_compression_selector { assert_eq!( on_compression_selector.value, - mle_of_zeros_then_ones( - (1 << log_n_poseidons) - n_compressions, - &on_compression_selector.point, - ) + mle_of_zeros_then_ones((1 << log_n_poseidons) - n_compressions, &on_compression_selector.point,) ); } @@ -236,10 +212,7 @@ pub fn run_poseidon_benchmark< if let Some(on_compression_selector) = on_compression_selector { assert_eq!( on_compression_selector.value, - mle_of_zeros_then_ones( - (1 << log_n_poseidons) - n_compressions, - &on_compression_selector.point, - ) + mle_of_zeros_then_ones((1 << log_n_poseidons) - n_compressions, &on_compression_selector.point,) ); } @@ -264,11 +237,7 @@ pub fn run_poseidon_benchmark< .unwrap(); whir_config - .verify::( - &mut verifier_state, - &parsed_pcs_commitment, - global_statements, - ) + .verify::(&mut verifier_state, &parsed_pcs_commitment, global_statements) .unwrap(); output_statements }; @@ -329,8 +298,7 @@ pub fn run_poseidon_benchmark< &output_statements_verifier.values, &output_layer .iter() - .map(|layer| PFPacking::::unpack_slice(layer) - .evaluate(&output_statements_verifier.point)) + .map(|layer| PFPacking::::unpack_slice(layer).evaluate(&output_statements_verifier.point)) .collect::>() ); diff --git a/crates/poseidon_circuit/src/utils.rs b/crates/poseidon_circuit/src/utils.rs index 36674132..a610b926 100644 --- a/crates/poseidon_circuit/src/utils.rs +++ b/crates/poseidon_circuit/src/utils.rs @@ -1,9 +1,7 @@ use std::array; use multilinear_toolkit::prelude::*; -use p3_koala_bear::{ - GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, -}; +use p3_koala_bear::{GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters}; use p3_monty_31::InternalLayerBaseParameters; use p3_poseidon2::GenericPoseidon2LinearLayers; use tracing::instrument; @@ -11,8 +9,7 @@ use tracing::instrument; use crate::F; #[instrument(skip_all)] -pub fn build_poseidon_inv_matrices() --> ([[F; WIDTH]; WIDTH], [[F; WIDTH]; WIDTH]) +pub fn build_poseidon_inv_matrices() -> ([[F; WIDTH]; WIDTH], [[F; WIDTH]; WIDTH]) where KoalaBearInternalLayerParameters: InternalLayerBaseParameters, { diff --git a/crates/poseidon_circuit/src/verify.rs b/crates/poseidon_circuit/src/verify.rs index aa5863d9..e3e8eca6 100644 --- a/crates/poseidon_circuit/src/verify.rs +++ b/crates/poseidon_circuit/src/verify.rs @@ -3,8 +3,8 @@ use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; use p3_monty_31::InternalLayerBaseParameters; use crate::{ - CompressionComputation, EF, F, FullRoundComputation, GKRPoseidonResult, - PartialRoundComputation, build_poseidon_inv_matrices, gkr_layers::PoseidonGKRLayers, + CompressionComputation, EF, F, FullRoundComputation, GKRPoseidonResult, PartialRoundComputation, + build_poseidon_inv_matrices, gkr_layers::PoseidonGKRLayers, }; pub fn verify_poseidon_gkr( @@ -38,13 +38,8 @@ where .map(|selector| selector.evaluate(alpha)) .collect::>(); for evals in inner_evals { - output_claims.push(evals.evaluate(&MultilinearPoint( - output_claim_point[..univariate_skips].to_vec(), - ))); - claims.push(dot_product( - selectors_at_alpha.iter().copied(), - evals.into_iter(), - )) + output_claims.push(evals.evaluate(&MultilinearPoint(output_claim_point[..univariate_skips].to_vec()))); + claims.push(dot_product(selectors_at_alpha.iter().copied(), evals.into_iter())) } [vec![alpha], output_claim_point[univariate_skips..].to_vec()].concat() }; @@ -65,12 +60,8 @@ where let inner_evals = verifier_state .next_extension_scalars_vec(1 << univariate_skips) .unwrap(); - let recomputed_value = evaluate_univariate_multilinear::<_, _, _, false>( - &inner_evals, - &[point[0]], - &selectors, - None, - ); + let recomputed_value = + evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals, &[point[0]], &selectors, None); assert_eq!(claims.pop().unwrap(), recomputed_value); let epsilons = verifier_state.sample_vec(univariate_skips); let new_point = MultilinearPoint([epsilons.clone(), point[1..].to_vec()].concat()); @@ -118,9 +109,7 @@ where let mut pcs_point_for_cubes = vec![]; let mut pcs_evals_for_cubes = vec![]; if N_COMMITED_CUBES > 0 { - let claimed_cubes_evals = verifier_state - .next_extension_scalars_vec(N_COMMITED_CUBES) - .unwrap(); + let claimed_cubes_evals = verifier_state.next_extension_scalars_vec(N_COMMITED_CUBES).unwrap(); (point, claims) = verify_gkr_round( verifier_state, @@ -171,16 +160,10 @@ where let cubes_statements = if N_COMMITED_CUBES == 0 { Default::default() } else { - verify_inner_evals_on_commited_columns( - verifier_state, - &pcs_point_for_cubes, - &pcs_evals_for_cubes, - &selectors, - ) + verify_inner_evals_on_commited_columns(verifier_state, &pcs_point_for_cubes, &pcs_evals_for_cubes, &selectors) }; - let output_statements = - MultiEvaluation::new(MultilinearPoint(output_claim_point.to_vec()), output_claims); + let output_statements = MultiEvaluation::new(MultilinearPoint(output_claim_point.to_vec()), output_claims); GKRPoseidonResult { output_statements, input_statements, @@ -215,11 +198,7 @@ fn verify_gkr_round>>( let sumcheck_inner_evals = verifier_state.next_extension_scalars_vec(n_inputs).unwrap(); assert_eq!( computation.eval_extension(&sumcheck_inner_evals, &[], &batching_scalars_powers) - * eq_poly_with_skip( - &sumcheck_postponed_claim.point, - claim_point, - univariate_skips - ), + * eq_poly_with_skip(&sumcheck_postponed_claim.point, claim_point, univariate_skips), sumcheck_postponed_claim.value ); @@ -238,23 +217,16 @@ fn verify_inner_evals_on_commited_columns( .unwrap(); let pcs_batching_scalars_inputs = verifier_state.sample_vec(univariate_skips); let mut values_to_verif = vec![]; - let point_to_verif = - MultilinearPoint([pcs_batching_scalars_inputs.clone(), point[1..].to_vec()].concat()); + let point_to_verif = MultilinearPoint([pcs_batching_scalars_inputs.clone(), point[1..].to_vec()].concat()); for (&eval, col_inner_evals) in claimed_evals .iter() .zip(inner_evals_inputs.chunks_exact(1 << univariate_skips)) { assert_eq!( eval, - evaluate_univariate_multilinear::<_, _, _, false>( - col_inner_evals, - &point[..1], - selectors, - None - ) + evaluate_univariate_multilinear::<_, _, _, false>(col_inner_evals, &point[..1], selectors, None) ); - values_to_verif - .push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); + values_to_verif.push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); } MultiEvaluation::new(point_to_verif, values_to_verif) } diff --git a/crates/poseidon_circuit/src/witness_gen.rs b/crates/poseidon_circuit/src/witness_gen.rs index 8f79ce8c..98eba6b8 100644 --- a/crates/poseidon_circuit/src/witness_gen.rs +++ b/crates/poseidon_circuit/src/witness_gen.rs @@ -14,19 +14,17 @@ use crate::gkr_layers::PoseidonGKRLayers; #[derive(Debug, Hash)] pub struct PoseidonWitness { - pub input_layer: [Vec; WIDTH], // input of the permutation - pub initial_full_layers: Vec<[Vec; WIDTH]>, // just before cubing - pub batch_partial_round_input: Option<[Vec; WIDTH]>, // again, the input of the batch (partial) round - pub committed_cubes: [Vec; N_COMMITED_CUBES], // the cubes commited in the batch (partial) rounds + pub input_layer: [Vec; WIDTH], // input of the permutation + pub initial_full_layers: Vec<[Vec; WIDTH]>, // just before cubing + pub batch_partial_round_input: Option<[Vec; WIDTH]>, // again, the input of the batch (partial) round + pub committed_cubes: [Vec; N_COMMITED_CUBES], // the cubes commited in the batch (partial) rounds pub remaining_partial_round_layers: Vec<[Vec; WIDTH]>, // the input of each remaining partial round, just before cubing the first element pub final_full_layers: Vec<[Vec; WIDTH]>, // just before cubing pub output_layer: [Vec; WIDTH], // output of the permutation - pub compression: Option<(Vec, [Vec; WIDTH])>, // compression indicator column, compressed output + pub compression: Option<(Vec, [Vec; WIDTH])>, // compression indicator column, compressed output } -impl - PoseidonWitness, WIDTH, N_COMMITED_CUBES> -{ +impl PoseidonWitness, WIDTH, N_COMMITED_CUBES> { pub fn n_poseidons(&self) -> usize { self.input_layer[0].len() * packing_width::() } @@ -126,13 +124,7 @@ where } // #[instrument(skip_all)] -fn apply_full_round< - A, - const WIDTH: usize, - const CUBE: bool, - const MDS: bool, - const ADD_CONSTANTS: bool, ->( +fn apply_full_round( input_layers: &[Vec; WIDTH], constants: &[F; WIDTH], ) -> [Vec; WIDTH] diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs index a01dfed2..bfb9ae06 100644 --- a/crates/rec_aggregation/src/recursion.rs +++ b/crates/rec_aggregation/src/recursion.rs @@ -11,12 +11,9 @@ use rand::Rng; use rand::SeedableRng; use rand::rngs::StdRng; use utils::{ - build_prover_state, build_verifier_state, padd_with_zero_to_next_multiple_of, - padd_with_zero_to_next_power_of_two, -}; -use whir_p3::{ - FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles, + build_prover_state, build_verifier_state, padd_with_zero_to_next_multiple_of, padd_with_zero_to_next_power_of_two, }; +use whir_p3::{FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles}; const NUM_VARIABLES: usize = 25; @@ -33,8 +30,7 @@ pub fn run_whir_recursion_benchmark() { rs_domain_initial_reduction_factor: 3, }; - let mut recursion_config = - WhirConfig::::new(recursion_config_builder.clone(), NUM_VARIABLES); + let mut recursion_config = WhirConfig::::new(recursion_config_builder.clone(), NUM_VARIABLES); // TODO remove overriding this { @@ -48,14 +44,8 @@ pub fn run_whir_recursion_benchmark() { // println!("Whir parameters: {}", params.to_string()); for (i, round) in recursion_config.round_parameters.iter().enumerate() { program_str = program_str - .replace( - &format!("NUM_QUERIES_{i}_PLACEHOLDER"), - &round.num_queries.to_string(), - ) - .replace( - &format!("GRINDING_BITS_{i}_PLACEHOLDER"), - &round.pow_bits.to_string(), - ); + .replace(&format!("NUM_QUERIES_{i}_PLACEHOLDER"), &round.num_queries.to_string()) + .replace(&format!("GRINDING_BITS_{i}_PLACEHOLDER"), &round.pow_bits.to_string()); } program_str = program_str .replace( @@ -70,25 +60,16 @@ pub fn run_whir_recursion_benchmark() { for round in 0..=recursion_config.n_rounds() { program_str = program_str.replace( &format!("FOLDING_FACTOR_{round}_PLACEHOLDER"), - &recursion_config_builder - .folding_factor - .at_round(round) - .to_string(), + &recursion_config_builder.folding_factor.at_round(round).to_string(), ); } program_str = program_str.replace( "RS_REDUCTION_FACTOR_0_PLACEHOLDER", - &recursion_config_builder - .rs_domain_initial_reduction_factor - .to_string(), + &recursion_config_builder.rs_domain_initial_reduction_factor.to_string(), ); let mut rng = StdRng::seed_from_u64(0); - let polynomial = MleOwned::Base( - (0..1 << NUM_VARIABLES) - .map(|_| rng.random()) - .collect::>(), - ); + let polynomial = MleOwned::Base((0..1 << NUM_VARIABLES).map(|_| rng.random()).collect::>()); let point = MultilinearPoint::((0..NUM_VARIABLES).map(|_| rng.random()).collect()); @@ -116,12 +97,7 @@ pub fn run_whir_recursion_benchmark() { >::as_basis_coefficients_slice(&eval), )); - recursion_config.prove( - &mut prover_state, - statement.clone(), - witness, - &polynomial.by_ref(), - ); + recursion_config.prove(&mut prover_state, statement.clone(), witness, &polynomial.by_ref()); let first_folding_factor = recursion_config_builder.folding_factor.at_round(0); @@ -166,9 +142,7 @@ pub fn run_whir_recursion_benchmark() { { let mut verifier_state = build_verifier_state(&prover_state); - let parsed_commitment = recursion_config - .parse_commitment::(&mut verifier_state) - .unwrap(); + let parsed_commitment = recursion_config.parse_commitment::(&mut verifier_state).unwrap(); recursion_config .verify(&mut verifier_state, &parsed_commitment, statement) .unwrap(); @@ -180,14 +154,8 @@ pub fn run_whir_recursion_benchmark() { // in practice we will precompute all the possible values // (depending on the number of recursions + the number of xmss signatures) // (or even better: find a linear relation) - let no_vec_runtime_memory = execute_bytecode( - &bytecode, - (&public_input, &[]), - 1 << 20, - false, - (&vec![], &vec![]), - ) - .no_vec_runtime_memory; + let no_vec_runtime_memory = + execute_bytecode(&bytecode, (&public_input, &[]), 1 << 20, false, (&vec![], &vec![])).no_vec_runtime_memory; let time = Instant::now(); diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 2dd2388a..24b554d3 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -7,9 +7,7 @@ use rand::{Rng, SeedableRng, rngs::StdRng}; use std::time::Instant; use tracing::instrument; use whir_p3::precompute_dft_twiddles; -use xmss::{ - PhonyXmssSecretKey, Poseidon16History, Poseidon24History, V, XmssPublicKey, XmssSignature, -}; +use xmss::{PhonyXmssSecretKey, Poseidon16History, Poseidon24History, V, XmssPublicKey, XmssSignature}; const LOG_LIFETIME: usize = 32; @@ -188,10 +186,7 @@ pub fn run_xmss_benchmark(n_xmss: usize) { program_str = program_str .replace("LOG_LIFETIME_PLACE_HOLDER", &LOG_LIFETIME.to_string()) .replace("N_PUBLIC_KEYS_PLACE_HOLDER", &n_xmss.to_string()) - .replace( - "XMSS_SIG_SIZE_PLACE_HOLDER", - &xmss_signature_size_padded.to_string(), - ); + .replace("XMSS_SIG_SIZE_PLACE_HOLDER", &xmss_signature_size_padded.to_string()); let bitfield = vec![true; n_xmss]; // for now we use a dense bitfield @@ -204,8 +199,7 @@ pub fn run_xmss_benchmark(n_xmss: usize) { let mut rng = StdRng::seed_from_u64(i as u64); if bitfield[i] { let signature_index = rng.random_range(0..1 << LOG_LIFETIME); - let xmss_secret_key = - PhonyXmssSecretKey::::random(&mut rng, signature_index); + let xmss_secret_key = PhonyXmssSecretKey::::random(&mut rng, signature_index); let signature = xmss_secret_key.sign(&message_hash, &mut rng); (xmss_secret_key.public_key, Some(signature)) } else { @@ -220,16 +214,11 @@ pub fn run_xmss_benchmark(n_xmss: usize) { for bit in bitfield { public_input.push(F::from_bool(bit)); } - let min_public_input_size = - (1 << LOG_SMALLEST_DECOMPOSITION_CHUNK) - NONRESERVED_PROGRAM_INPUT_START; - public_input.extend(F::zero_vec( - min_public_input_size.saturating_sub(public_input.len()), - )); + let min_public_input_size = (1 << LOG_SMALLEST_DECOMPOSITION_CHUNK) - NONRESERVED_PROGRAM_INPUT_START; + public_input.extend(F::zero_vec(min_public_input_size.saturating_sub(public_input.len()))); public_input.insert( 0, - F::from_usize( - (public_input.len() + 8 + NONRESERVED_PROGRAM_INPUT_START).next_power_of_two(), - ), + F::from_usize((public_input.len() + 8 + NONRESERVED_PROGRAM_INPUT_START).next_power_of_two()), ); public_input.splice(1..1, F::zero_vec(7)); @@ -243,12 +232,7 @@ pub fn run_xmss_benchmark(n_xmss: usize) { .iter() .flat_map(|digest| digest.to_vec()), ); - private_input.extend( - signature - .merkle_proof - .iter() - .flat_map(|(_, neighbour)| *neighbour), - ); + private_input.extend(signature.merkle_proof.iter().flat_map(|(_, neighbour)| *neighbour)); private_input.extend( signature .merkle_proof @@ -310,11 +294,7 @@ fn precompute_poseidons( let (poseidon_16_traces, poseidon_24_traces): (Vec<_>, Vec<_>) = xmss_pub_keys .par_iter() .zip(all_signatures.par_iter()) - .map(|(pub_key, sig)| { - pub_key - .verify_with_poseidon_trace(message_hash, sig) - .unwrap() - }) + .map(|(pub_key, sig)| pub_key.verify_with_poseidon_trace(message_hash, sig).unwrap()) .unzip(); ( poseidon_16_traces.into_par_iter().flatten().collect(), diff --git a/crates/sub_protocols/src/commit_extension_from_base.rs b/crates/sub_protocols/src/commit_extension_from_base.rs index cb83a896..66aa6c40 100644 --- a/crates/sub_protocols/src/commit_extension_from_base.rs +++ b/crates/sub_protocols/src/commit_extension_from_base.rs @@ -29,13 +29,9 @@ impl>> ExtensionCommitmentFromBaseProver { pub fn before_commitment(extension_columns: Vec<&[EF]>) -> Self { let mut sub_columns_to_commit = Vec::new(); for extension_column in extension_columns { - sub_columns_to_commit.extend(transpose_slice_to_basis_coefficients::, EF>( - extension_column, - )); - } - Self { - sub_columns_to_commit, + sub_columns_to_commit.extend(transpose_slice_to_basis_coefficients::, EF>(extension_column)); } + Self { sub_columns_to_commit } } pub fn after_commitment( @@ -66,8 +62,7 @@ impl ExtensionCommitmentFromBaseVerifier { verifier_state: &mut FSVerifier>, claim: &MultiEvaluation, ) -> ProofResult>>> { - let sub_evals = - verifier_state.next_extension_scalars_vec(EF::DIMENSION * claim.num_values())?; + let sub_evals = verifier_state.next_extension_scalars_vec(EF::DIMENSION * claim.num_values())?; let mut statements_remaning_to_verify = Vec::new(); for (chunk, claim_value) in sub_evals.chunks_exact(EF::DIMENSION).zip(&claim.values) { diff --git a/crates/sub_protocols/src/generic_packed_lookup.rs b/crates/sub_protocols/src/generic_packed_lookup.rs index 3b2325ef..4e815573 100644 --- a/crates/sub_protocols/src/generic_packed_lookup.rs +++ b/crates/sub_protocols/src/generic_packed_lookup.rs @@ -9,8 +9,7 @@ use utils::{FSProver, assert_eq_many}; use crate::{ColDims, MultilinearChunks, packed_pcs_global_statements_for_prover}; #[derive(Debug)] -pub struct GenericPackedLookupProver<'a, TF: Field, EF: ExtensionField + ExtensionField>> -{ +pub struct GenericPackedLookupProver<'a, TF: Field, EF: ExtensionField + ExtensionField>> { // inputs pub(crate) table: VecOrSlice<'a, TF>, pub(crate) index_columns: Vec<&'a [PF]>, @@ -31,8 +30,7 @@ pub struct PackedLookupStatements { pub on_indexes: Vec>>, // contain sparse points (TODO take advantage of it) } -impl<'a, TF: Field, EF: ExtensionField + ExtensionField>> - GenericPackedLookupProver<'a, TF, EF> +impl<'a, TF: Field, EF: ExtensionField + ExtensionField>> GenericPackedLookupProver<'a, TF, EF> where PF: PrimeField64, { @@ -62,17 +60,11 @@ where value_columns.len(), statements.len() ); - value_columns - .iter() - .zip(&statements) - .for_each(|(cols, evals)| { - assert_eq!(cols.len(), evals[0].num_values()); - }); + value_columns.iter().zip(&statements).for_each(|(cols, evals)| { + assert_eq!(cols.len(), evals[0].num_values()); + }); let n_groups = value_columns.len(); - let n_cols_per_group = value_columns - .iter() - .map(|cols| cols.len()) - .collect::>(); + let n_cols_per_group = value_columns.iter().map(|cols| cols.len()).collect::>(); let flatened_value_columns = value_columns .iter() @@ -82,10 +74,7 @@ where let mut all_dims = vec![]; for (i, (default_index, height)) in default_indexes.iter().zip(heights.iter()).enumerate() { for col_index in 0..n_cols_per_group[i] { - all_dims.push(ColDims::padded( - *height, - table_ref[col_index + default_index], - )); + all_dims.push(ColDims::padded(*height, table_ref[col_index + default_index])); } } @@ -129,8 +118,7 @@ where for (alpha_power, statement) in batching_scalar.powers().zip(&packed_statements) { compute_sparse_eval_eq(&statement.point, &mut poly_eq_point, alpha_power); } - let pushforward = - compute_pushforward(&packed_lookup_indexes, table_ref.len(), &poly_eq_point); + let pushforward = compute_pushforward(&packed_lookup_indexes, table_ref.len(), &poly_eq_point); let batched_value: EF = batching_scalar .powers() @@ -181,9 +169,10 @@ where offset += n_cols; assert!(my_chunks.iter().all(|col_chunks| { - col_chunks.iter().zip(my_chunks[0].iter()).all(|(c1, c2)| { - c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars - }) + col_chunks + .iter() + .zip(my_chunks[0].iter()) + .all(|(c1, c2)| c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars) })); let mut inner_statements = vec![]; let mut inner_evals = vec![]; @@ -191,9 +180,7 @@ where let sparse_point = MultilinearPoint( [ chunk.bits_offset_in_original(), - logup_star_statements.on_indexes.point - [self.chunks.packed_n_vars - chunk.n_vars..] - .to_vec(), + logup_star_statements.on_indexes.point[self.chunks.packed_n_vars - chunk.n_vars..].to_vec(), ] .concat(), ); @@ -208,20 +195,15 @@ where for (&inner_eval, chunk) in inner_evals.iter().zip(chunks_for_col) { let missing_vars = self.chunks.packed_n_vars - chunk.n_vars; value_on_packed_indexes += (inner_eval + PF::::from_usize(col_index)) - * MultilinearPoint( - logup_star_statements.on_indexes.point[..missing_vars].to_vec(), - ) - .eq_poly_outside(&MultilinearPoint( - chunk.bits_offset_in_packed(self.chunks.packed_n_vars), - )); + * MultilinearPoint(logup_star_statements.on_indexes.point[..missing_vars].to_vec()) + .eq_poly_outside(&MultilinearPoint( + chunk.bits_offset_in_packed(self.chunks.packed_n_vars), + )); } } } // sanity check - assert_eq!( - value_on_packed_indexes, - logup_star_statements.on_indexes.value - ); + assert_eq!(value_on_packed_indexes, logup_star_statements.on_indexes.value); PackedLookupStatements { on_table: logup_star_statements.on_table, @@ -312,9 +294,10 @@ where // sanity check assert!(my_chunks.iter().all(|col_chunks| { - col_chunks.iter().zip(my_chunks[0].iter()).all(|(c1, c2)| { - c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars - }) + col_chunks + .iter() + .zip(my_chunks[0].iter()) + .all(|(c1, c2)| c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars) })); let mut inner_statements = vec![]; let inner_evals = verifier_state.next_extension_scalars_vec(my_chunks[0].len())?; @@ -322,9 +305,7 @@ where let sparse_point = MultilinearPoint( [ chunk.bits_offset_in_original(), - logup_star_statements.on_indexes.point - [self.chunks.packed_n_vars - chunk.n_vars..] - .to_vec(), + logup_star_statements.on_indexes.point[self.chunks.packed_n_vars - chunk.n_vars..].to_vec(), ] .concat(), ); @@ -336,12 +317,10 @@ where for (&inner_eval, chunk) in inner_evals.iter().zip(chunks_for_col) { let missing_vars = self.chunks.packed_n_vars - chunk.n_vars; value_on_packed_indexes += (inner_eval + PF::::from_usize(col_index)) - * MultilinearPoint( - logup_star_statements.on_indexes.point[..missing_vars].to_vec(), - ) - .eq_poly_outside(&MultilinearPoint( - chunk.bits_offset_in_packed(self.chunks.packed_n_vars), - )); + * MultilinearPoint(logup_star_statements.on_indexes.point[..missing_vars].to_vec()) + .eq_poly_outside(&MultilinearPoint( + chunk.bits_offset_in_packed(self.chunks.packed_n_vars), + )); } } } @@ -357,9 +336,7 @@ where } } -fn expand_multi_evals( - statements: &[Vec>], -) -> Vec>> { +fn expand_multi_evals(statements: &[Vec>]) -> Vec>> { statements .iter() .flat_map(|multi_evals| { diff --git a/crates/sub_protocols/src/normal_packed_lookup.rs b/crates/sub_protocols/src/normal_packed_lookup.rs index f98b9f39..71770317 100644 --- a/crates/sub_protocols/src/normal_packed_lookup.rs +++ b/crates/sub_protocols/src/normal_packed_lookup.rs @@ -86,9 +86,7 @@ where ); } - for (eval_group, extension_column_split) in - statements_ef.iter().zip(&all_value_columns[n_cols_f..]) - { + for (eval_group, extension_column_split) in statements_ef.iter().zip(&all_value_columns[n_cols_f..]) { let mut multi_evals = vec![]; for eval in eval_group { let sub_evals = extension_column_split @@ -166,11 +164,7 @@ where EF: ExtensionField, { assert_eq_many!(heights_f.len(), default_indexes_f.len(), statements_f.len()); - assert_eq_many!( - heights_ef.len(), - default_indexes_ef.len(), - statements_ef.len() - ); + assert_eq_many!(heights_ef.len(), default_indexes_ef.len(), statements_ef.len()); let n_cols_f = statements_f.len(); let mut multi_eval_statements = vec![]; @@ -185,8 +179,8 @@ where for eval_group in &statements_ef { let mut multi_evals = vec![]; for eval in eval_group { - let sub_evals = verifier_state - .next_extension_scalars_vec(>>::DIMENSION)?; + let sub_evals = + verifier_state.next_extension_scalars_vec(>>::DIMENSION)?; if dot_product_with_base(&sub_evals) != eval.value { return Err(ProofError::InvalidProof); } diff --git a/crates/sub_protocols/src/packed_pcs.rs b/crates/sub_protocols/src/packed_pcs.rs index f6968e8c..ca0a2f73 100644 --- a/crates/sub_protocols/src/packed_pcs.rs +++ b/crates/sub_protocols/src/packed_pcs.rs @@ -4,8 +4,7 @@ use multilinear_toolkit::prelude::*; use p3_util::{log2_ceil_usize, log2_strict_usize}; use tracing::instrument; use utils::{ - FSProver, FSVerifier, from_end, multilinear_eval_constants_at_right, to_big_endian_bits, - to_big_endian_in_field, + FSProver, FSVerifier, from_end, multilinear_eval_constants_at_right, to_big_endian_bits, to_big_endian_in_field, }; use whir_p3::*; @@ -32,11 +31,7 @@ impl Chunk { packed_n_vars - self.n_vars, ) } - fn global_point_for_statement( - &self, - point: &[F], - packed_n_vars: usize, - ) -> MultilinearPoint { + fn global_point_for_statement(&self, point: &[F], packed_n_vars: usize) -> MultilinearPoint { MultilinearPoint([self.bits_offset_in_packed(packed_n_vars), point.to_vec()].concat()) } } @@ -107,12 +102,11 @@ fn split_in_chunks( let mut remaining = dims.committed_size; loop { - let mut chunk_size = - if remaining.next_power_of_two() - remaining <= 1 << log_smallest_decomposition_chunk { - log2_ceil_usize(remaining) - } else { - remaining.ilog2() as usize - }; + let mut chunk_size = if remaining.next_power_of_two() - remaining <= 1 << log_smallest_decomposition_chunk { + log2_ceil_usize(remaining) + } else { + remaining.ilog2() as usize + }; if let Some(log_public) = dims.log_public_data_size { chunk_size = chunk_size.min(log_public); } @@ -133,10 +127,7 @@ fn split_in_chunks( } } -pub fn num_packed_vars_for_dims( - dims: &[ColDims], - log_smallest_decomposition_chunk: usize, -) -> usize { +pub fn num_packed_vars_for_dims(dims: &[ColDims], log_smallest_decomposition_chunk: usize) -> usize { MultilinearChunks::compute(dims, log_smallest_decomposition_chunk).packed_n_vars } @@ -203,13 +194,10 @@ impl MultilinearChunks { let end = start + (1 << chunk.n_vars); let original_poly = &polynomials[chunk.original_poly_index]; unsafe { - let slice = std::slice::from_raw_parts_mut( - (packed_polynomial.as_ptr() as *mut F).add(start), - end - start, - ); + let slice = + std::slice::from_raw_parts_mut((packed_polynomial.as_ptr() as *mut F).add(start), end - start); slice.copy_from_slice( - &original_poly[chunk.offset_in_original - ..chunk.offset_in_original + (1 << chunk.n_vars)], + &original_poly[chunk.offset_in_original..chunk.offset_in_original + (1 << chunk.n_vars)], ); } }); @@ -277,11 +265,8 @@ where PF: TwoAdicField, EF: ExtensionField + TwoAdicField + ExtensionField>, { - let (packed_polynomial, _chunks_decomposition) = compute_multilinear_chunks_and_apply::( - polynomials, - dims, - log_smallest_decomposition_chunk, - ); + let (packed_polynomial, _chunks_decomposition) = + compute_multilinear_chunks_and_apply::(polynomials, dims, log_smallest_decomposition_chunk); let packed_n_vars = log2_strict_usize(packed_polynomial.len()); let mle = if TypeId::of::() == TypeId::of::>() { @@ -291,14 +276,10 @@ where std::mem::transmute::, Vec>(packed_polynomial) })) // TODO this is innefficient (this transposes everything...) } else { - panic!( - "Unsupported field type for packed PCS: {}", - std::any::type_name::() - ); + panic!("Unsupported field type for packed PCS: {}", std::any::type_name::()); }; - let inner_witness = - WhirConfig::new(whir_config_builder.clone(), packed_n_vars).commit(prover_state, &mle); + let inner_witness = WhirConfig::new(whir_config_builder.clone(), packed_n_vars).commit(prover_state, &mle); MultiCommitmentWitness { inner_witness, packed_polynomial: mle, @@ -306,10 +287,7 @@ where } #[instrument(skip_all)] -pub fn packed_pcs_global_statements_for_prover< - F: Field, - EF: ExtensionField + ExtensionField>, ->( +pub fn packed_pcs_global_statements_for_prover + ExtensionField>>( polynomials: &[&[F]], dims: &[ColDims], log_smallest_decomposition_chunk: usize, @@ -325,11 +303,7 @@ pub fn packed_pcs_global_statements_for_prover< let statements_flattened = statements_per_polynomial .iter() .enumerate() - .flat_map(|(poly_index, poly_statements)| { - poly_statements - .iter() - .map(move |statement| (poly_index, statement)) - }) + .flat_map(|(poly_index, poly_statements)| poly_statements.iter().map(move |statement| (poly_index, statement))) .collect::>(); let sub_packed_statements_and_evals_to_send = statements_flattened @@ -344,11 +318,7 @@ pub fn packed_pcs_global_statements_for_prover< let mut evals_to_send = Vec::new(); if chunks.len() == 1 { assert!(!chunks[0].public_data, "TODO"); - assert_eq!( - chunks[0].n_vars, - statement.point.0.len(), - "poly: {poly_index}" - ); + assert_eq!(chunks[0].n_vars, statement.point.0.len(), "poly: {poly_index}"); assert!( chunks[0] .offset_in_packed @@ -357,8 +327,7 @@ pub fn packed_pcs_global_statements_for_prover< ); sub_packed_statements.push(Evaluation::new( - chunks[0] - .global_point_for_statement(&statement.point, all_chunks.packed_n_vars), + chunks[0].global_point_for_statement(&statement.point, all_chunks.packed_n_vars), statement.value, )); } else { @@ -377,15 +346,12 @@ pub fn packed_pcs_global_statements_for_prover< .map(|chunk| { let missing_vars = statement.point.0.len() - chunk.n_vars; - let offset_in_original_booleans = to_big_endian_bits( - chunk.offset_in_original >> chunk.n_vars, - missing_vars, - ); + let offset_in_original_booleans = + to_big_endian_bits(chunk.offset_in_original >> chunk.n_vars, missing_vars); if !initial_booleans.is_empty() && initial_booleans.len() < offset_in_original_booleans.len() - && initial_booleans - == offset_in_original_booleans[..initial_booleans.len()] + && initial_booleans == offset_in_original_booleans[..initial_booleans.len()] { tracing::warn!("TODO: sparse statement accroos mutiple chunks"); } @@ -400,17 +366,13 @@ pub fn packed_pcs_global_statements_for_prover< } } - let sub_point = - MultilinearPoint(statement.point.0[missing_vars..].to_vec()); - let sub_value = (&pol[chunk.offset_in_original - ..chunk.offset_in_original + (1 << chunk.n_vars)]) + let sub_point = MultilinearPoint(statement.point.0[missing_vars..].to_vec()); + let sub_value = (&pol + [chunk.offset_in_original..chunk.offset_in_original + (1 << chunk.n_vars)]) .evaluate_sparse(&sub_point); // `evaluate_sparse` because sometime (typically due to packed lookup protocol, the original statement is already sparse) ( Some(Evaluation::new( - chunk.global_point_for_statement( - &sub_point, - all_chunks.packed_n_vars, - ), + chunk.global_point_for_statement(&sub_point, all_chunks.packed_n_vars), sub_value, )), sub_value, @@ -427,10 +389,8 @@ pub fn packed_pcs_global_statements_for_prover< }); let initial_missing_vars = statement.point.0.len() - chunks[0].n_vars; - let initial_offset_in_original_booleans = to_big_endian_bits( - chunks[0].offset_in_original >> chunks[0].n_vars, - initial_missing_vars, - ); + let initial_offset_in_original_booleans = + to_big_endian_bits(chunks[0].offset_in_original >> chunks[0].n_vars, initial_missing_vars); if initial_booleans.len() < initial_offset_in_original_booleans.len() // if the statement only concern the first chunk, no need to send more data && dim.log_public_data_size.is_none() // if the first value is public, no need to recompute it @@ -446,16 +406,12 @@ pub fn packed_pcs_global_statements_for_prover< let initial_missing_vars = statement.point.0.len() - chunks[0].n_vars; let initial_sub_value = (statement.value - retrieved_eval) / MultilinearPoint(statement.point.0[..initial_missing_vars].to_vec()) - .eq_poly_outside(&MultilinearPoint( - chunks[0].bits_offset_in_original(), - )); - let initial_sub_point = - MultilinearPoint(statement.point.0[initial_missing_vars..].to_vec()); - - let initial_packed_point = chunks[0] - .global_point_for_statement(&initial_sub_point, all_chunks.packed_n_vars); - sub_packed_statements - .insert(0, Evaluation::new(initial_packed_point, initial_sub_value)); + .eq_poly_outside(&MultilinearPoint(chunks[0].bits_offset_in_original())); + let initial_sub_point = MultilinearPoint(statement.point.0[initial_missing_vars..].to_vec()); + + let initial_packed_point = + chunks[0].global_point_for_statement(&initial_sub_point, all_chunks.packed_n_vars); + sub_packed_statements.insert(0, Evaluation::new(initial_packed_point, initial_sub_value)); evals_to_send.insert(0, initial_sub_value); } } @@ -484,14 +440,10 @@ where PF: TwoAdicField, { let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); - WhirConfig::new(whir_config_builder.clone(), all_chunks.packed_n_vars) - .parse_commitment(verifier_state) + WhirConfig::new(whir_config_builder.clone(), all_chunks.packed_n_vars).parse_commitment(verifier_state) } -pub fn packed_pcs_global_statements_for_verifier< - F: Field, - EF: ExtensionField + ExtensionField>, ->( +pub fn packed_pcs_global_statements_for_verifier + ExtensionField>>( dims: &[ColDims], log_smallest_decomposition_chunk: usize, statements_per_polynomial: &[Vec>], @@ -517,8 +469,7 @@ pub fn packed_pcs_global_statements_for_verifier< .is_multiple_of(1 << chunks[0].n_vars) ); packed_statements.push(Evaluation::new( - chunks[0] - .global_point_for_statement(&statement.point, all_chunks.packed_n_vars), + chunks[0].global_point_for_statement(&statement.point, all_chunks.packed_n_vars), statement.value, )); } else { @@ -530,9 +481,10 @@ pub fn packed_pcs_global_statements_for_verifier< .collect::>(); let mut sub_values = vec![]; if has_public_data { - sub_values.push(public_data[&poly_index].evaluate(&MultilinearPoint( - from_end(&statement.point, chunks[0].n_vars).to_vec(), - ))); + sub_values.push( + public_data[&poly_index] + .evaluate(&MultilinearPoint(from_end(&statement.point, chunks[0].n_vars).to_vec())), + ); } for chunk in chunks { if chunk.public_data { @@ -553,8 +505,7 @@ pub fn packed_pcs_global_statements_for_verifier< } else { let sub_value = verifier_state.next_extension_scalar()?; sub_values.push(sub_value); - let sub_point = - MultilinearPoint(statement.point.0[missing_vars..].to_vec()); + let sub_point = MultilinearPoint(statement.point.0[missing_vars..].to_vec()); packed_statements.push(Evaluation::new( chunk.global_point_for_statement(&sub_point, all_chunks.packed_n_vars), sub_value, @@ -627,11 +578,7 @@ mod tests { let mut rng = StdRng::seed_from_u64(0); let log_smallest_decomposition_chunk = 4; - let committed_length_lengths_and_default_value_and_log_public_data: [( - usize, - F, - Option, - ); _] = [ + let committed_length_lengths_and_default_value_and_log_public_data: [(usize, F, Option); _] = [ (916, F::from_usize(8), Some(5)), (854, F::from_usize(0), Some(7)), (854, F::from_usize(1), Some(5)), @@ -678,8 +625,7 @@ mod tests { let n_points = rng.random_range(1..5); let mut statements = Vec::new(); for _ in 0..n_points { - let point = - MultilinearPoint((0..n_vars).map(|_| rng.random()).collect::>()); + let point = MultilinearPoint((0..n_vars).map(|_| rng.random()).collect::>()); let value = poly.evaluate(&point); statements.push(Evaluation { point, value }); } diff --git a/crates/sub_protocols/src/vectorized_packed_lookup.rs b/crates/sub_protocols/src/vectorized_packed_lookup.rs index bfc43a1e..8fbb10ba 100644 --- a/crates/sub_protocols/src/vectorized_packed_lookup.rs +++ b/crates/sub_protocols/src/vectorized_packed_lookup.rs @@ -13,8 +13,7 @@ pub struct VectorizedPackedLookupProver<'a, EF: ExtensionField>, const VE folding_scalars: MultilinearPoint, } -impl<'a, EF: ExtensionField>, const VECTOR_LEN: usize> - VectorizedPackedLookupProver<'a, EF, VECTOR_LEN> +impl<'a, EF: ExtensionField>, const VECTOR_LEN: usize> VectorizedPackedLookupProver<'a, EF, VECTOR_LEN> where PF: PrimeField64, { @@ -34,8 +33,7 @@ where statements: Vec>>, log_smallest_decomposition_chunk: usize, ) -> Self { - let folding_scalars = - MultilinearPoint(prover_state.sample_vec(log2_strict_usize(VECTOR_LEN))); + let folding_scalars = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(VECTOR_LEN))); let folded_table = fold_multilinear_chunks(table, &folding_scalars); let folding_poly_eq = eval_eq(&folding_scalars); @@ -86,10 +84,7 @@ where let mut statements = self .generic .step_2(prover_state, non_zero_memory_size.div_ceil(VECTOR_LEN)); - statements - .on_table - .point - .extend(self.folding_scalars.0.clone()); + statements.on_table.point.extend(self.folding_scalars.0.clone()); statements } } @@ -100,8 +95,7 @@ pub struct VectorizedPackedLookupVerifier>, const VECT folding_scalars: MultilinearPoint, } -impl>, const VECTOR_LEN: usize> - VectorizedPackedLookupVerifier +impl>, const VECTOR_LEN: usize> VectorizedPackedLookupVerifier where PF: PrimeField64, { @@ -114,8 +108,7 @@ where log_smallest_decomposition_chunk: usize, table_initial_values: &[PF], ) -> ProofResult { - let folding_scalars = - MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(VECTOR_LEN))); + let folding_scalars = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(VECTOR_LEN))); let folded_table_initial_values = fold_multilinear_chunks( &table_initial_values[..(table_initial_values.len() / VECTOR_LEN) * VECTOR_LEN], &folding_scalars, @@ -142,14 +135,10 @@ where verifier_state: &mut FSVerifier>, log_memory_size: usize, ) -> ProofResult> { - let mut statements = self.generic.step_2( - verifier_state, - log_memory_size - log2_strict_usize(VECTOR_LEN), - )?; - statements - .on_table - .point - .extend(self.folding_scalars.0.clone()); + let mut statements = self + .generic + .step_2(verifier_state, log_memory_size - log2_strict_usize(VECTOR_LEN))?; + statements.on_table.point.extend(self.folding_scalars.0.clone()); Ok(statements) } } @@ -163,12 +152,7 @@ fn get_folded_statements( .map(|sub_statements| { sub_statements .iter() - .map(|meval| { - MultiEvaluation::new( - meval.point.clone(), - vec![meval.values.evaluate(folding_scalars)], - ) - }) + .map(|meval| MultiEvaluation::new(meval.point.clone(), vec![meval.values.evaluate(folding_scalars)])) .collect::>() }) .collect::>() diff --git a/crates/sub_protocols/tests/test_generic_packed_lookup.rs b/crates/sub_protocols/tests/test_generic_packed_lookup.rs index a0726538..ffaa0b22 100644 --- a/crates/sub_protocols/tests/test_generic_packed_lookup.rs +++ b/crates/sub_protocols/tests/test_generic_packed_lookup.rs @@ -12,15 +12,10 @@ const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; #[test] fn test_generic_packed_lookup() { let non_zero_memory_size: usize = 37412; - let lookups_height_and_cols: Vec<(usize, usize)> = - vec![(4587, 1), (1234, 3), (9411, 1), (7890, 2)]; + let lookups_height_and_cols: Vec<(usize, usize)> = vec![(4587, 1), (1234, 3), (9411, 1), (7890, 2)]; let default_indexes = vec![7, 11, 0, 2]; let n_statements = [1, 5, 2, 1]; - assert_eq_many!( - lookups_height_and_cols.len(), - default_indexes.len(), - n_statements.len() - ); + assert_eq_many!(lookups_height_and_cols.len(), default_indexes.len(), n_statements.len()); let mut rng = StdRng::seed_from_u64(0); let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); @@ -50,10 +45,7 @@ fn test_generic_packed_lookup() { let mut statements = vec![]; for _ in 0..n_statements[i] { let point = MultilinearPoint::::random(&mut rng, log2_ceil_usize(*n_lines)); - let values = columns - .iter() - .map(|col| col.evaluate(&point)) - .collect::>(); + let values = columns.iter().map(|col| col.evaluate(&point)).collect::>(); statements.push(MultiEvaluation::new(point, values)); } all_statements.push(statements); @@ -79,8 +71,7 @@ fn test_generic_packed_lookup() { // phony commitment to pushforward prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); - let remaining_claims_to_prove = - packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); + let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); let mut verifier_state = build_verifier_state(&prover_state); diff --git a/crates/sub_protocols/tests/test_normal_packed_lookup.rs b/crates/sub_protocols/tests/test_normal_packed_lookup.rs index 92470a1b..7578fd07 100644 --- a/crates/sub_protocols/tests/test_normal_packed_lookup.rs +++ b/crates/sub_protocols/tests/test_normal_packed_lookup.rs @@ -37,9 +37,8 @@ fn test_normal_packed_lookup() { for (i, height) in cols_heights_ef.iter().enumerate() { let mut indexes = vec![F::from_usize(default_indexes_ef[i]); height.next_power_of_two()]; for idx in indexes.iter_mut().take(*height) { - *idx = F::from_usize(rng.random_range( - 0..non_zero_memory_size - >>::DIMENSION, - )); + *idx = + F::from_usize(rng.random_range(0..non_zero_memory_size - >>::DIMENSION)); } all_indexe_columns_ef.push(indexes); } @@ -67,8 +66,7 @@ fn test_normal_packed_lookup() { for (value_col_f, n_statements) in value_columns_f.iter().zip(&n_statements_f) { let mut statements = vec![]; for _ in 0..*n_statements { - let point = - MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_f.len())); + let point = MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_f.len())); let value = value_col_f.evaluate(&point); statements.push(Evaluation::new(point, value)); } @@ -78,8 +76,7 @@ fn test_normal_packed_lookup() { for (value_col_ef, n_statements) in value_columns_ef.iter().zip(&n_statements_ef) { let mut statements = vec![]; for _ in 0..*n_statements { - let point = - MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_ef.len())); + let point = MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_ef.len())); let value = value_col_ef.evaluate(&point); statements.push(Evaluation::new(point, value)); } @@ -107,8 +104,7 @@ fn test_normal_packed_lookup() { // phony commitment to pushforward prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); - let remaining_claims_to_prove = - packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); + let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); let mut verifier_state = build_verifier_state(&prover_state); diff --git a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs index 9ac6a648..3029109e 100644 --- a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs +++ b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs @@ -19,19 +19,11 @@ fn test_vectorized_packed_lookup() { let cols_heights: Vec = vec![785, 1022, 4751]; let default_indexes = vec![7, 11, 0]; let n_statements = vec![1, 5, 2]; - assert_eq_many!( - cols_heights.len(), - default_indexes.len(), - n_statements.len() - ); + assert_eq_many!(cols_heights.len(), default_indexes.len(), n_statements.len()); let mut rng = StdRng::seed_from_u64(0); let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); - for mem in memory - .iter_mut() - .take(non_zero_memory_size) - .skip(VECTOR_LEN) - { + for mem in memory.iter_mut().take(non_zero_memory_size).skip(VECTOR_LEN) { *mem = rng.random(); } @@ -59,12 +51,8 @@ fn test_vectorized_packed_lookup() { for (value_cols, n_statements) in all_value_columns.iter().zip(&n_statements) { let mut statements = vec![]; for _ in 0..*n_statements { - let point = - MultilinearPoint::::random(&mut rng, log2_strict_usize(value_cols[0].len())); - let values = value_cols - .iter() - .map(|col| col.evaluate(&point)) - .collect::>(); + let point = MultilinearPoint::::random(&mut rng, log2_strict_usize(value_cols[0].len())); + let values = value_cols.iter().map(|col| col.evaluate(&point)).collect::>(); statements.push(MultiEvaluation::new(point, values)); } all_statements.push(statements); @@ -89,8 +77,7 @@ fn test_vectorized_packed_lookup() { // phony commitment to pushforward prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); - let remaining_claims_to_prove = - packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); + let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); let mut verifier_state = build_verifier_state(&prover_state); diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index f8ba826a..9254bb74 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -5,10 +5,7 @@ use tracing::instrument; pub fn transmute_slice(slice: &[Before]) -> &[After] { let new_len = std::mem::size_of_val(slice) / std::mem::size_of::(); - assert_eq!( - std::mem::size_of_val(slice), - new_len * std::mem::size_of::() - ); + assert_eq!(std::mem::size_of_val(slice), new_len * std::mem::size_of::()); assert_eq!(slice.as_ptr() as usize % std::mem::align_of::(), 0); unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const After, new_len) } } @@ -18,9 +15,7 @@ pub fn from_end(slice: &[A], n: usize) -> &[A] { &slice[slice.len() - n..] } -pub fn transpose_slice_to_basis_coefficients>( - slice: &[EF], -) -> Vec> { +pub fn transpose_slice_to_basis_coefficients>(slice: &[EF]) -> Vec> { let res = vec![F::zero_vec(slice.len()); EF::DIMENSION]; slice.par_iter().enumerate().for_each(|(i, row)| { let coeffs = EF::as_basis_coefficients_slice(row); @@ -42,10 +37,7 @@ pub fn dot_product_with_base>>(slice: &[EF]) -> EF { } pub fn to_big_endian_bits(value: usize, bit_count: usize) -> Vec { - (0..bit_count) - .rev() - .map(|i| (value >> i) & 1 == 1) - .collect() + (0..bit_count).rev().map(|i| (value >> i) & 1 == 1).collect() } pub fn to_big_endian_in_field(value: usize, bit_count: usize) -> Vec { @@ -80,11 +72,7 @@ pub fn powers_const(base: F) -> [F; N] { } #[instrument(skip_all)] -pub fn transpose( - matrix: &[F], - width: usize, - column_extra_capacity: usize, -) -> Vec> { +pub fn transpose(matrix: &[F], width: usize, column_extra_capacity: usize) -> Vec> { assert!((matrix.len().is_multiple_of(width))); let height = matrix.len() / width; let res = vec![ @@ -98,17 +86,14 @@ pub fn transpose( }; width ]; - matrix - .par_chunks_exact(width) - .enumerate() - .for_each(|(row, chunk)| { - for (&value, col) in chunk.iter().zip(&res) { - unsafe { - let ptr = col.as_ptr() as *mut F; - ptr.add(row).write(value); - } + matrix.par_chunks_exact(width).enumerate().for_each(|(row, chunk)| { + for (&value, col) in chunk.iter().zip(&res) { + unsafe { + let ptr = col.as_ptr() as *mut F; + ptr.add(row).write(value); } - }); + } + }); res } @@ -118,9 +103,9 @@ pub fn transposed_par_iter_mut( let len = array[0].len(); let data_ptrs: [AtomicPtr; N] = array.each_mut().map(|v| AtomicPtr::new(v.as_mut_ptr())); - (0..len).into_par_iter().map(move |i| unsafe { - std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) - }) + (0..len) + .into_par_iter() + .map(move |i| unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) }) } #[derive(Debug)] diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 2246cb8a..b4d50991 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -7,20 +7,13 @@ use multilinear_toolkit::prelude::*; use tracing::instrument; #[instrument(skip_all)] -pub fn multilinears_linear_combination< - F: Field, - EF: ExtensionField, - P: Borrow<[F]> + Send + Sync, ->( +pub fn multilinears_linear_combination, P: Borrow<[F]> + Send + Sync>( pols: &[P], scalars: &[EF], ) -> Vec { assert_eq!(pols.len(), scalars.len()); let n_vars = log2_strict_usize(pols[0].borrow().len()); - assert!( - pols.iter() - .all(|p| log2_strict_usize(p.borrow().len()) == n_vars) - ); + assert!(pols.iter().all(|p| log2_strict_usize(p.borrow().len()) == n_vars)); (0..1 << n_vars) .into_par_iter() .map(|i| dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i]))) @@ -32,10 +25,7 @@ pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) // multilinear polynomial = [0 0 --- 0][1 1 --- 1] (`limit` times 0, then `2^n_vars - limit` times 1) evaluated at `point` - assert!( - limit <= (1 << n_vars), - "limit {limit} is too large for n_vars {n_vars}" - ); + assert!(limit <= (1 << n_vars), "limit {limit} is too large for n_vars {n_vars}"); if limit == 1 << n_vars { return F::ZERO; @@ -88,10 +78,7 @@ pub fn padd_with_zero_to_next_multiple_of(pol: &[F], multiple: usize) padded } -pub fn evaluate_as_larger_multilinear_pol>( - pol: &[F], - point: &[EF], -) -> EF { +pub fn evaluate_as_larger_multilinear_pol>(pol: &[F], point: &[EF]) -> EF { // [[-pol-] 0 0 0 0 ... 0 0 0 0 0] evaluated at point let pol_n_vars = log2_strict_usize(pol.len()); assert!(point.len() >= pol_n_vars); @@ -103,10 +90,7 @@ pub fn evaluate_as_larger_multilinear_pol>( * pol.evaluate(&MultilinearPoint(from_end(point, pol_n_vars).to_vec())) } -pub fn evaluate_as_smaller_multilinear_pol>( - pol: &[F], - point: &[EF], -) -> EF { +pub fn evaluate_as_smaller_multilinear_pol>(pol: &[F], point: &[EF]) -> EF { let pol_n_vars = log2_strict_usize(pol.len()); assert!(point.len() <= pol_n_vars); (&pol[..1 << point.len()]).evaluate(&MultilinearPoint(point.to_vec())) @@ -140,9 +124,7 @@ mod tests { let n_point_vars = 7; let mut rng = StdRng::seed_from_u64(0); let mut pol = F::zero_vec(1 << n_point_vars); - pol.iter_mut() - .take(1 << n_vars) - .for_each(|coeff| *coeff = rng.random()); + pol.iter_mut().take(1 << n_vars).for_each(|coeff| *coeff = rng.random()); let point = (0..n_point_vars).map(|_| rng.random()).collect::>(); assert_eq!( evaluate_as_larger_multilinear_pol(&pol[..1 << n_vars], &point), diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index 154e2c69..50cd6138 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -14,8 +14,7 @@ pub fn build_challenger() -> MyChallenger { MyChallenger::new(get_poseidon16().clone()) } -pub fn build_prover_state>() --> ProverState { +pub fn build_prover_state>() -> ProverState { ProverState::new(build_challenger()) } diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 0628e472..08305a1f 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -28,11 +28,7 @@ fn poseidon16_compress(a: &Digest, b: &Digest) -> Digest { .unwrap() } -fn poseidon16_compress_with_trace( - a: &Digest, - b: &Digest, - poseidon_16_trace: &mut Vec<([F; 16], [F; 16])>, -) -> Digest { +fn poseidon16_compress_with_trace(a: &Digest, b: &Digest, poseidon_16_trace: &mut Vec<([F; 16], [F; 16])>) -> Digest { let input: [F; 16] = [*a, *b].concat().try_into().unwrap(); let output = poseidon16_permute(input); poseidon_16_trace.push((input, output)); diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 2b5ad6bf..90931e4c 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -39,20 +39,14 @@ impl WotsSecretKey { pub fn sign(&self, message_hash: &Digest, rng: &mut impl Rng) -> WotsSignature { let (randomness, encoding) = find_randomness_for_wots_encoding(message_hash, rng); WotsSignature { - chain_tips: std::array::from_fn(|i| { - iterate_hash(&self.pre_images[i], encoding[i] as usize) - }), + chain_tips: std::array::from_fn(|i| iterate_hash(&self.pre_images[i], encoding[i] as usize)), randomness, } } } impl WotsSignature { - pub fn recover_public_key( - &self, - message_hash: &Digest, - signature: &Self, - ) -> Option { + pub fn recover_public_key(&self, message_hash: &Digest, signature: &Self) -> Option { self.recover_public_key_with_poseidon_trace(message_hash, signature, &mut Vec::new()) } @@ -62,17 +56,9 @@ impl WotsSignature { signature: &Self, poseidon_16_trace: &mut Vec<([F; 16], [F; 16])>, ) -> Option { - let encoding = wots_encode_with_poseidon_trace( - message_hash, - &signature.randomness, - poseidon_16_trace, - )?; + let encoding = wots_encode_with_poseidon_trace(message_hash, &signature.randomness, poseidon_16_trace)?; Some(WotsPublicKey(std::array::from_fn(|i| { - iterate_hash_with_poseidon_trace( - &self.chain_tips[i], - W - 1 - encoding[i] as usize, - poseidon_16_trace, - ) + iterate_hash_with_poseidon_trace(&self.chain_tips[i], W - 1 - encoding[i] as usize, poseidon_16_trace) }))) } } @@ -84,11 +70,9 @@ impl WotsPublicKey { pub fn hash_with_poseidon_trace(&self, poseidon_24_trace: &mut Poseidon24History) -> Digest { assert!(V.is_multiple_of(2), "V must be even for hashing pairs."); - self.0 - .chunks_exact(2) - .fold(Digest::default(), |digest, chunk| { - poseidon24_compress_with_trace(&chunk[0], &chunk[1], &digest, poseidon_24_trace) - }) + self.0.chunks_exact(2).fold(Digest::default(), |digest, chunk| { + poseidon24_compress_with_trace(&chunk[0], &chunk[1], &digest, poseidon_24_trace) + }) } } @@ -106,10 +90,7 @@ pub fn iterate_hash_with_poseidon_trace( }) } -pub fn find_randomness_for_wots_encoding( - message: &Digest, - rng: &mut impl Rng, -) -> (Digest, [u8; V]) { +pub fn find_randomness_for_wots_encoding(message: &Digest, rng: &mut impl Rng) -> (Digest, [u8; V]) { loop { let randomness = rng.random(); if let Some(encoding) = wots_encode(message, &randomness) { diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index ae14f0e9..8556d609 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -19,9 +19,7 @@ pub struct XmssPublicKey(pub Digest); impl XmssSecretKey { pub fn random(rng: &mut impl Rng) -> Self { - let wots_secret_keys: Vec<_> = (0..1 << LOG_LIFETIME) - .map(|_| WotsSecretKey::random(rng)) - .collect(); + let wots_secret_keys: Vec<_> = (0..1 << LOG_LIFETIME).map(|_| WotsSecretKey::random(rng)).collect(); let leaves = wots_secret_keys .iter() @@ -43,10 +41,7 @@ impl XmssSecretKey { } pub fn sign(&self, message_hash: &Digest, index: usize, rng: &mut impl Rng) -> XmssSignature { - assert!( - index < (1 << LOG_LIFETIME), - "Index out of bounds for XMSS signature" - ); + assert!(index < (1 << LOG_LIFETIME), "Index out of bounds for XMSS signature"); let wots_signature = self.wots_secret_keys[index].sign(message_hash, rng); let merkle_proof = (0..LOG_LIFETIME) .scan(index, |current_idx, level| { @@ -71,8 +66,7 @@ impl XmssSecretKey { impl XmssPublicKey { pub fn verify(&self, message_hash: &Digest, signature: &XmssSignature) -> Option<()> { - self.verify_with_poseidon_trace(message_hash, signature) - .map(|_| ()) + self.verify_with_poseidon_trace(message_hash, signature).map(|_| ()) } pub fn verify_with_poseidon_trace( @@ -82,13 +76,11 @@ impl XmssPublicKey { ) -> Option<(Poseidon16History, Poseidon24History)> { let mut poseidon_16_trace = Vec::new(); let mut poseidon_24_trace = Vec::new(); - let wots_public_key = signature - .wots_signature - .recover_public_key_with_poseidon_trace( - message_hash, - &signature.wots_signature, - &mut poseidon_16_trace, - )?; + let wots_public_key = signature.wots_signature.recover_public_key_with_poseidon_trace( + message_hash, + &signature.wots_signature, + &mut poseidon_16_trace, + )?; // merkle root verification let mut current_hash = wots_public_key.hash_with_poseidon_trace(&mut poseidon_24_trace); if signature.merkle_proof.len() != LOG_LIFETIME { @@ -96,14 +88,9 @@ impl XmssPublicKey { } for (is_left, neighbour) in &signature.merkle_proof { if *is_left { - current_hash = - poseidon16_compress_with_trace(¤t_hash, neighbour, &mut poseidon_16_trace) + current_hash = poseidon16_compress_with_trace(¤t_hash, neighbour, &mut poseidon_16_trace) } else { - current_hash = poseidon16_compress_with_trace( - neighbour, - ¤t_hash, - &mut poseidon_16_trace, - ); + current_hash = poseidon16_compress_with_trace(neighbour, ¤t_hash, &mut poseidon_16_trace); } } if current_hash == self.0 { diff --git a/crates/xmss/tests/test_xmss.rs b/crates/xmss/tests/test_xmss.rs index 96e77bea..55c2d06e 100644 --- a/crates/xmss/tests/test_xmss.rs +++ b/crates/xmss/tests/test_xmss.rs @@ -13,9 +13,7 @@ fn test_wots_signature() { let message_hash: [F; 8] = rng.random(); let signature = sk.sign(&message_hash, &mut rng); assert_eq!( - signature - .recover_public_key(&message_hash, &signature,) - .unwrap(), + signature.recover_public_key(&message_hash, &signature,).unwrap(), *sk.public_key() ); } diff --git a/rustfmt.toml b/rustfmt.toml index d100efd0..866c7561 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1 @@ -max_width = 100 \ No newline at end of file +max_width = 120 \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 4b87d4db..644077e2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,6 @@ use clap::Parser; use poseidon_circuit::tests::run_poseidon_benchmark; -use rec_aggregation::{ - recursion::run_whir_recursion_benchmark, xmss_aggregate::run_xmss_benchmark, -}; +use rec_aggregation::{recursion::run_whir_recursion_benchmark, xmss_aggregate::run_xmss_benchmark}; #[derive(Parser)] enum Cli { @@ -24,17 +22,13 @@ fn main() { let cli = Cli::parse(); match cli { - Cli::Xmss { - n_signatures: count, - } => { + Cli::Xmss { n_signatures: count } => { run_xmss_benchmark(count); } Cli::Recursion => { run_whir_recursion_benchmark(); } - Cli::Poseidon { - log_n_perms: log_count, - } => { + Cli::Poseidon { log_n_perms: log_count } => { run_poseidon_benchmark::<16, 16, 3>(log_count, false); } }