Skip to content

Commit c5202a9

Browse files
authored
Specialize functions for operation id. (#2440)
If one item in the RHS of the connection has `operation_id` in its name and has a known concrete value, create a specialized function for this value.
1 parent 5075011 commit c5202a9

File tree

4 files changed

+86
-21
lines changed

4 files changed

+86
-21
lines changed

executor/src/witgen/jit/block_machine_processor.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
5454
can_process: impl CanProcessCall<T>,
5555
identity_id: u64,
5656
known_args: &BitVec,
57+
known_concrete: Option<(usize, T)>,
5758
) -> Result<(ProcessorResult<T>, Vec<ProverFunction<'a, T>>), String> {
5859
let connection = self.machine_parts.connections[&identity_id];
5960
assert_eq!(connection.right.expressions.len(), known_args.len());
@@ -77,6 +78,15 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
7778
T::one(),
7879
));
7980

81+
if let Some((index, value)) = known_concrete {
82+
// Set the known argument to the concrete value.
83+
assignments.push(Assignment::assign_constant(
84+
&connection.right.expressions[index],
85+
self.latch_row as i32,
86+
value,
87+
));
88+
}
89+
8090
// Set all other selectors to 0 in the latch row.
8191
for other_connection in self.machine_parts.connections.values() {
8292
let other_selector = &other_connection.right.selector;
@@ -263,7 +273,7 @@ mod test {
263273
);
264274

265275
processor
266-
.generate_code(&mutable_state, connection_id, &known_values)
276+
.generate_code(&mutable_state, connection_id, &known_values, None)
267277
.map(|(result, _)| result)
268278
}
269279

executor/src/witgen/jit/function_cache.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ use super::{
2626
};
2727

2828
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29-
struct CacheKey {
29+
struct CacheKey<T: FieldElement> {
3030
identity_id: u64,
31+
/// If `Some((index, value))`, then this function is used only if the
32+
/// `index`th argument is set to `value`.
33+
known_concrete: Option<(usize, T)>,
3134
known_args: BitVec,
3235
}
3336

@@ -37,7 +40,7 @@ pub struct FunctionCache<'a, T: FieldElement> {
3740
processor: BlockMachineProcessor<'a, T>,
3841
/// The cache of JIT functions and the returned range constraints.
3942
/// If the entry is None, we attempted to generate the function but failed.
40-
witgen_functions: HashMap<CacheKey, Option<CacheEntry<T>>>,
43+
witgen_functions: HashMap<CacheKey<T>, Option<CacheEntry<T>>>,
4144
column_layout: ColumnLayout,
4245
block_size: usize,
4346
machine_name: String,
@@ -72,23 +75,26 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
7275
}
7376
}
7477

75-
/// Compiles the JIT function for the given identity and known arguments.
78+
/// Compiles the JIT function for the given identity and known arguments, and a potentially
79+
/// fully known argument.
7680
/// Returns the function and the output range constraints if the function was successfully compiled.
7781
pub fn compile_cached(
7882
&mut self,
7983
can_process: impl CanProcessCall<T>,
8084
identity_id: u64,
8185
known_args: &BitVec,
86+
known_concrete: Option<(usize, T)>,
8287
) -> Option<&CacheEntry<T>> {
8388
let cache_key = CacheKey {
8489
identity_id,
8590
known_args: known_args.clone(),
91+
known_concrete,
8692
};
8793
self.ensure_cache(can_process, &cache_key);
8894
self.witgen_functions.get(&cache_key).unwrap().as_ref()
8995
}
9096

