@@ -69,23 +69,40 @@ impl Poseidon2 {
6969 state [2 ] += input [i * RATE + 2 ];
7070 state = crate ::poseidon2_permutation (state , 4 );
7171 }
72- // handle remaining (<3) elements
72+
73+ // handle remaining elements after last full RATE-sized chunk
7374 let remainder_start = (in_len / RATE ) * RATE ;
7475 for j in remainder_start ..in_len {
7576 state [j - remainder_start ] += input [j ];
7677 }
7778 } else {
78- for i in 0 ..input .len () {
79- if i < in_len {
80- state [i % RATE ] += input [i ];
81- if (i + 1 ) % RATE == 0 {
82- state = crate ::poseidon2_permutation (state , 4 );
83- }
79+ let mut states : [[Field ; 4 ]; N / RATE + 1 ] = [[0 ; 4 ]; N / RATE + 1 ];
80+ states [0 ] = state ;
81+
82+ // process all full RATE-sized chunks, storing state after each permutation
83+ for chunk_idx in 0 ..(N / RATE ) {
84+ for i in 0 ..RATE {
85+ state [i ] += input [chunk_idx * RATE + i ];
86+ }
87+ state = crate ::poseidon2_permutation (state , 4 );
88+ states [chunk_idx + 1 ] = state ;
89+ }
90+
91+ // get state at the last full block before in_len
92+ let first_partially_filled_chunk = in_len / RATE ;
93+ state = states [first_partially_filled_chunk ];
94+
95+ // handle remaining elements after last full RATE-sized chunk
96+ let remainder_start = (in_len / RATE ) * RATE ;
97+ for j in 0 ..RATE {
98+ let idx = remainder_start + j ;
99+ if idx < in_len {
100+ state [j ] += input [idx ];
84101 }
85102 }
86103 }
87104
88- // Always run final permutation unless we just completed a full chunk
105+ // always run final permutation unless we just completed a full chunk
89106 // still need to permute once if in_len is 0
90107 if (in_len == 0 ) | (in_len % RATE != 0 ) {
91108 state = crate ::poseidon2_permutation (state , 4 )
0 commit comments