@@ -18,11 +18,18 @@ use crate::witgen::{
18
18
use super :: {
19
19
block_machine_processor:: BlockMachineProcessor ,
20
20
compiler:: { compile_effects, CompiledFunction } ,
21
+ effect:: Effect ,
21
22
interpreter:: EffectsInterpreter ,
23
+ prover_function_heuristics:: ProverFunction ,
22
24
variable:: Variable ,
23
25
witgen_inference:: CanProcessCall ,
24
26
} ;
25
27
28
+ /// Inferred witness generation routines that are larger than
29
+ /// this number of "statements" will use the interpreter instead of the compiler
30
+ /// due to the large compilation ressources required.
31
+ const MAX_COMPILED_CODE_SIZE : usize = 500 ;
32
+
26
33
#[ derive( Debug , Clone , Hash , PartialEq , Eq ) ]
27
34
struct CacheKey < T : FieldElement > {
28
35
bus_id : T ,
@@ -136,28 +143,26 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
136
143
) -> & Option < CacheEntry < T > > {
137
144
if !self . witgen_functions . contains_key ( cache_key) {
138
145
record_start ( "Auto-witgen code derivation" ) ;
139
- let f = match T :: known_field ( ) {
140
- // TODO: Currently, code generation only supports the Goldilocks
141
- // fields. We can't enable the interpreter for non-goldilocks
142
- // fields due to a limitation of autowitgen.
143
- Some ( KnownField :: GoldilocksField ) => {
144
- self . compile_witgen_function ( can_process, cache_key, false )
145
- }
146
- _ => None ,
147
- } ;
148
- assert ! ( self . witgen_functions. insert( cache_key. clone( ) , f) . is_none( ) ) ;
146
+ let compiled = self
147
+ . derive_witgen_function ( can_process, cache_key)
148
+ . and_then ( |( result, prover_functions) | {
149
+ self . compile_witgen_function ( result, prover_functions, cache_key)
150
+ } ) ;
151
+ assert ! ( self
152
+ . witgen_functions
153
+ . insert( cache_key. clone( ) , compiled)
154
+ . is_none( ) ) ;
149
155
150
156
record_end ( "Auto-witgen code derivation" ) ;
151
157
}
152
158
self . witgen_functions . get ( cache_key) . unwrap ( )
153
159
}
154
160
155
- fn compile_witgen_function (
161
+ fn derive_witgen_function (
156
162
& self ,
157
163
can_process : impl CanProcessCall < T > ,
158
164
cache_key : & CacheKey < T > ,
159
- interpreted : bool ,
160
- ) -> Option < CacheEntry < T > > {
165
+ ) -> Option < ( ProcessorResult < T > , Vec < ProverFunction < ' a , T > > ) > {
161
166
log:: debug!(
162
167
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}" ,
163
168
self . machine_name,
@@ -169,13 +174,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
169
174
. unwrap_or_default( )
170
175
) ;
171
176
172
- let (
173
- ProcessorResult {
174
- code,
175
- range_constraints,
176
- } ,
177
- prover_functions,
178
- ) = self
177
+ let ( processor_result, prover_functions) = self
179
178
. processor
180
179
. generate_code (
181
180
can_process,
@@ -193,7 +192,8 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
193
192
. ok ( ) ?;
194
193
195
194
log:: debug!( "=> Success!" ) ;
196
- let out_of_bounds_vars = code
195
+ let out_of_bounds_vars = processor_result
196
+ . code
197
197
. iter ( )
198
198
. flat_map ( |effect| effect. referenced_variables ( ) )
199
199
. filter_map ( |var| match var {
@@ -203,7 +203,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
203
203
. filter ( |cell| cell. row_offset < -1 || cell. row_offset >= self . block_size as i32 )
204
204
. collect_vec ( ) ;
205
205
if !out_of_bounds_vars. is_empty ( ) {
206
- log:: debug!( "Code:\n {}" , format_code( & code) ) ;
206
+ log:: debug!( "Code:\n {}" , format_code( & processor_result . code) ) ;
207
207
panic ! (
208
208
"Expected JITed code to only reference cells in the block + the last row \
209
209
of the previous block, i.e. rows -1 until (including) {}, but it does reference the following:\n {}",
@@ -212,25 +212,51 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
212
212
) ;
213
213
}
214
214
215
- log:: trace!( "Generated code ({} steps)" , code. len( ) ) ;
215
+ log:: trace!( "Generated code ({} steps)" , processor_result. code. len( ) ) ;
216
+ Some ( ( processor_result, prover_functions) )
217
+ }
218
+
219
+ fn compile_witgen_function (
220
+ & self ,
221
+ result : ProcessorResult < T > ,
222
+ prover_functions : Vec < ProverFunction < ' a , T > > ,
223
+ cache_key : & CacheKey < T > ,
224
+ ) -> Option < CacheEntry < T > > {
216
225
let known_inputs = cache_key
217
226
. known_args
218
227
. iter ( )
219
228
. enumerate ( )
220
229
. filter_map ( |( i, b) | if b { Some ( Variable :: Param ( i) ) } else { None } )
221
230
. collect :: < Vec < _ > > ( ) ;
222
231
232
+ let has_prover_function_call = has_prover_function_call ( & result. code ) ;
233
+
234
+ // TODO This is the goal, but we need to implement prover unctions for the interpreter first.
235
+
236
+ // Use the compiler for goldilocks with at most MAX_COMPILED_CODE_SIZE statements and
237
+ // the interpreter otherwise.
238
+ #[ allow( unused) ]
239
+ let interpreted = !matches ! ( T :: known_field( ) , Some ( KnownField :: GoldilocksField ) )
240
+ || code_size ( & result. code ) > MAX_COMPILED_CODE_SIZE ;
241
+
242
+ let interpreted = !matches ! ( T :: known_field( ) , Some ( KnownField :: GoldilocksField ) ) ;
243
+
244
+ if interpreted && has_prover_function_call {
245
+ log:: debug!( "Interpreter does not yet implement prover functions." ) ;
246
+ return None ;
247
+ }
248
+
223
249
let function = if interpreted {
224
250
log:: trace!( "Building effects interpreter..." ) ;
225
- WitgenFunction :: Interpreted ( EffectsInterpreter :: try_new ( & known_inputs, & code) ? )
251
+ WitgenFunction :: Interpreted ( EffectsInterpreter :: new ( & known_inputs, & result . code ) )
226
252
} else {
227
253
log:: trace!( "Compiling effects..." ) ;
228
254
WitgenFunction :: Compiled (
229
255
compile_effects (
230
256
self . fixed_data . analyzed ,
231
257
self . column_layout . clone ( ) ,
232
258
& known_inputs,
233
- & code,
259
+ & result . code ,
234
260
prover_functions,
235
261
)
236
262
. unwrap ( ) ,
@@ -240,7 +266,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
240
266
241
267
Some ( CacheEntry {
242
268
function,
243
- range_constraints,
269
+ range_constraints : result . range_constraints ,
244
270
} )
245
271
}
246
272
@@ -281,3 +307,33 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
281
307
Ok ( true )
282
308
}
283
309
}
310
+
311
+ /// Returns the elements in the code and thus a rough estimate of the number of steps
312
+ fn code_size < T : FieldElement > ( code : & [ Effect < T , Variable > ] ) -> usize {
313
+ code. iter ( )
314
+ . map ( |effect| match effect {
315
+ Effect :: Assignment ( ..)
316
+ | Effect :: Assertion ( ..)
317
+ | Effect :: MachineCall ( ..)
318
+ | Effect :: ProverFunctionCall ( ..) => 1 ,
319
+ Effect :: RangeConstraint ( ..) => unreachable ! ( ) ,
320
+ Effect :: Branch ( _, first, second) => code_size ( first) + code_size ( second) + 1 ,
321
+ } )
322
+ . sum ( )
323
+ }
324
+
325
+ /// Returns true if there is any prover function call in the code.
326
+ fn has_prover_function_call < ' a , T : FieldElement > (
327
+ code : impl IntoIterator < Item = & ' a Effect < T , Variable > > + ' a ,
328
+ ) -> bool {
329
+ code. into_iter ( ) . any ( |effect| match effect {
330
+ Effect :: ProverFunctionCall ( ..) => true ,
331
+ Effect :: Branch ( _, if_branch, else_branch) => {
332
+ has_prover_function_call ( if_branch) || has_prover_function_call ( else_branch)
333
+ }
334
+ Effect :: Assignment ( ..)
335
+ | Effect :: RangeConstraint ( ..)
336
+ | Effect :: Assertion ( ..)
337
+ | Effect :: MachineCall ( ..) => false ,
338
+ } )
339
+ }
0 commit comments