Skip to content

Commit e6c29c4

Browse files
pachecochriseth
andauthored
jit interpreter branch handling (#2481)
Co-authored-by: chriseth <[email protected]>
1 parent d7653ce commit e6c29c4

File tree

3 files changed

+391
-123
lines changed

3 files changed

+391
-123
lines changed

executor/src/witgen/jit/compiler.rs

+8-11
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,15 @@ use super::{
2929
variable::Variable,
3030
};
3131

32-
pub struct WitgenFunction<T> {
32+
pub struct CompiledFunction<T> {
3333
// TODO We might want to pass arguments as direct function parameters
3434
// (instead of a struct), so that
3535
// they are stored in registers instead of the stack. Should be checked.
3636
function: extern "C" fn(WitgenFunctionParams<T>),
3737
_library: Arc<Library>,
3838
}
3939

40-
impl<T: FieldElement> WitgenFunction<T> {
41-
/// Call the witgen function to fill the data and "known" tables
42-
/// given a slice of parameters.
43-
/// The `row_offset` is the index inside `data` of the row considered to be "row zero".
44-
/// This function always succeeds (unless it panics).
40+
impl<T: FieldElement> CompiledFunction<T> {
4541
pub fn call<Q: QueryCallback<T>>(
4642
&self,
4743
fixed_data: &FixedData<'_, T>,
@@ -51,7 +47,7 @@ impl<T: FieldElement> WitgenFunction<T> {
5147
) {
5248
let row_offset = data.row_offset.try_into().unwrap();
5349
let (data, known) = data.as_mut_slices();
54-
(self.function)(WitgenFunctionParams {
50+
let params = WitgenFunctionParams {
5551
data: data.into(),
5652
known: known.as_mut_ptr(),
5753
row_offset,
@@ -62,7 +58,8 @@ impl<T: FieldElement> WitgenFunction<T> {
6258
get_fixed_value: get_fixed_value::<T>,
6359
input_from_channel: input_from_channel::<T, Q>,
6460
output_to_channel: output_to_channel::<T, Q>,
65-
});
61+
};
62+
(self.function)(params);
6663
}
6764
}
6865

@@ -119,7 +116,7 @@ pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(
119116
known_inputs: &[Variable],
120117
effects: &[Effect<T, Variable>],
121118
prover_functions: Vec<ProverFunction<'_, T>>,
122-
) -> Result<WitgenFunction<T>, String> {
119+
) -> Result<CompiledFunction<T>, String> {
123120
let utils = util_code::<T>()?;
124121
let interface = interface_code(column_layout);
125122
let mut codegen = CodeGenerator::<T, _>::new(definitions);
@@ -153,7 +150,7 @@ pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(
153150

154151
let library = Arc::new(unsafe { libloading::Library::new(&lib_path.path).unwrap() });
155152
let witgen_fun = unsafe { library.get(b"witgen\0") }.unwrap();
156-
Ok(WitgenFunction {
153+
Ok(CompiledFunction {
157154
function: *witgen_fun,
158155
_library: library,
159156
})
@@ -660,7 +657,7 @@ mod tests {
660657
column_count: usize,
661658
known_inputs: &[Variable],
662659
effects: &[Effect<GoldilocksField, Variable>],
663-
) -> Result<WitgenFunction<GoldilocksField>, String> {
660+
) -> Result<CompiledFunction<GoldilocksField>, String> {
664661
super::compile_effects(
665662
&NoDefinitions,
666663
ColumnLayout {

executor/src/witgen/jit/function_cache.rs

+53-13
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ use crate::witgen::{
1717

1818
use super::{
1919
block_machine_processor::BlockMachineProcessor,
20-
compiler::{compile_effects, WitgenFunction},
20+
compiler::{compile_effects, CompiledFunction},
21+
interpreter::EffectsInterpreter,
2122
variable::Variable,
2223
witgen_inference::CanProcessCall,
2324
};
@@ -44,8 +45,36 @@ pub struct FunctionCache<'a, T: FieldElement> {
4445
parts: MachineParts<'a, T>,
4546
}
4647

48+
enum WitgenFunction<T: FieldElement> {
49+
Compiled(CompiledFunction<T>),
50+
Interpreted(EffectsInterpreter<T>),
51+
}
52+
53+
impl<T: FieldElement> WitgenFunction<T> {
54+
/// Call the witgen function to fill the data and "known" tables
55+
/// given a slice of parameters.
56+
/// The `row_offset` is the index inside `data` of the row considered to be "row zero".
57+
/// This function always succeeds (unless it panics).
58+
pub fn call<Q: QueryCallback<T>>(
59+
&self,
60+
fixed_data: &FixedData<'_, T>,
61+
mutable_state: &MutableState<'_, T, Q>,
62+
params: &mut [LookupCell<T>],
63+
data: CompactDataRef<'_, T>,
64+
) {
65+
match self {
66+
WitgenFunction::Compiled(compiled_function) => {
67+
compiled_function.call(fixed_data, mutable_state, params, data);
68+
}
69+
WitgenFunction::Interpreted(interpreter) => {
70+
interpreter.call::<Q>(fixed_data, mutable_state, params, data)
71+
}
72+
}
73+
}
74+
}
75+
4776
pub struct CacheEntry<T: FieldElement> {
48-
pub function: WitgenFunction<T>,
77+
function: WitgenFunction<T>,
4978
pub range_constraints: Vec<RangeConstraint<T>>,
5079
}
5180

@@ -108,13 +137,16 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
108137
if !self.witgen_functions.contains_key(cache_key) {
109138
record_start("Auto-witgen code derivation");
110139
let f = match T::known_field() {
111-
// Currently, we only support the Goldilocks fields
140+
// TODO: Currently, code generation only supports the Goldilocks
141+
// fields. We can't enable the interpreter for non-goldilocks
142+
// fields due to a limitation of autowitgen.
112143
Some(KnownField::GoldilocksField) => {
113-
self.compile_witgen_function(can_process, cache_key)
144+
self.compile_witgen_function(can_process, cache_key, false)
114145
}
115146
_ => None,
116147
};
117148
assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none());
149+
118150
record_end("Auto-witgen code derivation");
119151
}
120152
self.witgen_functions.get(cache_key).unwrap()
@@ -124,6 +156,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
124156
&self,
125157
can_process: impl CanProcessCall<T>,
126158
cache_key: &CacheKey<T>,
159+
interpreted: bool,
127160
) -> Option<CacheEntry<T>> {
128161
log::debug!(
129162
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
@@ -187,15 +220,22 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
187220
.filter_map(|(i, b)| if b { Some(Variable::Param(i)) } else { None })
188221
.collect::<Vec<_>>();
189222

190-
log::trace!("Compiling effects...");
191-
let function = compile_effects(
192-
self.fixed_data.analyzed,
193-
self.column_layout.clone(),
194-
&known_inputs,
195-
&code,
196-
prover_functions,
197-
)
198-
.unwrap();
223+
let function = if interpreted {
224+
log::trace!("Building effects interpreter...");
225+
WitgenFunction::Interpreted(EffectsInterpreter::try_new(&known_inputs, &code)?)
226+
} else {
227+
log::trace!("Compiling effects...");
228+
WitgenFunction::Compiled(
229+
compile_effects(
230+
self.fixed_data.analyzed,
231+
self.column_layout.clone(),
232+
&known_inputs,
233+
&code,
234+
prover_functions,
235+
)
236+
.unwrap(),
237+
)
238+
};
199239
log::trace!("Compilation done.");
200240

201241
Some(CacheEntry {

0 commit comments

Comments
 (0)