Skip to content

Commit 77fdf21

Browse files
committed
Merge branch 'correct-rust-tools-on-ci' into fix-cache
2 parents a3ed2b3 + 671095c commit 77fdf21

File tree

4 files changed

+89
-47
lines changed

4 files changed

+89
-47
lines changed

.github/workflows/build-cache.yml

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ jobs:
1919
##### The block below is shared between cache build and PR build workflows #####
2020
- name: Install EStarkPolygon prover dependencies
2121
run: sudo apt-get update && sudo apt-get install -y nlohmann-json3-dev libpqxx-dev nasm libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc uuid-dev build-essential cmake pkg-config git
22+
- name: Clean stale rust installation
23+
run: rm -rf ~/.cargo/bin/rust-analyzer ~/.cargo/bin/rustfmt ~/.cargo/bin/cargo-fmt
2224
- name: Install Rust toolchain nightly-2024-12-17 (with clippy and rustfmt)
2325
run: rustup toolchain install nightly-2024-12-17-x86_64-unknown-linux-gnu && rustup component add clippy --toolchain nightly-2024-12-17-x86_64-unknown-linux-gnu && rustup component add rustfmt --toolchain nightly-2024-12-17-x86_64-unknown-linux-gnu
2426
- name: Install Rust toolchain 1.81 (stable)

.github/workflows/pr-tests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ jobs:
4444
##### The block below is shared between cache build and PR build workflows #####
4545
- name: Install EStarkPolygon prover dependencies
4646
run: sudo apt-get update && sudo apt-get install -y nlohmann-json3-dev libpqxx-dev nasm libgrpc++-dev uuid-dev
47+
- name: Clean stale rust installation
48+
run: rm -rf ~/.cargo/bin/rust-analyzer ~/.cargo/bin/rustfmt ~/.cargo/bin/cargo-fmt
4749
- name: Install Rust toolchain nightly-2024-12-17 (with clippy and rustfmt)
4850
run: rustup toolchain install nightly-2024-12-17-x86_64-unknown-linux-gnu && rustup component add clippy --toolchain nightly-2024-12-17-x86_64-unknown-linux-gnu && rustup component add rustfmt --toolchain nightly-2024-12-17-x86_64-unknown-linux-gnu
4951
- name: Install Rust toolchain 1.81 (stable)

executor/src/witgen/jit/function_cache.rs

+82-26
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,18 @@ use crate::witgen::{
1818
use super::{
1919
block_machine_processor::BlockMachineProcessor,
2020
compiler::{compile_effects, CompiledFunction},
21+
effect::Effect,
2122
interpreter::EffectsInterpreter,
23+
prover_function_heuristics::ProverFunction,
2224
variable::Variable,
2325
witgen_inference::CanProcessCall,
2426
};
2527

28+
/// Inferred witness generation routines that are larger than
29+
/// this number of "statements" will use the interpreter instead of the compiler
30+
/// due to the large compilation ressources required.
31+
const MAX_COMPILED_CODE_SIZE: usize = 500;
32+
2633
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
2734
struct CacheKey<T: FieldElement> {
2835
bus_id: T,
@@ -136,28 +143,26 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
136143
) -> &Option<CacheEntry<T>> {
137144
if !self.witgen_functions.contains_key(cache_key) {
138145
record_start("Auto-witgen code derivation");
139-
let f = match T::known_field() {
140-
// TODO: Currently, code generation only supports the Goldilocks
141-
// fields. We can't enable the interpreter for non-goldilocks
142-
// fields due to a limitation of autowitgen.
143-
Some(KnownField::GoldilocksField) => {
144-
self.compile_witgen_function(can_process, cache_key, false)
145-
}
146-
_ => None,
147-
};
148-
assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none());
146+
let compiled = self
147+
.derive_witgen_function(can_process, cache_key)
148+
.and_then(|(result, prover_functions)| {
149+
self.compile_witgen_function(result, prover_functions, cache_key)
150+
});
151+
assert!(self
152+
.witgen_functions
153+
.insert(cache_key.clone(), compiled)
154+
.is_none());
149155

150156
record_end("Auto-witgen code derivation");
151157
}
152158
self.witgen_functions.get(cache_key).unwrap()
153159
}
154160

