Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 11 additions & 35 deletions crates/air/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ where

*extra_data.alpha_powers_mut() = alpha
.powers()
.take(
air.n_constraints()
+ virtual_column_statements
.as_ref()
.map_or(0, |s| s.values.len()),
)
.take(air.n_constraints() + virtual_column_statements.as_ref().map_or(0, |s| s.values.len()))
.collect();

let n_sc_rounds = log_n_rows + 1 - univariate_skips;
Expand All @@ -64,9 +59,7 @@ where
.down_column_indexes_f()
.par_iter()
.zip_eq(last_row_shifted_f)
.map(|(&col_index, &final_value)| {
column_shifted(columns_f[col_index], final_value.as_base().unwrap())
})
.map(|(&col_index, &final_value)| column_shifted(columns_f[col_index], final_value.as_base().unwrap()))
.collect::<Vec<_>>();
let shifted_rows_ef = air
.down_column_indexes_ef()
Expand All @@ -81,10 +74,8 @@ where
let mut columns_up_down_ef = columns_ef.to_vec(); // orginal columns, followed by shifted ones
columns_up_down_ef.extend(shifted_rows_ef.iter().map(Vec::as_slice));

let columns_up_down_group_f: MleGroupRef<'_, EF> =
MleGroupRef::<'_, EF>::Base(columns_up_down_f);
let columns_up_down_group_ef: MleGroupRef<'_, EF> =
MleGroupRef::<'_, EF>::Extension(columns_up_down_ef);
let columns_up_down_group_f: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Base(columns_up_down_f);
let columns_up_down_group_ef: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Extension(columns_up_down_ef);

let columns_up_down_group_f_packed = columns_up_down_group_f.pack();
let columns_up_down_group_ef_packed = columns_up_down_group_ef.pack();
Expand Down Expand Up @@ -130,10 +121,8 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
columns_ef: &[&[EF]],
outer_sumcheck_challenge: &[EF],
) -> (MultilinearPoint<EF>, Vec<EF>, Vec<EF>) {
let n_up_down_columns = columns_f.len()
+ columns_ef.len()
+ columns_with_shift_f.len()
+ columns_with_shift_ef.len();
let n_up_down_columns =
columns_f.len() + columns_ef.len() + columns_with_shift_f.len() + columns_with_shift_ef.len();
let batching_scalars = prover_state.sample_vec(log2_ceil_usize(n_up_down_columns));

let eval_eq_batching_scalars = eval_eq(&batching_scalars)[..n_up_down_columns].to_vec();
Expand All @@ -153,31 +142,23 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
});
}

let columns_shifted_f = &columns_with_shift_f
.iter()
.map(|&i| columns_f[i])
.collect::<Vec<_>>();
let columns_shifted_ef = &columns_with_shift_ef
.iter()
.map(|&i| columns_ef[i])
.collect::<Vec<_>>();
let columns_shifted_f = &columns_with_shift_f.iter().map(|&i| columns_f[i]).collect::<Vec<_>>();
let columns_shifted_ef = &columns_with_shift_ef.iter().map(|&i| columns_ef[i]).collect::<Vec<_>>();

let mut batched_column_down = if columns_shifted_f.is_empty() {
tracing::warn!("TODO optimize open_columns when no shifted F columns");
vec![EF::ZERO; batched_column_up.len()]
} else {
multilinears_linear_combination(
columns_shifted_f,
&eval_eq_batching_scalars[columns_f.len() + columns_ef.len()..]
[..columns_shifted_f.len()],
&eval_eq_batching_scalars[columns_f.len() + columns_ef.len()..][..columns_shifted_f.len()],
)
};

