Skip to content

Commit d39e313

Browse files
authored
Batch public input check with ring-switch oracle relation (#1441)
## Summary - Remove public input checking from the shift reduction protocol (no longer samples inout_eval_point or batch_coeff during shift reduction) - Re-introduce public input verification batched with the ring-switch oracle relation during PCS opening, modifying both the transparent polynomial and the claim - Extract `compute_batched_transparent` helper on the prover side and `eval_pubcheck_eq` helper on the verifier side ## Test plan - `cargo test -p binius-prover --test shift` - `cargo test -p binius-examples --tests` - `cargo clippy -p binius-verifier -p binius-prover --tests --benches -- -D warnings`
1 parent 6f8039b commit d39e313

File tree

7 files changed

+134
-246
lines changed

7 files changed

+134
-246
lines changed

crates/prover/benches/shift_reduction.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ fn bench_prove_and_verify(c: &mut Criterion) {
9191
let intmul_evals = [F::random(&mut rng); 4];
9292
let key_collection = build_key_collection(&cs);
9393

94-
let inout_n_vars = strict_log_2(cs.value_vec_layout.offset_witness).unwrap();
95-
9694
let mut group = c.benchmark_group(format!(
9795
"shift_reduction_log2_{log_message_len_bytes}_bytes_{message_len_bytes}"
9896
));
@@ -114,7 +112,6 @@ fn bench_prove_and_verify(c: &mut Criterion) {
114112
let mut prover_transcript = ProverTranscript::<StdChallenger>::default();
115113

116114
prove::<F, P, _>(
117-
inout_n_vars,
118115
&key_collection,
119116
value_vec.combined_witness(),
120117
prover_bitand_data,
@@ -140,7 +137,6 @@ fn bench_prove_and_verify(c: &mut Criterion) {
140137
let mut prover_transcript = ProverTranscript::<StdChallenger>::default();
141138

142139
prove::<F, P, _>(
143-
inout_n_vars,
144140
&key_collection,
145141
value_vec.combined_witness(),
146142
prover_bitand_data,
@@ -166,14 +162,8 @@ fn bench_prove_and_verify(c: &mut Criterion) {
166162
intmul_evals,
167163
);
168164

169-
verify(
170-
&cs,
171-
value_vec.public(),
172-
&verifier_bitand_data,
173-
&verifier_intmul_data,
174-
&mut verifier_transcript,
175-
)
176-
.unwrap();
165+
verify(&cs, &verifier_bitand_data, &verifier_intmul_data, &mut verifier_transcript)
166+
.unwrap();
177167
})
178168
});
179169
}

crates/prover/src/protocols/shift/phase_2.rs

Lines changed: 4 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
// Copyright 2025 Irreducible Inc.
22

3-
use std::iter;
4-
53
use binius_core::word::Word;
64
use binius_field::{AESTowerField8b, BinaryField, Field, PackedField};
75
use binius_ip_prover::channel::IPProverChannel;
8-
use binius_math::{
9-
FieldBuffer,
10-
multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
11-
};
6+
use binius_math::{FieldBuffer, multilinear::eq::eq_ind_partial_eval};
127
use binius_verifier::{config::LOG_WORD_SIZE_BITS, protocols::sumcheck::SumcheckOutput};
138
use tracing::instrument;
149

@@ -50,7 +45,6 @@ use crate::{
5045
/// or an error if the protocol fails.
5146
#[instrument(skip_all, name = "prove_phase_2")]
5247
pub fn prove_phase_2<F, P: PackedField<Scalar = F>, Channel>(
53-
inout_n_vars: usize,
5448
key_collection: &KeyCollection,
5549
words: &[Word],
5650
bitand_data: &PreparedOperatorData<F>,
@@ -78,78 +72,15 @@ where
7872
let monster_multilinear =
7973
build_monster_multilinear(key_collection, bitand_data, intmul_data, &r_j, &r_s)?;
8074

81-
run_sumcheck(inout_n_vars, r_j_witness, monster_multilinear, r_j, gamma, channel)
82-
}
83-
84-
/// Evaluates the r_j_witness multilinear at the inout evaluation point.
85-
///
86-
/// This helper function extracts the public input chunk from the r_j_witness
87-
/// and evaluates it at the given inout_eval_point, corresponding to evaluating
88-
/// the witness at `(r_j, inout_eval_point || 0)`.
89-
///
90-
/// # Parameters
91-
/// - `r_j_witness`: The witness folded at challenges `r_j`
92-
/// - `inout_eval_point`: Challenge point for the public input/output variables
93-
///
94-
/// # Returns
95-
/// The evaluation of the public chunk at `inout_eval_point`.
96-
fn evaluate_public_at_inout<F: Field, P: PackedField<Scalar = F>>(
97-
r_j_witness: &FieldBuffer<P>,
98-
inout_eval_point: &[F],
99-
) -> F {
100-
let public_chunk = r_j_witness.chunk(inout_eval_point.len(), 0);
101-
evaluate(&public_chunk, inout_eval_point)
102-
}
103-
104-
/// Computes the batched polynomial m = monster + χ eq(inout_eval_point || 0),
105-
/// where monster is the monster multilinear and χ is `batch_coeff`.
106-
///
107-
/// Since eq(inout_eval_point || 0) takes the value 0 on all hypercube vertices
108-
/// except the first 2^inout_eval_point.len(), only that first chunk needs to be updated.
109-
///
110-
/// # Parameters
111-
/// - `monster_multilinear`: The monster multilinear polynomial
112-
/// - `inout_eval_point`: Challenge point for the public input/output variables
113-
/// - `batch_coeff`: Random batching coefficient
114-
///
115-
/// # Returns
116-
/// The combined polynomial `monster_multilinear + batch_coeff * eq(inout_eval_point || 0)`.
117-
fn compute_monster_with_inout<F: Field, P: PackedField<Scalar = F>>(
118-
monster_multilinear: FieldBuffer<P>,
119-
inout_eval_point: &[F],
120-
batch_coeff: F,
121-
) -> FieldBuffer<P> {
122-
let mut combined_monster = monster_multilinear;
123-
124-
{
125-
let mut public_chunk = combined_monster.chunk_mut(inout_eval_point.len(), 0);
126-
let mut public_chunk = public_chunk.get();
127-
128-
let eq_inout = eq_ind_partial_eval::<P>(inout_eval_point);
129-
130-
let batch_coeff_packed = P::broadcast(batch_coeff);
131-
for (dst, src) in iter::zip(public_chunk.as_mut(), eq_inout.as_ref()) {
132-
*dst += *src * batch_coeff_packed;
133-
}
134-
}
135-
136-
combined_monster
75+
run_sumcheck(r_j_witness, monster_multilinear, r_j, gamma, channel)
13776
}
13877

13978
/// Executes the bivariate product sumcheck for the witness and monster multilinear relationship.
14079
///
14180
/// This helper function runs the sumcheck protocol to prove the relationship between
142-
/// the witness and monster multilinear, batched with the public input check.
143-
///
144-
/// # Protocol Details
145-
/// - Samples challenge point for public inputs and computes public evaluation
146-
/// - Samples batching coefficient and merges public check into monster multilinear
147-
/// - Uses single `BivariateProductSumcheckProver` for the batched relationship
148-
/// - Extracts witness evaluation from the sumcheck output
149-
/// - In debug mode, verifies the witness evaluation against expected value
81+
/// the witness and monster multilinear.
15082
///
15183
/// # Parameters
152-
/// - `inout_n_vars`: Number of variables for the public input/output point
15384
/// - `r_j_witness`: The witness folded at challenges `r_j`
15485
/// - `monster_multilinear`: The monster multilinear polynomial constructed from constraints
15586
/// - `r_j`: Challenge vector from phase 1 (first `LOG_WORD_SIZE_BITS` challenges)
@@ -160,7 +91,6 @@ fn compute_monster_with_inout<F: Field, P: PackedField<Scalar = F>>(
16091
/// Returns `SumcheckOutput` with concatenated challenges `[r_j, r_y]` and witness evaluation.
16192
#[instrument(skip_all, name = "run_sumcheck")]
16293
fn run_sumcheck<F: Field, P: PackedField<Scalar = F>, Channel: IPProverChannel<F>>(
163-
inout_n_vars: usize,
16494
r_j_witness: FieldBuffer<P>,
16595
monster_multilinear: FieldBuffer<P>,
16696
r_j: Vec<F>,
@@ -170,20 +100,8 @@ fn run_sumcheck<F: Field, P: PackedField<Scalar = F>, Channel: IPProverChannel<F
170100
#[cfg(debug_assertions)]
171101
let cloned_r_j_witness_for_debugging = r_j_witness.clone();
172102

173-
// Sample inout evaluation point and compute public evaluation
174-
let inout_eval_point = channel.sample_many(inout_n_vars);
175-
let public_eval = evaluate_public_at_inout(&r_j_witness, &inout_eval_point);
176-
177-
// Sample batching coefficient
178-
let batch_coeff = channel.sample();
179-
180-
// Compute the batched polynomial m = monster + χ eq(inout_eval_point || 0)
181-
let combined_monster =
182-
compute_monster_with_inout(monster_multilinear, &inout_eval_point, batch_coeff);
183-
184103
// Run sumcheck on bivariate product
185-
let batched_sum = gamma + batch_coeff * public_eval;
186-
let prover = BivariateProductSumcheckProver::new([r_j_witness, combined_monster], batched_sum)?;
104+
let prover = BivariateProductSumcheckProver::new([r_j_witness, monster_multilinear], gamma)?;
187105

188106
let ProveSingleOutput {
189107
multilinear_evals,

crates/prover/src/protocols/shift/prove.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ impl<F: Field> PreparedOperatorData<F> {
8585
/// 3. **Phase 2**: Reduces to witness evaluation using monster multilinear polynomial
8686
///
8787
/// # Parameters
88-
/// - `log_public_words`: log2 the number of public words
8988
/// - `key_collection`: Prover's key collection representing the constraint system
9089
/// - `words`: The witness words (must have power-of-2 length)
9190
/// - `bitand_data`: Operator data for bit multiplication (AND) constraints
@@ -99,7 +98,6 @@ impl<F: Field> PreparedOperatorData<F> {
9998
/// # Requirements
10099
/// - `words` must have power-of-2 length for efficient multilinear operations
101100
pub fn prove<F, P, Channel>(
102-
log_public_words: usize,
103101
key_collection: &KeyCollection,
104102
words: &[Word],
105103
bitand_data: OperatorData<F>,
@@ -137,7 +135,6 @@ where
137135
// the witness at oblong point had by univariate
138136
// variable `r_j` and multilinear variable `r_y`.
139137
let SumcheckOutput { challenges, eval } = prove_phase_2::<_, P, _>(
140-
log_public_words,
141138
key_collection,
142139
words,
143140
&prepared_bitand_data,

crates/prover/src/prove.rs

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ use binius_core::{
66
word::Word,
77
};
88
use binius_field::{
9-
AESTowerField8b as B8, BinaryField, PackedAESBinaryField16x8b, PackedExtension, PackedField,
10-
UnderlierWithBitOps, WithUnderlier,
9+
AESTowerField8b as B8, BinaryField, ExtensionField, PackedAESBinaryField16x8b, PackedExtension,
10+
PackedField, UnderlierWithBitOps, WithUnderlier,
1111
};
1212
use binius_iop_prover::{
1313
basefold_channel::BaseFoldProverChannel, basefold_compiler::BaseFoldProverCompiler,
1414
channel::IOPProverChannel,
1515
};
1616
use binius_math::{
17-
BinarySubspace, FieldBuffer,
17+
BinarySubspace, FieldBuffer, FieldSlice,
1818
inner_product::inner_product,
19+
multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
1920
ntt::{NeighborsLastMultiThread, domain_context::GenericPreExpanded},
2021
univariate::lagrange_evals,
2122
};
@@ -144,24 +145,6 @@ where
144145
witness: ValueVec,
145146
transcript: &mut ProverTranscript<Challenger_>,
146147
) -> Result<(), Error> {
147-
let verifier = &self.verifier;
148-
149-
// Check that the public input length is correct
150-
let public = witness.public().to_vec();
151-
if public.len() != 1 << self.verifier.log_public_words() {
152-
return Err(Error::ArgumentError {
153-
arg: "witness".to_string(),
154-
msg: format!(
155-
"witness layout has {} words, expected {}",
156-
public.len(),
157-
1 << verifier.log_public_words()
158-
),
159-
});
160-
}
161-
162-
// Prover observes the public input (includes it in Fiat-Shamir).
163-
transcript.observe().write_slice(&public);
164-
165148
// Create channel and delegate to prove_iop
166149
let channel = BaseFoldProverChannel::from_compiler(&self.basefold_compiler, transcript);
167150
self.prove_iop(witness, channel)
@@ -191,6 +174,14 @@ where
191174
let witness_packed = pack_witness::<P>(verifier.log_witness_elems(), &witness)?;
192175
drop(setup_guard);
193176

177+
// Observe the public input as B128 elements (includes it in Fiat-Shamir).
178+
let n_public_elems = 1 << (verifier.log_public_words() - LOG_WORDS_PER_ELEM);
179+
let public_elems = witness_packed
180+
.iter_scalars()
181+
.take(n_public_elems)
182+
.collect::<Vec<_>>();
183+
channel.observe_many(&public_elems);
184+
194185
// [phase] Witness Commit - witness generation and commitment
195186
let witness_commit_guard = tracing::info_span!(
196187
"[phase] Witness Commit",
@@ -280,7 +271,6 @@ where
280271
challenges: eval_point,
281272
eval: _,
282273
} = prove_shift_reduction::<_, P, _>(
283-
verifier.log_public_words(),
284274
&self.key_collection,
285275
witness.combined_witness(),
286276
bitand_claim,
@@ -303,15 +293,53 @@ where
303293
sumcheck_claim,
304294
} = ring_switch::prove(&witness_packed, &eval_point, &mut channel);
305295

296+
// Public input check batched with ring-switch
297+
let log_packing = <B128 as ExtensionField<B1>>::LOG_DEGREE;
298+
299+
let log_public_elems = verifier.log_public_words() - LOG_WORDS_PER_ELEM;
300+
let pubcheck_point = &eval_point[log_packing..][..log_public_elems];
301+
let pubcheck_claim = {
302+
let public_elems_buf = FieldSlice::from_slice(log_public_elems, &public_elems);
303+
evaluate(&public_elems_buf, pubcheck_point)
304+
};
305+
306+
let batch_coeff: B128 = channel.sample();
307+
let batched_claim = sumcheck_claim + batch_coeff * pubcheck_claim;
308+
309+
// Batch the pubcheck transparent with the ring-switch transparent
310+
let batched_transparent =
311+
compute_batched_transparent(rs_eq_ind, pubcheck_point, batch_coeff);
312+
306313
// Prove oracle relations via channel (runs BaseFold internally)
307-
channel.prove_oracle_relations([(trace_oracle, rs_eq_ind, sumcheck_claim)]);
314+
channel.prove_oracle_relations([(trace_oracle, batched_transparent, batched_claim)]);
308315

309316
drop(pcs_guard);
310317

311318
Ok(())
312319
}
313320
}
314321

322+
/// Batches the pubcheck transparent polynomial with the ring-switch equality indicator.
323+
///
324+
/// Computes `rs_eq_ind + batch_coeff * eq(pubcheck_point || 0, ·)`, adding the scaled
325+
/// pubcheck equality indicator to the first `2^log_public_elems` entries of `rs_eq_ind`.
326+
fn compute_batched_transparent<P: PackedField<Scalar = B128>>(
327+
mut rs_eq_ind: FieldBuffer<P>,
328+
pubcheck_point: &[B128],
329+
batch_coeff: B128,
330+
) -> FieldBuffer<P> {
331+
let log_public_elems = pubcheck_point.len();
332+
let pubcheck_eq = eq_ind_partial_eval::<P>(pubcheck_point);
333+
let mut chunk = rs_eq_ind.chunk_mut(log_public_elems, 0);
334+
let mut chunk_data = chunk.get();
335+
let batch = P::broadcast(batch_coeff);
336+
for (dst, src) in std::iter::zip(chunk_data.as_mut(), pubcheck_eq.as_ref()) {
337+
*dst += *src * batch;
338+
}
339+
drop(chunk);
340+
rs_eq_ind
341+
}
342+
315343
fn pack_witness<P: PackedField<Scalar = B128>>(
316344
log_witness_elems: usize,
317345
witness: &ValueVec,

crates/prover/tests/shift.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,7 @@ fn test_shift_prove_and_verify() {
296296
r_x_prime: r_x_prime_intmul.clone(),
297297
};
298298

299-
let inout_n_vars = strict_log_2(cs.value_vec_layout.offset_witness).unwrap();
300-
301299
let prover_output = prove::<F, P, _>(
302-
inout_n_vars,
303300
&key_collection,
304301
value_vec.combined_witness(),
305302
prover_bitand_data.clone(),
@@ -316,14 +313,9 @@ fn test_shift_prove_and_verify() {
316313
let verifier_intmul_data =
317314
VerifierOperatorData::new(r_zhat_prime_intmul, r_x_prime_intmul, intmul_evals);
318315

319-
let verifier_output = verify(
320-
&cs,
321-
value_vec.public(),
322-
&verifier_bitand_data,
323-
&verifier_intmul_data,
324-
&mut verifier_transcript,
325-
)
326-
.unwrap();
316+
let verifier_output =
317+
verify(&cs, &verifier_bitand_data, &verifier_intmul_data, &mut verifier_transcript)
318+
.unwrap();
327319
verifier_transcript.finalize().unwrap();
328320

329321
// Check consistency with verifier output

0 commit comments

Comments
 (0)