@@ -9,6 +9,7 @@ use sp1_primitives::RC_16_30_U32;
99use std:: ops:: Add ;
1010
1111use crate :: air:: { RecursionInteractionAirBuilder , RecursionMemoryAirBuilder } ;
12+ use crate :: memory:: MemoryCols ;
1213use crate :: poseidon2_wide:: { apply_m_4, internal_linear_layer} ;
1314use crate :: runtime:: Opcode ;
1415
@@ -37,6 +38,7 @@ impl Poseidon2Chip {
3738 & self ,
3839 builder : & mut AB ,
3940 local : & Poseidon2Cols < AB :: Var > ,
41+ next : & Poseidon2Cols < AB :: Var > ,
4042 receive_table : AB :: Var ,
4143 memory_access : AB :: Expr ,
4244 ) {
@@ -65,6 +67,7 @@ impl Poseidon2Chip {
6567 self . eval_mem (
6668 builder,
6769 local,
70+ next,
6871 is_memory_read,
6972 is_memory_write,
7073 memory_access,
@@ -73,6 +76,7 @@ impl Poseidon2Chip {
7376 self . eval_computation (
7477 builder,
7578 local,
79+ next,
7680 is_initial. into ( ) ,
7781 is_external_layer. clone ( ) ,
7882 is_internal_layer. clone ( ) ,
@@ -94,12 +98,12 @@ impl Poseidon2Chip {
9498 & self ,
9599 builder : & mut AB ,
96100 local : & Poseidon2Cols < AB :: Var > ,
101+ next : & Poseidon2Cols < AB :: Var > ,
97102 is_memory_read : AB :: Var ,
98103 is_memory_write : AB :: Var ,
99104 memory_access : AB :: Expr ,
100105 ) {
101106 let memory_access_cols = local. round_specific_cols . memory_access ( ) ;
102-
103107 builder
104108 . when ( is_memory_read)
105109 . assert_eq ( local. left_input , memory_access_cols. addr_first_half ) ;
@@ -128,12 +132,24 @@ impl Poseidon2Chip {
128132 memory_access. clone ( ) ,
129133 ) ;
130134 }
135+
136+ // For the memory read round, need to connect the memory val to the input of the next
137+ // computation round.
138+ let next_computation_col = next. round_specific_cols . computation ( ) ;
139+ for i in 0 ..WIDTH {
140+ builder. when_transition ( ) . when ( is_memory_read) . assert_eq (
141+ * memory_access_cols. mem_access [ i] . value ( ) ,
142+ next_computation_col. input [ i] ,
143+ ) ;
144+ }
131145 }
132146
147+ #[ allow( clippy:: too_many_arguments) ]
133148 fn eval_computation < AB : BaseAirBuilder + ExtensionAirBuilder > (
134149 & self ,
135150 builder : & mut AB ,
136151 local : & Poseidon2Cols < AB :: Var > ,
152+ next : & Poseidon2Cols < AB :: Var > ,
137153 is_initial : AB :: Expr ,
138154 is_external_layer : AB :: Expr ,
139155 is_internal_layer : AB :: Expr ,
@@ -158,11 +174,11 @@ impl Poseidon2Chip {
158174 let mut result: AB :: Expr = computation_cols. input [ i] . into ( ) ;
159175 for r in 0 ..rounds {
160176 if i == 0 {
161- result += local. rounds [ r + 1 ]
177+ result += local. rounds [ r + 2 ]
162178 * constants[ r] [ i]
163179 * ( is_external_layer. clone ( ) + is_internal_layer. clone ( ) ) ;
164180 } else {
165- result += local. rounds [ r + 1 ] * constants[ r] [ i] * is_external_layer. clone ( ) ;
181+ result += local. rounds [ r + 2 ] * constants[ r] [ i] * is_external_layer. clone ( ) ;
166182 }
167183 }
168184 builder
@@ -251,9 +267,26 @@ impl Poseidon2Chip {
251267 let mut state: [ AB :: Expr ; WIDTH ] = sbox_result. clone ( ) ;
252268 internal_linear_layer ( & mut state) ;
253269 builder
254- . when ( is_internal_layer)
270+ . when ( is_internal_layer. clone ( ) )
255271 . assert_all_eq ( state. clone ( ) , computation_cols. output ) ;
256272 }
273+
274+ // Assert that the round's output values are equal the the next round's input values. For the
275+ // last computation round, assert athat the output values are equal to the output memory values.
276+ let next_row_computation = next. round_specific_cols . computation ( ) ;
277+ let next_row_memory_access = next. round_specific_cols . memory_access ( ) ;
278+ for i in 0 ..WIDTH {
279+ let next_round_value = builder. if_else (
280+ local. rounds [ 22 ] ,
281+ * next_row_memory_access. mem_access [ i] . value ( ) ,
282+ next_row_computation. input [ i] ,
283+ ) ;
284+
285+ builder
286+ . when_transition ( )
287+ . when ( is_initial. clone ( ) + is_external_layer. clone ( ) + is_internal_layer. clone ( ) )
288+ . assert_eq ( computation_cols. output [ i] , next_round_value) ;
289+ }
257290 }
258291
259292 fn eval_syscall < AB : BaseAirBuilder + ExtensionAirBuilder > (
@@ -295,9 +328,13 @@ where
295328 let main = builder. main ( ) ;
296329 let local = main. row_slice ( 0 ) ;
297330 let local: & Poseidon2Cols < AB :: Var > = ( * local) . borrow ( ) ;
331+ let next = main. row_slice ( 1 ) ;
332+ let next: & Poseidon2Cols < AB :: Var > = ( * next) . borrow ( ) ;
333+
298334 self . eval_poseidon2 :: < AB > (
299335 builder,
300336 local,
337+ next,
301338 Self :: do_receive_table :: < AB :: Var > ( local) ,
302339 Self :: do_memory_access :: < AB :: Var , AB :: Expr > ( local) ,
303340 ) ;
@@ -309,10 +346,10 @@ mod tests {
309346 use itertools:: Itertools ;
310347 use std:: borrow:: Borrow ;
311348 use std:: time:: Instant ;
349+ use zkhash:: ark_ff:: UniformRand ;
312350
313351 use p3_baby_bear:: BabyBear ;
314352 use p3_baby_bear:: DiffusionMatrixBabyBear ;
315- use p3_field:: AbstractField ;
316353 use p3_matrix:: { dense:: RowMajorMatrix , Matrix } ;
317354 use p3_poseidon2:: Poseidon2 ;
318355 use p3_poseidon2:: Poseidon2ExternalMatrixGeneral ;
@@ -324,7 +361,7 @@ mod tests {
324361 } ;
325362
326363 use crate :: {
327- poseidon2:: { Poseidon2Chip , Poseidon2Event , WIDTH } ,
364+ poseidon2:: { Poseidon2Chip , Poseidon2Event } ,
328365 runtime:: ExecutionRecord ,
329366 } ;
330367 use p3_symmetric:: Permutation ;
@@ -338,12 +375,12 @@ mod tests {
338375 let chip = Poseidon2Chip {
339376 fixed_log2_rows : None ,
340377 } ;
341- let test_inputs = vec ! [
342- [ BabyBear :: from_canonical_u32 ( 1 ) ; WIDTH ] ,
343- [ BabyBear :: from_canonical_u32 ( 2 ) ; WIDTH ] ,
344- [ BabyBear :: from_canonical_u32 ( 3 ) ; WIDTH ] ,
345- [ BabyBear :: from_canonical_u32 ( 4 ) ; WIDTH ] ,
346- ] ;
378+
379+ let rng = & mut rand :: thread_rng ( ) ;
380+
381+ let test_inputs : Vec < [ BabyBear ; 16 ] > = ( 0 .. 16 )
382+ . map ( |_| core :: array :: from_fn ( |_| BabyBear :: rand ( rng ) ) )
383+ . collect_vec ( ) ;
347384
348385 let gt: Poseidon2 <
349386 BabyBear ,
@@ -384,9 +421,10 @@ mod tests {
384421 let chip = Poseidon2Chip {
385422 fixed_log2_rows : None ,
386423 } ;
424+ let rng = & mut rand:: thread_rng ( ) ;
387425
388- let test_inputs = ( 0 ..16 )
389- . map ( |i| [ BabyBear :: from_canonical_u32 ( i ) ; WIDTH ] )
426+ let test_inputs: Vec < [ BabyBear ; 16 ] > = ( 0 ..16 )
427+ . map ( |_| core :: array :: from_fn ( |_| BabyBear :: rand ( rng ) ) )
390428 . collect_vec ( ) ;
391429
392430 let gt: Poseidon2 <
0 commit comments