if !columns_shifted_ef.is_empty() {
let batched_column_down_ef = multilinears_linear_combination(
columns_shifted_ef,
&eval_eq_batching_scalars
[columns_f.len() + columns_ef.len() + columns_shifted_f.len()..],
&eval_eq_batching_scalars[columns_f.len() + columns_ef.len() + columns_shifted_f.len()..],
);
batched_column_down
.par_iter_mut()
Expand Down Expand Up @@ -277,12 +258,7 @@ impl<EF: ExtensionField<PF<EF>>> SumcheckComputation<EF> for MySumcheck {
point[0] * point[1] + point[2] * point[3]
}
#[inline(always)]
fn eval_packed_base(
&self,
_: &[PFPacking<EF>],
_: &[EFPacking<EF>],
_: &Self::ExtraData,
) -> EFPacking<EF> {
fn eval_packed_base(&self, _: &[PFPacking<EF>], _: &[EFPacking<EF>], _: &Self::ExtraData) -> EFPacking<EF> {
unreachable!()
}
#[inline(always)]
Expand Down
9 changes: 3 additions & 6 deletions crates/air/src/uni_skip_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ pub fn matrix_next_mle_folded<F: ExtensionField<PF<F>>>(outer_challenges: &[F])
let n = outer_challenges.len();
let mut res = F::zero_vec(1 << n);
for k in 0..n {
let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1])
* outer_challenges[n - k..].iter().copied().product::<F>();
let outer_challenges_prod =
(F::ONE - outer_challenges[n - k - 1]) * outer_challenges[n - k..].iter().copied().product::<F>();
let mut eq_mle = eval_eq_scaled(&outer_challenges[0..n - k - 1], outer_challenges_prod);
for (mut i, v) in eq_mle.iter_mut().enumerate() {
i <<= k + 1;
Expand Down Expand Up @@ -36,10 +36,7 @@ mod tests {
for y in 0..1 << n_vars {
let y_bools = to_big_endian_in_field::<F>(y, n_vars);
let expected = F::from_bool(x + 1 == y);
assert_eq!(
matrix.evaluate(&MultilinearPoint(y_bools.clone())),
expected
);
assert_eq!(matrix.evaluate(&MultilinearPoint(y_bools.clone())), expected);
assert_eq!(next_mle(&[x_bools.clone(), y_bools].concat()), expected);
}
}
Expand Down
14 changes: 3 additions & 11 deletions crates/air/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ use multilinear_toolkit::prelude::*;
/// Field element: 1 if y = x + 1, 0 otherwise.
pub(crate) fn next_mle<F: Field>(point: &[F]) -> F {
// Check that the point length is even: we split into x and y of equal length.
assert_eq!(
point.len() % 2,
0,
"Input point must have an even number of variables."
);
assert_eq!(point.len() % 2, 0, "Input point must have an even number of variables.");
let n = point.len() / 2;

// Split point into x (first n) and y (last n).
Expand All @@ -56,9 +52,7 @@ pub(crate) fn next_mle<F: Field>(point: &[F]) -> F {
//
// Indices are reversed because bits are big-endian.
let eq_high_bits = (k + 1..n)
.map(|i| {
x[n - 1 - i] * y[n - 1 - i] + (F::ONE - x[n - 1 - i]) * (F::ONE - y[n - 1 - i])
})
.map(|i| x[n - 1 - i] * y[n - 1 - i] + (F::ONE - x[n - 1 - i]) * (F::ONE - y[n - 1 - i]))
.product::<F>();

// Term 2: carry bit at position k
Expand All @@ -71,9 +65,7 @@ pub(crate) fn next_mle<F: Field>(point: &[F]) -> F {
//
// For i < k, enforce x_i = 1 and y_i = 0.
// Condition: x_i * (1 - y_i).
let low_bits_are_one_zero = (0..k)
.map(|i| x[n - 1 - i] * (F::ONE - y[n - 1 - i]))
.product::<F>();
let low_bits_are_one_zero = (0..k).map(|i| x[n - 1 - i] * (F::ONE - y[n - 1 - i])).product::<F>();

// Multiply the three terms for this k, representing one "carry pattern".
eq_high_bits * carry_bit * low_bits_are_one_zero
Expand Down
66 changes: 19 additions & 47 deletions crates/air/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ where

*extra_data.alpha_powers_mut() = alpha
.powers()
.take(
air.n_constraints()
+ virtual_column_statements
.as_ref()
.map_or(0, |s| s.values.len()),
)
.take(air.n_constraints() + virtual_column_statements.as_ref().map_or(0, |s| s.values.len()))
.collect();

let n_sc_rounds = log_n_rows + 1 - univariate_skips;
Expand All @@ -38,12 +33,8 @@ where
.unwrap_or_else(|| verifier_state.sample_vec(n_sc_rounds));
assert_eq!(zerocheck_challenges.len(), n_sc_rounds);

let (sc_sum, outer_statement) = sumcheck_verify_with_univariate_skip::<EF>(
verifier_state,
air.degree() + 1,
log_n_rows,
univariate_skips,
)?;
let (sc_sum, outer_statement) =
sumcheck_verify_with_univariate_skip::<EF>(verifier_state, air.degree() + 1, log_n_rows, univariate_skips)?;
if sc_sum
!= virtual_column_statements
.as_ref()
Expand All @@ -59,9 +50,7 @@ where
.collect::<Vec<_>>();

let mut inner_sums = verifier_state.next_extension_scalars_vec(
air.n_columns_air()
+ air.down_column_indexes_f().len()
+ air.down_column_indexes_ef().len(),
air.n_columns_air() + air.down_column_indexes_f().len() + air.down_column_indexes_ef().len(),
)?;

let n_columns_down_f = air.down_column_indexes_f().len();
Expand All @@ -72,11 +61,7 @@ where
&extra_data,
);

if eq_poly_with_skip(
&zerocheck_challenges,
&outer_statement.point,
univariate_skips,
) * constraint_evals
if eq_poly_with_skip(&zerocheck_challenges, &outer_statement.point, univariate_skips) * constraint_evals
!= outer_statement.value
{
return Err(ProofError::InvalidProof);
Expand Down Expand Up @@ -128,14 +113,8 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
evals_up_and_down.len()
);
let last_row_selector = outer_selector_evals[(1 << univariate_skips) - 1]
* outer_sumcheck_challenge
.point
.iter()
.copied()
.product::<EF>();
for (&last_row_value, down_col_eval) in
last_row_f.iter().zip(&mut evals_up_and_down[n_columns..])
{
* outer_sumcheck_challenge.point.iter().copied().product::<EF>();
for (&last_row_value, down_col_eval) in last_row_f.iter().zip(&mut evals_up_and_down[n_columns..]) {
*down_col_eval -= last_row_selector * last_row_value;
}
for (&last_row_value, down_col_eval) in last_row_ef
Expand All @@ -145,23 +124,20 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
*down_col_eval -= last_row_selector * last_row_value;
}

let batching_scalars = verifier_state.sample_vec(log2_ceil_usize(
n_columns + last_row_f.len() + last_row_ef.len(),
));
let batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns + last_row_f.len() + last_row_ef.len()));

let eval_eq_batching_scalars = eval_eq(&batching_scalars);
let batching_scalars_up = &eval_eq_batching_scalars[..n_columns];
let batching_scalars_down = &eval_eq_batching_scalars[n_columns..];

let sub_evals = verifier_state.next_extension_scalars_vec(1 << univariate_skips)?;

if dot_product::<EF, _, _>(
sub_evals.iter().copied(),
outer_selector_evals.iter().copied(),
) != dot_product::<EF, _, _>(
evals_up_and_down.iter().copied(),
eval_eq_batching_scalars.iter().copied(),
) {
if dot_product::<EF, _, _>(sub_evals.iter().copied(), outer_selector_evals.iter().copied())
!= dot_product::<EF, _, _>(
evals_up_and_down.iter().copied(),
eval_eq_batching_scalars.iter().copied(),
)
{
return Err(ProofError::InvalidProof);
}

Expand All @@ -173,9 +149,8 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
return Err(ProofError::InvalidProof);
}

let matrix_up_sc_eval =
MultilinearPoint([epsilons.0.clone(), outer_sumcheck_challenge.point.0.clone()].concat())
.eq_poly_outside(&inner_sumcheck_stement.point);
let matrix_up_sc_eval = MultilinearPoint([epsilons.0.clone(), outer_sumcheck_challenge.point.0.clone()].concat())
.eq_poly_outside(&inner_sumcheck_stement.point);
let matrix_down_sc_eval = next_mle(
&[
epsilons.0,
Expand All @@ -185,10 +160,8 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
.concat(),
);

let evaluations_remaining_to_verify_f =
verifier_state.next_extension_scalars_vec(n_columns_f)?;
let evaluations_remaining_to_verify_ef =
verifier_state.next_extension_scalars_vec(n_columns_ef)?;
let evaluations_remaining_to_verify_f = verifier_state.next_extension_scalars_vec(n_columns_f)?;
let evaluations_remaining_to_verify_ef = verifier_state.next_extension_scalars_vec(n_columns_ef)?;
let evaluations_remaining_to_verify = [
evaluations_remaining_to_verify_f.clone(),
evaluations_remaining_to_verify_ef.clone(),
Expand All @@ -211,8 +184,7 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
.sum::<EF>();

if inner_sumcheck_stement.value
!= matrix_up_sc_eval * batched_col_up_sc_eval
+ matrix_down_sc_eval * batched_col_down_sc_eval
!= matrix_up_sc_eval * batched_col_up_sc_eval + matrix_down_sc_eval * batched_col_down_sc_eval
{
return Err(ProofError::InvalidProof);
}
Expand Down
Loading