Skip to content

Commit 4fa031b

Browse files
committed
Compile prover functions.
1 parent 927a87b commit 4fa031b

File tree

5 files changed

+69
-11
lines changed

5 files changed

+69
-11
lines changed

executor/src/witgen/jit/block_machine_processor.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::witgen::{
1313

1414
use super::{
1515
processor::ProcessorResult,
16+
prover_function_heuristics::ProverFunction,
1617
variable::{Cell, Variable},
1718
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
1819
};
@@ -50,7 +51,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
5051
can_process: impl CanProcessCall<T>,
5152
identity_id: u64,
5253
known_args: &BitVec,
53-
) -> Result<ProcessorResult<T>, String> {
54+
) -> Result<(ProcessorResult<T>, Vec<ProverFunction<'a, T>>), String> {
5455
let connection = self.machine_parts.connections[&identity_id];
5556
assert_eq!(connection.right.expressions.len(), known_args.len());
5657

@@ -119,7 +120,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
119120
.iter()
120121
.enumerate()
121122
.filter_map(|(i, is_input)| (!is_input).then_some(Variable::Param(i)));
122-
Processor::new(
123+
let result= Processor::new(
123124
self.fixed_data,
124125
self,
125126
identities,
@@ -129,7 +130,12 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
129130
.with_block_shape_check()
130131
.with_block_size(self.block_size)
131132
.with_requested_range_constraints((0..known_args.len()).map(Variable::Param))
132-
.with_prover_functions(prover_functions)
133+
.with_prover_functions(
134+
prover_functions
135+
.iter()
136+
.flat_map(|f| (0..self.block_size).map(move |row| (f.clone(), row as i32)))
137+
.collect_vec()
138+
)
133139
.generate_code(can_process, witgen)
134140
.map_err(|e| {
135141
let err_str = e.to_string_with_variable_formatter(|var| match var {
@@ -149,7 +155,8 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
149155
.take(10)
150156
.format("\n ");
151157
format!("Code generation failed: {shortened_error}\nRun with RUST_LOG=trace to see the code generated so far.")
152-
})
158+
})?;
159+
Ok((result, prover_functions))
153160
}
154161
}
155162

executor/src/witgen/jit/compiler.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use powdr_ast::{
66
analyzed::{PolyID, PolynomialType},
77
indent,
88
};
9-
use powdr_jit_compiler::util_code::util_code;
9+
use powdr_jit_compiler::{util_code::util_code, CodeGenerator, DefinitionFetcher};
1010
use powdr_number::FieldElement;
1111

1212
use crate::witgen::{
@@ -23,6 +23,7 @@ use crate::witgen::{
2323

2424
use super::{
2525
effect::{Assertion, BranchCondition, Effect, ProverFunctionCall},
26+
prover_function_heuristics::ProverFunction,
2627
symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator},
2728
variable::Variable,
2829
};
@@ -87,17 +88,28 @@ extern "C" fn call_machine<T: FieldElement, Q: QueryCallback<T>>(
8788
}
8889

8990
/// Compile the given inferred effects into machine code and load it.
90-
pub fn compile_effects<T: FieldElement>(
91+
pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(
92+
definitions: &D,
9193
column_layout: ColumnLayout,
9294
known_inputs: &[Variable],
9395
effects: &[Effect<T, Variable>],
96+
prover_functions: Vec<ProverFunction<T>>,
9497
) -> Result<WitgenFunction<T>, String> {
9598
let utils = util_code::<T>()?;
9699
let interface = interface_code(column_layout);
100+
let mut codegen = CodeGenerator::new(definitions);
101+
let prover_functions = prover_functions
102+
.iter()
103+
.map(|f| prover_function_code(f, &mut codegen))
104+
.collect::<Result<Vec<_>, _>>()?
105+
.into_iter()
106+
.format("\n");
97107
let witgen_code = witgen_code(known_inputs, effects);
98108
let code = format!(
99109
"{utils}\n\
100110
//-------------------------------\n\
111+
{prover_functions}\n\
112+
//-------------------------------\n\
101113
{interface}\n\
102114
//-------------------------------\n\
103115
{witgen_code}"
@@ -522,11 +534,31 @@ fn interface_code(column_layout: ColumnLayout) -> String {
522534
)
523535
}
524536

537+
fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
538+
f: &ProverFunction<'_, T>,
539+
codegen: &mut CodeGenerator<'_, T, D>,
540+
) -> Result<String, String> {
541+
let ProverFunction::ComputeFrom(f) = f else {
542+
return Err("ProvideIfUnknown functions are not supported".to_string());
543+
};
544+
545+
let code = codegen.generate_code_for_expresson(f.computation)?;
546+
547+
let index = f.index;
548+
Ok(format!(
549+
"fn prover_function_{index}(i: u64, args: &[FieldElement]) -> FieldElement {{\n\
550+
let i: ibig::IBig = i.into();\n\
551+
({code}).call(args.to_vec().into())\n\
552+
}}"
553+
))
554+
}
555+
525556
#[cfg(test)]
526557
mod tests {
527558

528559
use std::ptr::null;
529560

561+
use powdr_ast::analyzed::FunctionValueDefinition;
530562
use pretty_assertions::assert_eq;
531563
use test_log::test;
532564

@@ -538,18 +570,27 @@ mod tests {
538570

539571
use super::*;
540572

573+
struct NoDefinitions;
574+
impl DefinitionFetcher for NoDefinitions {
575+
fn get_definition(&self, _: &str) -> Option<&FunctionValueDefinition> {
576+
None
577+
}
578+
}
579+
541580
fn compile_effects(
542581
column_count: usize,
543582
known_inputs: &[Variable],
544583
effects: &[Effect<GoldilocksField, Variable>],
545584
) -> Result<WitgenFunction<GoldilocksField>, String> {
546585
super::compile_effects(
586+
&NoDefinitions,
547587
ColumnLayout {
548588
column_count,
549589
first_column_id: 0,
550590
},
551591
known_inputs,
552592
effects,
593+
vec![],
553594
)
554595
}
555596

executor/src/witgen/jit/function_cache.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,13 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
114114
cache_key.known_args
115115
);
116116

117-
let ProcessorResult {
118-
code,
119-
range_constraints,
120-
} = self
117+
let (
118+
ProcessorResult {
119+
code,
120+
range_constraints,
121+
},
122+
prover_functions,
123+
) = self
121124
.processor
122125
.generate_code(can_process, cache_key.identity_id, &cache_key.known_args)
123126
.map_err(|e| {

jit-compiler/src/codegen.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ impl<'a, T: FieldElement, Def: DefinitionFetcher> CodeGenerator<'a, T, Def> {
8080
Ok(self.symbol_reference(name, type_args))
8181
}
8282

83+
/// Generates code for an isolated expression. This might request code generation
84+
/// for referenced symbols, this the returned code is only valid code in connection with
85+
/// the code returned by `self.generated_code`.
86+
pub fn generate_code_for_expresson(&mut self, e: &Expression) -> Result<String, String> {
87+
self.format_expr(e, 0)
88+
}
89+
8390
/// Returns the concatenation of all successfully compiled symbols.
8491
pub fn generated_code(self) -> String {
8592
self.symbols

jit-compiler/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ use std::{
88
sync::Arc,
99
};
1010

11-
use codegen::CodeGenerator;
1211
use compiler::{generate_glue_code, load_library};
1312

1413
use itertools::Itertools;
1514
use powdr_ast::analyzed::Analyzed;
1615
use powdr_number::FieldElement;
1716

17+
pub use codegen::{CodeGenerator, DefinitionFetcher};
1818
pub use compiler::call_cargo;
1919

2020
pub struct CompiledPIL {

0 commit comments

Comments
 (0)