155-
fn compile_witgen_function(
161+
fn derive_witgen_function(
156162
&self,
157163
can_process: impl CanProcessCall<T>,
158164
cache_key: &CacheKey<T>,
159-
interpreted: bool,
160-
) -> Option<CacheEntry<T>> {
165+
) -> Option<(ProcessorResult<T>, Vec<ProverFunction<'a, T>>)> {
161166
log::debug!(
162167
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
163168
self.machine_name,
@@ -169,13 +174,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
169174
.unwrap_or_default()
170175
);
171176

172-
let (
173-
ProcessorResult {
174-
code,
175-
range_constraints,
176-
},
177-
prover_functions,
178-
) = self
177+
let (processor_result, prover_functions) = self
179178
.processor
180179
.generate_code(
181180
can_process,
@@ -193,7 +192,8 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
193192
.ok()?;
194193

195194
log::debug!("=> Success!");
196-
let out_of_bounds_vars = code
195+
let out_of_bounds_vars = processor_result
196+
.code
197197
.iter()
198198
.flat_map(|effect| effect.referenced_variables())
199199
.filter_map(|var| match var {
@@ -203,7 +203,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
203203
.filter(|cell| cell.row_offset < -1 || cell.row_offset >= self.block_size as i32)
204204
.collect_vec();
205205
if !out_of_bounds_vars.is_empty() {
206-
log::debug!("Code:\n{}", format_code(&code));
206+
log::debug!("Code:\n{}", format_code(&processor_result.code));
207207
panic!(
208208
"Expected JITed code to only reference cells in the block + the last row \
209209
of the previous block, i.e. rows -1 until (including) {}, but it does reference the following:\n{}",
@@ -212,25 +212,51 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
212212
);
213213
}
214214

215-
log::trace!("Generated code ({} steps)", code.len());
215+
log::trace!("Generated code ({} steps)", processor_result.code.len());
216+
Some((processor_result, prover_functions))
217+
}
218+
219+
fn compile_witgen_function(
220+
&self,
221+
result: ProcessorResult<T>,
222+
prover_functions: Vec<ProverFunction<'a, T>>,
223+
cache_key: &CacheKey<T>,
224+
) -> Option<CacheEntry<T>> {
216225
let known_inputs = cache_key
217226
.known_args
218227
.iter()
219228
.enumerate()
220229
.filter_map(|(i, b)| if b { Some(Variable::Param(i)) } else { None })
221230
.collect::<Vec<_>>();
222231

232+
let has_prover_function_call = has_prover_function_call(&result.code);
233+
234+
// TODO This is the goal, but we need to implement prover unctions for the interpreter first.
235+
236+
// Use the compiler for goldilocks with at most MAX_COMPILED_CODE_SIZE statements and
237+
// the interpreter otherwise.
238+
#[allow(unused)]
239+
let interpreted = !matches!(T::known_field(), Some(KnownField::GoldilocksField))
240+
|| code_size(&result.code) > MAX_COMPILED_CODE_SIZE;
241+
242+
let interpreted = !matches!(T::known_field(), Some(KnownField::GoldilocksField));
243+
244+
if interpreted && has_prover_function_call {
245+
log::debug!("Interpreter does not yet implement prover functions.");
246+
return None;
247+
}
248+
223249
let function = if interpreted {
224250
log::trace!("Building effects interpreter...");
225-
WitgenFunction::Interpreted(EffectsInterpreter::try_new(&known_inputs, &code)?)
251+
WitgenFunction::Interpreted(EffectsInterpreter::new(&known_inputs, &result.code))
226252
} else {
227253
log::trace!("Compiling effects...");
228254
WitgenFunction::Compiled(
229255
compile_effects(
230256
self.fixed_data.analyzed,
231257
self.column_layout.clone(),
232258
&known_inputs,
233-
&code,
259+
&result.code,
234260
prover_functions,
235261
)
236262
.unwrap(),
@@ -240,7 +266,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
240266

241267
Some(CacheEntry {
242268
function,
243-
range_constraints,
269+
range_constraints: result.range_constraints,
244270
})
245271
}
246272

@@ -281,3 +307,33 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
281307
Ok(true)
282308
}
283309
}
310+
311+
/// Returns the elements in the code and thus a rough estimate of the number of steps
312+
fn code_size<T: FieldElement>(code: &[Effect<T, Variable>]) -> usize {
313+
code.iter()
314+
.map(|effect| match effect {
315+
Effect::Assignment(..)
316+
| Effect::Assertion(..)
317+
| Effect::MachineCall(..)
318+
| Effect::ProverFunctionCall(..) => 1,
319+
Effect::RangeConstraint(..) => unreachable!(),
320+
Effect::Branch(_, first, second) => code_size(first) + code_size(second) + 1,
321+
})
322+
.sum()
323+
}
324+
325+
/// Returns true if there is any prover function call in the code.
326+
fn has_prover_function_call<'a, T: FieldElement>(
327+
code: impl IntoIterator<Item = &'a Effect<T, Variable>> + 'a,
328+
) -> bool {
329+
code.into_iter().any(|effect| match effect {
330+
Effect::ProverFunctionCall(..) => true,
331+
Effect::Branch(_, if_branch, else_branch) => {
332+
has_prover_function_call(if_branch) || has_prover_function_call(else_branch)
333+
}
334+
Effect::Assignment(..)
335+
| Effect::RangeConstraint(..)
336+
| Effect::Assertion(..)
337+
| Effect::MachineCall(..) => false,
338+
})
339+
}

executor/src/witgen/jit/interpreter.rs

+3-21
Original file line numberDiff line numberDiff line change
@@ -89,25 +89,7 @@ enum MachineCallArgumentIdx {
8989
}
9090

9191
impl<T: FieldElement> EffectsInterpreter<T> {
92-
pub fn try_new(known_inputs: &[Variable], effects: &[Effect<T, Variable>]) -> Option<Self> {
93-
// TODO: interpreter doesn't support prover functions yet
94-
fn has_prover_fn<T: FieldElement>(effect: &Effect<T, Variable>) -> bool {
95-
match effect {
96-
Effect::ProverFunctionCall(..) => true,
97-
Effect::Branch(_, if_branch, else_branch) => {
98-
if if_branch.iter().any(has_prover_fn) || else_branch.iter().any(has_prover_fn)
99-
{
100-
return true;
101-
}
102-
false
103-
}
104-
_ => false,
105-
}
106-
}
107-
if effects.iter().any(has_prover_fn) {
108-
return None;
109-
}
110-
92+
pub fn new(known_inputs: &[Variable], effects: &[Effect<T, Variable>]) -> Self {
11193
let mut actions = vec![];
11294
let mut var_mapper = VariableMapper::new();
11395

@@ -121,7 +103,7 @@ impl<T: FieldElement> EffectsInterpreter<T> {
121103
actions,
122104
};
123105
assert!(actions_are_valid(&ret.actions, BTreeSet::new()));
124-
Some(ret)
106+
ret
125107
}
126108

127109
/// Returns an iterator of actions to load all accessed fixed column values into variables.
@@ -680,7 +662,7 @@ mod test {
680662
.unwrap();
681663

682664
// generate and call the interpreter
683-
let interpreter = EffectsInterpreter::try_new(&known_inputs, &result.code).unwrap();
665+
let interpreter = EffectsInterpreter::new(&known_inputs, &result.code);
684666

685667
Self {
686668
analyzed,

0 commit comments

Comments
 (0)