1
1
use std ::array ;
2
2
use std ::check::assert ;
3
3
use std ::utils::unchanged_until ;
4
+ use std ::utils::new_bool ;
4
5
use std ::utils::force_bool ;
5
6
use std ::utils::sum ;
6
7
use std ::convert::expr ;
@@ -19,8 +20,8 @@ use super::poseidon2_common::poseidon2;
19
20
// state size of 8 field elements instead of 12 , matching Plonky3's implementation.
20
21
//
21
22
// This machine assumes each memory word contains a full field element , and it
22
- // writes one field element per memory word. Use SplitGLVec8 to split the output
23
- // into 32 - bit words.
23
+ // writes one field element per memory word. Use SplitGLVec4 to split the output
24
+ // into 32 - bit words.
24
25
machine Poseidon2GL(mem: Memory) with
25
26
latch: latch ,
26
27
// Allow this machine to be connected via a permutation
@@ -33,19 +34,32 @@ machine Poseidon2GL(mem: Memory) with
33
34
// The input data is passed via a memory pointer: the machine will read STATE_SIZE
34
35
// field elements from memory.
35
36
//
36
- // Similarly , the output data is written to memory at the provided pointer.
37
+ // Similarly , the output data is written to memory at the provided pointer. We don't
38
+ // have any use for writing the full state as output , so depending on the operation ,
39
+ // it will either write the first half of the state (used in sponge squeeze) or the
40
+ // second half (used in sponge absorb and on merkle tree compression).
37
41
//
38
- // Reads happen at the provided time step ; writes happen at the next time step .
42
+ // Memory reads happen at input_time_step and memory writes happens at output_time_step .
39
43
//
40
44
// The addresses must be multiple of 4 .
41
- operation poseidon2_permutation
45
+ //
46
+ // This operation can output any combination of the first and second half of the final
47
+ // state , depending on the value of output_halves:
48
+ // 0 : no output
49
+ // 1 : first half
50
+ // 2 : second half
51
+ // 3 : the entire state
52
+ operation permute
42
53
input_addr ,
54
+ input_time_step ,
43
55
output_addr ,
44
- time_step - > ;
56
+ output_time_step ,
57
+ output_halves - > ;
45
58
46
59
let latch = 1 ;
47
60
48
- let time_step ;
61
+ let input_time_step ;
62
+ let output_time_step ;
49
63
50
64
// Poseidon2 parameters , compatible with our powdr - plonky3 implementation.
51
65
//
@@ -120,14 +134,14 @@ machine Poseidon2GL(mem: Memory) with
120
134
let input: col [ STATE_SIZE ] ;
121
135
122
136
// TODO: when link is available inside functions , we can turn this into array operations.
123
- link if is_used ~> input [ 0 ] = mem.mload(input_addr + 0 , time_step ) ;
124
- link if is_used ~> input [ 1 ] = mem.mload(input_addr + 4 , time_step ) ;
125
- link if is_used ~> input [ 2 ] = mem.mload(input_addr + 8 , time_step ) ;
126
- link if is_used ~> input [ 3 ] = mem.mload(input_addr + 12 , time_step ) ;
127
- link if is_used ~> input [ 4 ] = mem.mload(input_addr + 16 , time_step ) ;
128
- link if is_used ~> input [ 5 ] = mem.mload(input_addr + 20 , time_step ) ;
129
- link if is_used ~> input [ 6 ] = mem.mload(input_addr + 24 , time_step ) ;
130
- link if is_used ~> input [ 7 ] = mem.mload(input_addr + 28 , time_step ) ;
137
+ link if is_used ~> input [ 0 ] = mem.mload(input_addr + 0 , input_time_step ) ;
138
+ link if is_used ~> input [ 1 ] = mem.mload(input_addr + 4 , input_time_step ) ;
139
+ link if is_used ~> input [ 2 ] = mem.mload(input_addr + 8 , input_time_step ) ;
140
+ link if is_used ~> input [ 3 ] = mem.mload(input_addr + 12 , input_time_step ) ;
141
+ link if is_used ~> input [ 4 ] = mem.mload(input_addr + 16 , input_time_step ) ;
142
+ link if is_used ~> input [ 5 ] = mem.mload(input_addr + 20 , input_time_step ) ;
143
+ link if is_used ~> input [ 6 ] = mem.mload(input_addr + 24 , input_time_step ) ;
144
+ link if is_used ~> input [ 7 ] = mem.mload(input_addr + 28 , input_time_step ) ;
131
145
132
146
// Generate the Poseidon2 permutation
133
147
let output = poseidon2(
@@ -143,16 +157,25 @@ machine Poseidon2GL(mem: Memory) with
143
157
input ,
144
158
) ;
145
159
146
- // Write the output to memory at the next time step
147
- let output_addr ;
160
+ // Decide which halves to output:
161
+ let output_halves ;
162
+ let output_first_half = new_bool() ;
163
+ let output_second_half = new_bool() ;
164
+ output_halves = output_first_half + 2 * output_second_half ;
148
165
149
- // TODO: turn this into array operations
150
- link if is_used ~> mem.mstore(output_addr + 0 , time_step + 1 , output [ 0 ] ) ;
151
- link if is_used ~> mem.mstore(output_addr + 4 , time_step + 1 , output [ 1 ] ) ;
152
- link if is_used ~> mem.mstore(output_addr + 8 , time_step + 1 , output [ 2 ] ) ;
153
- link if is_used ~> mem.mstore(output_addr + 12 , time_step + 1 , output [ 3 ] ) ;
154
- link if is_used ~> mem.mstore(output_addr + 16 , time_step + 1 , output [ 4 ] ) ;
155
- link if is_used ~> mem.mstore(output_addr + 20 , time_step + 1 , output [ 5 ] ) ;
156
- link if is_used ~> mem.mstore(output_addr + 24 , time_step + 1 , output [ 6 ] ) ;
157
- link if is_used ~> mem.mstore(output_addr + 28 , time_step + 1 , output [ 7 ] ) ;
166
+ // TODO: turn these into array operations:
167
+
168
+ // Write the first half of the output
169
+ let output_addr ;
170
+ link if is_used * output_first_half ~> mem.mstore(output_addr + 0 , output_time_step , output [ 0 ] ) ;
171
+ link if is_used * output_first_half ~> mem.mstore(output_addr + 4 , output_time_step , output [ 1 ] ) ;
172
+ link if is_used * output_first_half ~> mem.mstore(output_addr + 8 , output_time_step , output [ 2 ] ) ;
173
+ link if is_used * output_first_half ~> mem.mstore(output_addr + 12 , output_time_step , output [ 3 ] ) ;
174
+
175
+ // Write the second half of the output
176
+ let second_half_addr = output_addr + 16 * output_first_half ;
177
+ link if is_used * output_second_half ~> mem.mstore(second_half_addr + 0 , output_time_step , output [ 4 ] ) ;
178
+ link if is_used * output_second_half ~> mem.mstore(second_half_addr + 4 , output_time_step , output [ 5 ] ) ;
179
+ link if is_used * output_second_half ~> mem.mstore(second_half_addr + 8 , output_time_step , output [ 6 ] ) ;
180
+ link if is_used * output_second_half ~> mem.mstore(second_half_addr + 12 , output_time_step , output [ 7 ] ) ;
158
181
}
0 commit comments