91-
fn ensure_cache(&mut self, can_process: impl CanProcessCall<T>, cache_key: &CacheKey) {
97+
fn ensure_cache(&mut self, can_process: impl CanProcessCall<T>, cache_key: &CacheKey<T>) {
9298
if self.witgen_functions.contains_key(cache_key) {
9399
return;
94100
}
@@ -108,13 +114,17 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
108114
fn compile_witgen_function(
109115
&self,
110116
can_process: impl CanProcessCall<T>,
111-
cache_key: &CacheKey,
117+
cache_key: &CacheKey<T>,
112118
) -> Option<CacheEntry<T>> {
113119
log::debug!(
114-
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}",
120+
"Compiling JIT function for\n Machine: {}\n Connection: {}\n Inputs: {:?}{}",
115121
self.machine_name,
116122
self.parts.connections[&cache_key.identity_id],
117-
cache_key.known_args
123+
cache_key.known_args,
124+
cache_key
125+
.known_concrete
126+
.map(|(i, v)| format!("\n Input {i} = {v}"))
127+
.unwrap_or_default()
118128
);
119129

120130
let (
@@ -125,7 +135,12 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
125135
prover_functions,
126136
) = self
127137
.processor
128-
.generate_code(can_process, cache_key.identity_id, &cache_key.known_args)
138+
.generate_code(
139+
can_process,
140+
cache_key.identity_id,
141+
&cache_key.known_args,
142+
cache_key.known_concrete,
143+
)
129144
.map_err(|e| {
130145
// These errors can be pretty verbose and are quite common currently.
131146
log::debug!(
@@ -202,6 +217,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
202217
connection_id: u64,
203218
values: &mut [LookupCell<'c, T>],
204219
data: CompactDataRef<'d, T>,
220+
known_concrete: Option<(usize, T)>,
205221
) -> Result<bool, EvalError<T>> {
206222
let known_args = values
207223
.iter()
@@ -211,8 +227,12 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
211227
let cache_key = CacheKey {
212228
identity_id: connection_id,
213229
known_args,
230+
known_concrete,
214231
};
215232

233+
// TODO If the function is not in the cache, we should also try with
234+
// known_concrete set to None.
235+
216236
self.witgen_functions
217237
.get(&cache_key)
218238
.expect("Need to call compile_cached() first!")

executor/src/witgen/jit/interpreter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ mod test {
535535

536536
// TODO we cannot compile the prover functions here, but we can evaluate them.
537537
let (result, _prover_functions) = processor
538-
.generate_code(&mutable_state, connection_id, &known_values)
538+
.generate_code(&mutable_state, connection_id, &known_values, None)
539539
.unwrap();
540540

541541
let known_inputs = (0..12).map(Variable::Param).collect::<Vec<_>>();

executor/src/witgen/machines/block_machine.rs

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,17 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
170170
can_process: impl CanProcessCall<T>,
171171
identity_id: u64,
172172
known_arguments: &BitVec,
173-
_range_constraints: &[RangeConstraint<T>],
173+
range_constraints: &[RangeConstraint<T>],
174174
) -> Option<Vec<RangeConstraint<T>>> {
175-
// We do not use the input range constraints because then we would need
176-
// to generate new code depending on the range constraints as well.
175+
// We use the input range constraints to see if there is a column
176+
// containing the substring "operation_id" which is constrained to a
177+
// single value and use that value as part of the cache key.
178+
let operation_id = self.find_operation_id(identity_id).and_then(|index| {
179+
let v = range_constraints[index].try_to_single_value()?;
180+
Some((index, v))
181+
});
177182
self.function_cache
178-
.compile_cached(can_process, identity_id, known_arguments)
183+
.compile_cached(can_process, identity_id, known_arguments, operation_id)
179184
.map(|r| r.range_constraints.clone())
180185
}
181186

@@ -189,12 +194,23 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
189194
return Err(EvalError::RowsExhausted(self.name.clone()));
190195
}
191196

197+
let operation_id =
198+
self.find_operation_id(identity_id)
199+
.and_then(|index| match &values[index] {
200+
LookupCell::Input(v) => Some((index, **v)),
201+
LookupCell::Output(_) => None,
202+
});
203+
192204
self.data.finalize_all();
193205
let data = self.data.append_new_finalized_rows(self.block_size);
194206

195-
let success =
196-
self.function_cache
197-
.process_lookup_direct(mutable_state, identity_id, values, data)?;
207+
let success = self.function_cache.process_lookup_direct(
208+
mutable_state,
209+
identity_id,
210+
values,
211+
data,
212+
operation_id,
213+
)?;
198214
assert!(success);
199215
self.block_count_jit += 1;
200216
Ok(true)
@@ -443,9 +459,13 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
443459
.iter()
444460
.map(|e| e.is_constant())
445461
.collect();
462+
let operation_id = self.find_operation_id(identity_id).and_then(|index| {
463+
let v = arguments[index].constant_value()?;
464+
Some((index, v))
465+
});
446466
if self
447467
.function_cache
448-
.compile_cached(mutable_state, identity_id, &known_inputs)
468+
.compile_cached(mutable_state, identity_id, &known_inputs, operation_id)
449469
.is_some()
450470
{
451471
let updates = self.process_lookup_via_jit(mutable_state, identity_id, outer_query)?;
@@ -510,26 +530,41 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
510530
identity_id: u64,
511531
outer_query: OuterQuery<'a, 'b, T>,
512532
) -> EvalResult<'a, T> {
513-
let mut values = CallerData::from(&outer_query);
514-
515533
assert!(
516534
(self.rows() + self.block_size as DegreeType) <= self.degree,
517535
"Block machine is full (this should have been checked before)"
518536
);
519537
self.data.finalize_all();
538+
539+
let mut values = CallerData::from(&outer_query);
540+
let mut lookup_cells = values.as_lookup_cells();
541+
let operation_id =
542+
self.find_operation_id(identity_id)
543+
.and_then(|index| match &lookup_cells[index] {
544+
LookupCell::Input(v) => Some((index, **v)),
545+
LookupCell::Output(_) => None,
546+
});
520547
let data = self.data.append_new_finalized_rows(self.block_size);
521548

522549
let success = self.function_cache.process_lookup_direct(
523550
mutable_state,
524551
identity_id,
525-
&mut values.as_lookup_cells(),
552+
&mut lookup_cells,
526553
data,
554+
operation_id,
527555
)?;
528556
assert!(success);
529557

530558
values.into()
531559
}
532560

561+
fn find_operation_id(&self, identity_id: u64) -> Option<usize> {
562+
let right = &self.parts.connections[&identity_id].right.expressions;
563+
right.iter().position(|r| {
564+
try_to_simple_poly(r).is_some_and(|poly| poly.name.contains("operation_id"))
565+
})
566+
}
567+
533568
fn process<'b, Q: QueryCallback<T>>(
534569
&self,
535570
mutable_state: &MutableState<'a, T, Q>,

0 commit comments

Comments
 (0)