Skip to content

Commit e699a98

Browse files
authored
chore(recursion): poseidon2 loose ends (#672)
1 parent 227bbdd commit e699a98

9 files changed

Lines changed: 64 additions & 23 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

prover/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ impl SP1Prover {
557557
runtime.witness_stream = witness_stream.into();
558558
runtime.run();
559559
let mut checkpoint = runtime.memory.clone();
560+
let checkpoint_uninit = runtime.uninitialized_memory.clone();
560561

561562
// Execute runtime.
562563
let machine = RecursionAirWideDeg3::machine(InnerSC::default());
@@ -568,6 +569,7 @@ impl SP1Prover {
568569
e.1.timestamp = BabyBear::zero();
569570
});
570571
runtime.memory = checkpoint;
572+
runtime.uninitialized_memory = checkpoint_uninit;
571573
runtime.run();
572574
runtime.print_stats();
573575
tracing::info!(

recursion/core/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ serde_with = "3.6.1"
3030
backtrace = { version = "0.3.71", features = ["serde"] }
3131
arrayref = "0.3.6"
3232
static_assertions = "1.1.0"
33+
34+
[dev-dependencies]
35+
rand = "0.8.5"

recursion/core/src/multi/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ where
190190
poseidon2_chip.eval_poseidon2(
191191
&mut sub_builder,
192192
local.poseidon2(),
193+
next.poseidon2(),
193194
local.poseidon2_receive_table,
194195
local.poseidon2_memory_access.into(),
195196
);

recursion/core/src/poseidon2/columns.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ pub struct Poseidon2Cols<T: Copy> {
1111
pub left_input: T,
1212
pub right_input: T,
1313
pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output
14-
pub is_computation: T,
15-
pub is_memory_access: T,
1614
pub round_specific_cols: RoundSpecificCols<T>,
1715
}
1816

recursion/core/src/poseidon2/external.rs

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use sp1_primitives::RC_16_30_U32;
99
use std::ops::Add;
1010

1111
use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder};
12+
use crate::memory::MemoryCols;
1213
use crate::poseidon2_wide::{apply_m_4, internal_linear_layer};
1314
use crate::runtime::Opcode;
1415

@@ -37,6 +38,7 @@ impl Poseidon2Chip {
3738
&self,
3839
builder: &mut AB,
3940
local: &Poseidon2Cols<AB::Var>,
41+
next: &Poseidon2Cols<AB::Var>,
4042
receive_table: AB::Var,
4143
memory_access: AB::Expr,
4244
) {
@@ -65,6 +67,7 @@ impl Poseidon2Chip {
6567
self.eval_mem(
6668
builder,
6769
local,
70+
next,
6871
is_memory_read,
6972
is_memory_write,
7073
memory_access,
@@ -73,6 +76,7 @@ impl Poseidon2Chip {
7376
self.eval_computation(
7477
builder,
7578
local,
79+
next,
7680
is_initial.into(),
7781
is_external_layer.clone(),
7882
is_internal_layer.clone(),
@@ -94,12 +98,12 @@ impl Poseidon2Chip {
9498
&self,
9599
builder: &mut AB,
96100
local: &Poseidon2Cols<AB::Var>,
101+
next: &Poseidon2Cols<AB::Var>,
97102
is_memory_read: AB::Var,
98103
is_memory_write: AB::Var,
99104
memory_access: AB::Expr,
100105
) {
101106
let memory_access_cols = local.round_specific_cols.memory_access();
102-
103107
builder
104108
.when(is_memory_read)
105109
.assert_eq(local.left_input, memory_access_cols.addr_first_half);
@@ -128,12 +132,24 @@ impl Poseidon2Chip {
128132
memory_access.clone(),
129133
);
130134
}
135+
136+
// For the memory read round, need to connect the memory val to the input of the next
137+
// computation round.
138+
let next_computation_col = next.round_specific_cols.computation();
139+
for i in 0..WIDTH {
140+
builder.when_transition().when(is_memory_read).assert_eq(
141+
*memory_access_cols.mem_access[i].value(),
142+
next_computation_col.input[i],
143+
);
144+
}
131145
}
132146

147+
#[allow(clippy::too_many_arguments)]
133148
fn eval_computation<AB: BaseAirBuilder + ExtensionAirBuilder>(
134149
&self,
135150
builder: &mut AB,
136151
local: &Poseidon2Cols<AB::Var>,
152+
next: &Poseidon2Cols<AB::Var>,
137153
is_initial: AB::Expr,
138154
is_external_layer: AB::Expr,
139155
is_internal_layer: AB::Expr,
@@ -158,11 +174,11 @@ impl Poseidon2Chip {
158174
let mut result: AB::Expr = computation_cols.input[i].into();
159175
for r in 0..rounds {
160176
if i == 0 {
161-
result += local.rounds[r + 1]
177+
result += local.rounds[r + 2]
162178
* constants[r][i]
163179
* (is_external_layer.clone() + is_internal_layer.clone());
164180
} else {
165-
result += local.rounds[r + 1] * constants[r][i] * is_external_layer.clone();
181+
result += local.rounds[r + 2] * constants[r][i] * is_external_layer.clone();
166182
}
167183
}
168184
builder
@@ -251,9 +267,26 @@ impl Poseidon2Chip {
251267
let mut state: [AB::Expr; WIDTH] = sbox_result.clone();
252268
internal_linear_layer(&mut state);
253269
builder
254-
.when(is_internal_layer)
270+
.when(is_internal_layer.clone())
255271
.assert_all_eq(state.clone(), computation_cols.output);
256272
}
273+
274+
// Assert that the round's output values are equal the the next round's input values. For the
275+
// last computation round, assert athat the output values are equal to the output memory values.
276+
let next_row_computation = next.round_specific_cols.computation();
277+
let next_row_memory_access = next.round_specific_cols.memory_access();
278+
for i in 0..WIDTH {
279+
let next_round_value = builder.if_else(
280+
local.rounds[22],
281+
*next_row_memory_access.mem_access[i].value(),
282+
next_row_computation.input[i],
283+
);
284+
285+
builder
286+
.when_transition()
287+
.when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone())
288+
.assert_eq(computation_cols.output[i], next_round_value);
289+
}
257290
}
258291

259292
fn eval_syscall<AB: BaseAirBuilder + ExtensionAirBuilder>(
@@ -295,9 +328,13 @@ where
295328
let main = builder.main();
296329
let local = main.row_slice(0);
297330
let local: &Poseidon2Cols<AB::Var> = (*local).borrow();
331+
let next = main.row_slice(1);
332+
let next: &Poseidon2Cols<AB::Var> = (*next).borrow();
333+
298334
self.eval_poseidon2::<AB>(
299335
builder,
300336
local,
337+
next,
301338
Self::do_receive_table::<AB::Var>(local),
302339
Self::do_memory_access::<AB::Var, AB::Expr>(local),
303340
);
@@ -309,10 +346,10 @@ mod tests {
309346
use itertools::Itertools;
310347
use std::borrow::Borrow;
311348
use std::time::Instant;
349+
use zkhash::ark_ff::UniformRand;
312350

313351
use p3_baby_bear::BabyBear;
314352
use p3_baby_bear::DiffusionMatrixBabyBear;
315-
use p3_field::AbstractField;
316353
use p3_matrix::{dense::RowMajorMatrix, Matrix};
317354
use p3_poseidon2::Poseidon2;
318355
use p3_poseidon2::Poseidon2ExternalMatrixGeneral;
@@ -324,7 +361,7 @@ mod tests {
324361
};
325362

326363
use crate::{
327-
poseidon2::{Poseidon2Chip, Poseidon2Event, WIDTH},
364+
poseidon2::{Poseidon2Chip, Poseidon2Event},
328365
runtime::ExecutionRecord,
329366
};
330367
use p3_symmetric::Permutation;
@@ -338,12 +375,12 @@ mod tests {
338375
let chip = Poseidon2Chip {
339376
fixed_log2_rows: None,
340377
};
341-
let test_inputs = vec![
342-
[BabyBear::from_canonical_u32(1); WIDTH],
343-
[BabyBear::from_canonical_u32(2); WIDTH],
344-
[BabyBear::from_canonical_u32(3); WIDTH],
345-
[BabyBear::from_canonical_u32(4); WIDTH],
346-
];
378+
379+
let rng = &mut rand::thread_rng();
380+
381+
let test_inputs: Vec<[BabyBear; 16]> = (0..16)
382+
.map(|_| core::array::from_fn(|_| BabyBear::rand(rng)))
383+
.collect_vec();
347384

348385
let gt: Poseidon2<
349386
BabyBear,
@@ -384,9 +421,10 @@ mod tests {
384421
let chip = Poseidon2Chip {
385422
fixed_log2_rows: None,
386423
};
424+
let rng = &mut rand::thread_rng();
387425

388-
let test_inputs = (0..16)
389-
.map(|i| [BabyBear::from_canonical_u32(i); WIDTH])
426+
let test_inputs: Vec<[BabyBear; 16]> = (0..16)
427+
.map(|_| core::array::from_fn(|_| BabyBear::rand(rng)))
390428
.collect_vec();
391429

392430
let gt: Poseidon2<

recursion/core/src/poseidon2/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl<F: PrimeField32> Poseidon2Event<F> {
3030
MemoryRecord::new_read(F::zero(), Block::from(input[i]), F::one(), F::zero())
3131
});
3232
let output_records: [MemoryRecord<F>; WIDTH] = core::array::from_fn(|i| {
33-
MemoryRecord::new_read(F::zero(), Block::from(output[i]), F::one(), F::zero())
33+
MemoryRecord::new_read(F::zero(), Block::from(output[i]), F::two(), F::zero())
3434
});
3535

3636
Self {

recursion/core/src/poseidon2/trace.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,15 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2Chip {
115115
// Apply the round constants.
116116
for j in 0..WIDTH {
117117
computation_cols.add_rc[j] = computation_cols.input[j]
118-
+ F::from_wrapped_u32(RC_16_30_U32[r - 1][j]);
118+
+ F::from_wrapped_u32(RC_16_30_U32[r - 2][j]);
119119
}
120120
} else {
121121
// Apply the round constants only on the first element.
122122
computation_cols
123123
.add_rc
124124
.copy_from_slice(&computation_cols.input);
125125
computation_cols.add_rc[0] =
126-
computation_cols.input[0] + F::from_wrapped_u32(RC_16_30_U32[r - 1][0]);
126+
computation_cols.input[0] + F::from_wrapped_u32(RC_16_30_U32[r - 2][0]);
127127
};
128128

129129
// Apply the sbox.

recursion/core/src/poseidon2_wide/external.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,10 +493,8 @@ mod tests {
493493
.push(Poseidon2Event::dummy_from_input(input, output));
494494
}
495495

496-
let trace: RowMajorMatrix<BabyBear> =
497-
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
498-
499-
assert_eq!(trace.height(), test_inputs.len());
496+
// Generate trace will assert for the expected outputs.
497+
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
500498
}
501499

502500
/// A test generating a trace for a single permutation that checks that the output is correct

0 commit comments

Comments
 (0)