@@ -18,11 +18,18 @@ use crate::witgen::{
1818use super :: {
1919 block_machine_processor:: BlockMachineProcessor ,
2020 compiler:: { compile_effects, CompiledFunction } ,
21+ effect:: Effect ,
2122 interpreter:: EffectsInterpreter ,
23+ prover_function_heuristics:: ProverFunction ,
2224 variable:: Variable ,
2325 witgen_inference:: CanProcessCall ,
2426} ;
2527
28+ /// Inferred witness generation routines that are larger than
29+ /// this number of "statements" will use the interpreter instead of the compiler
30+ /// due to the large compilation ressources required.
31+ const MAX_COMPILED_CODE_SIZE : usize = 500 ;
32+
2633#[ derive( Debug , Clone , Hash , PartialEq , Eq ) ]
2734struct CacheKey < T : FieldElement > {
2835 bus_id : T ,
@@ -136,28 +143,26 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
136143 ) -> & Option < CacheEntry < T > > {
137144 if !self . witgen_functions . contains_key ( cache_key) {
138145 record_start ( "Auto-witgen code derivation" ) ;
139- let f = match T :: known_field ( ) {
140- // TODO: Currently, code generation only supports the Goldilocks
141- // fields. We can't enable the interpreter for non-goldilocks
142- // fields due to a limitation of autowitgen.
143- Some ( KnownField :: GoldilocksField ) => {
144- self . compile_witgen_function ( can_process, cache_key, false )
145- }
146- _ => None ,
147- } ;
148- assert ! ( self . witgen_functions. insert( cache_key. clone( ) , f) . is_none( ) ) ;
146+ let compiled = self
147+ . derive_witgen_function ( can_process, cache_key)
148+ . and_then ( |( result, prover_functions) | {
149+ self . compile_witgen_function ( result, prover_functions, cache_key)
150+ } ) ;
151+ assert ! ( self
152+ . witgen_functions
153+ . insert( cache_key. clone( ) , compiled)
154+ . is_none( ) ) ;
149155
150156 record_end ( "Auto-witgen code derivation" ) ;
151157 }
152158 self . witgen_functions . get ( cache_key) . unwrap ( )
153159 }
154160
155- fn compile_witgen_function (
161+ fn derive_witgen_function (
156162 & self ,
157163 can_process : impl CanProcessCall < T > ,
158164 cache_key : & CacheKey < T > ,
159- interpreted : bool ,
160- ) -> Option < CacheEntry < T > > {
165+ ) -> Option < ( ProcessorResult < T > , Vec < ProverFunction < ' a , T > > ) > {
161166 log:: debug!(
162167 "Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}" ,
163168 self . machine_name,
@@ -169,13 +174,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
169174 . unwrap_or_default( )
170175 ) ;
171176
172- let (
173- ProcessorResult {
174- code,
175- range_constraints,
176- } ,
177- prover_functions,
178- ) = self
177+ let ( processor_result, prover_functions) = self
179178 . processor
180179 . generate_code (
181180 can_process,
@@ -193,7 +192,8 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
193192 . ok ( ) ?;
194193
195194 log:: debug!( "=> Success!" ) ;
196- let out_of_bounds_vars = code
195+ let out_of_bounds_vars = processor_result
196+ . code
197197 . iter ( )
198198 . flat_map ( |effect| effect. referenced_variables ( ) )
199199 . filter_map ( |var| match var {
@@ -203,7 +203,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
203203 . filter ( |cell| cell. row_offset < -1 || cell. row_offset >= self . block_size as i32 )
204204 . collect_vec ( ) ;
205205 if !out_of_bounds_vars. is_empty ( ) {
206- log:: debug!( "Code:\n {}" , format_code( & code) ) ;
206+ log:: debug!( "Code:\n {}" , format_code( & processor_result . code) ) ;
207207 panic ! (
208208 "Expected JITed code to only reference cells in the block + the last row \
209209 of the previous block, i.e. rows -1 until (including) {}, but it does reference the following:\n {}",
@@ -212,25 +212,51 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
212212 ) ;
213213 }
214214
215- log:: trace!( "Generated code ({} steps)" , code. len( ) ) ;
215+ log:: trace!( "Generated code ({} steps)" , processor_result. code. len( ) ) ;
216+ Some ( ( processor_result, prover_functions) )
217+ }
218+
219+ fn compile_witgen_function (
220+ & self ,
221+ result : ProcessorResult < T > ,
222+ prover_functions : Vec < ProverFunction < ' a , T > > ,
223+ cache_key : & CacheKey < T > ,
224+ ) -> Option < CacheEntry < T > > {
216225 let known_inputs = cache_key
217226 . known_args
218227 . iter ( )
219228 . enumerate ( )
220229 . filter_map ( |( i, b) | if b { Some ( Variable :: Param ( i) ) } else { None } )
221230 . collect :: < Vec < _ > > ( ) ;
222231
232+ let has_prover_function_call = has_prover_function_call ( & result. code ) ;
233+
234+ // TODO This is the goal, but we need to implement prover unctions for the interpreter first.
235+
236+ // Use the compiler for goldilocks with at most MAX_COMPILED_CODE_SIZE statements and
237+ // the interpreter otherwise.
238+ #[ allow( unused) ]
239+ let interpreted = !matches ! ( T :: known_field( ) , Some ( KnownField :: GoldilocksField ) )
240+ || code_size ( & result. code ) > MAX_COMPILED_CODE_SIZE ;
241+
242+ let interpreted = !matches ! ( T :: known_field( ) , Some ( KnownField :: GoldilocksField ) ) ;
243+
244+ if interpreted && has_prover_function_call {
245+ log:: debug!( "Interpreter does not yet implement prover functions." ) ;
246+ return None ;
247+ }
248+
223249 let function = if interpreted {
224250 log:: trace!( "Building effects interpreter..." ) ;
225- WitgenFunction :: Interpreted ( EffectsInterpreter :: try_new ( & known_inputs, & code) ? )
251+ WitgenFunction :: Interpreted ( EffectsInterpreter :: new ( & known_inputs, & result . code ) )
226252 } else {
227253 log:: trace!( "Compiling effects..." ) ;
228254 WitgenFunction :: Compiled (
229255 compile_effects (
230256 self . fixed_data . analyzed ,
231257 self . column_layout . clone ( ) ,
232258 & known_inputs,
233- & code,
259+ & result . code ,
234260 prover_functions,
235261 )
236262 . unwrap ( ) ,
@@ -240,7 +266,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
240266
241267 Some ( CacheEntry {
242268 function,
243- range_constraints,
269+ range_constraints : result . range_constraints ,
244270 } )
245271 }
246272
@@ -281,3 +307,33 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
281307 Ok ( true )
282308 }
283309}
310+
311+ /// Returns the elements in the code and thus a rough estimate of the number of steps
312+ fn code_size < T : FieldElement > ( code : & [ Effect < T , Variable > ] ) -> usize {
313+ code. iter ( )
314+ . map ( |effect| match effect {
315+ Effect :: Assignment ( ..)
316+ | Effect :: Assertion ( ..)
317+ | Effect :: MachineCall ( ..)
318+ | Effect :: ProverFunctionCall ( ..) => 1 ,
319+ Effect :: RangeConstraint ( ..) => unreachable ! ( ) ,
320+ Effect :: Branch ( _, first, second) => code_size ( first) + code_size ( second) + 1 ,
321+ } )
322+ . sum ( )
323+ }
324+
325+ /// Returns true if there is any prover function call in the code.
326+ fn has_prover_function_call < ' a , T : FieldElement > (
327+ code : impl IntoIterator < Item = & ' a Effect < T , Variable > > + ' a ,
328+ ) -> bool {
329+ code. into_iter ( ) . any ( |effect| match effect {
330+ Effect :: ProverFunctionCall ( ..) => true ,
331+ Effect :: Branch ( _, if_branch, else_branch) => {
332+ has_prover_function_call ( if_branch) || has_prover_function_call ( else_branch)
333+ }
334+ Effect :: Assignment ( ..)
335+ | Effect :: RangeConstraint ( ..)
336+ | Effect :: Assertion ( ..)
337+ | Effect :: MachineCall ( ..) => false ,
338+ } )
339+ }
0 commit comments