Skip to content

Commit 28307fb

Browse files
chrisethgeorgwiese
andauthored
Use regular queue for identity queue. (#2588)
This avoids cloning QueueItems and processes the items in a queue instead of always looking at the "smallest" activated item first. --------- Co-authored-by: Georg Wiese <[email protected]>
1 parent a991eeb commit 28307fb

File tree

4 files changed

+123
-55
lines changed

4 files changed

+123
-55
lines changed

executor/src/witgen/jit/block_machine_processor.rs

+26-26
Original file line numberDiff line numberDiff line change
@@ -512,29 +512,31 @@ main_binary::operation_id[3] = params[0];
512512
main_binary::A[3] = params[1];
513513
main_binary::B[3] = params[2];
514514
main_binary::operation_id[2] = main_binary::operation_id[3];
515-
main_binary::operation_id[1] = main_binary::operation_id[2];
516-
main_binary::operation_id[0] = main_binary::operation_id[1];
517-
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
518-
call_var(9, -1, 0) = main_binary::operation_id_next[-1];
519-
main_binary::operation_id_next[0] = main_binary::operation_id[1];
520-
call_var(9, 0, 0) = main_binary::operation_id_next[0];
521-
main_binary::operation_id_next[1] = main_binary::operation_id[2];
522-
call_var(9, 1, 0) = main_binary::operation_id_next[1];
523515
2**0 * main_binary::A[2] + 2**24 * main_binary::A_byte[2] := main_binary::A[3];
524-
call_var(9, 2, 1) = main_binary::A_byte[2];
525-
2**0 * main_binary::A[1] + 2**16 * main_binary::A_byte[1] := main_binary::A[2];
526-
call_var(9, 1, 1) = main_binary::A_byte[1];
527-
2**0 * main_binary::A[0] + 2**8 * main_binary::A_byte[0] := main_binary::A[1];
528-
call_var(9, 0, 1) = main_binary::A_byte[0];
529-
main_binary::A_byte[-1] = main_binary::A[0];
530-
call_var(9, -1, 1) = main_binary::A_byte[-1];
531516
2**0 * main_binary::B[2] + 2**24 * main_binary::B_byte[2] := main_binary::B[3];
517+
main_binary::operation_id_next[2] = main_binary::operation_id[3];
518+
call_var(9, 2, 1) = main_binary::A_byte[2];
532519
call_var(9, 2, 2) = main_binary::B_byte[2];
520+
call_var(9, 2, 0) = main_binary::operation_id_next[2];
521+
main_binary::operation_id[1] = main_binary::operation_id[2];
522+
2**0 * main_binary::A[1] + 2**16 * main_binary::A_byte[1] := main_binary::A[2];
533523
2**0 * main_binary::B[1] + 2**16 * main_binary::B_byte[1] := main_binary::B[2];
534-
call_var(9, 1, 2) = main_binary::B_byte[1];
524+
main_binary::operation_id_next[1] = main_binary::operation_id[2];
525+
main_binary::operation_id[0] = main_binary::operation_id[1];
526+
main_binary::operation_id_next[0] = main_binary::operation_id[1];
527+
2**0 * main_binary::A[0] + 2**8 * main_binary::A_byte[0] := main_binary::A[1];
528+
call_var(9, 1, 1) = main_binary::A_byte[1];
535529
2**0 * main_binary::B[0] + 2**8 * main_binary::B_byte[0] := main_binary::B[1];
536-
call_var(9, 0, 2) = main_binary::B_byte[0];
530+
call_var(9, 1, 2) = main_binary::B_byte[1];
531+
call_var(9, 1, 0) = main_binary::operation_id_next[1];
532+
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
533+
call_var(9, 0, 0) = main_binary::operation_id_next[0];
534+
main_binary::A_byte[-1] = main_binary::A[0];
535+
call_var(9, 0, 1) = main_binary::A_byte[0];
537536
main_binary::B_byte[-1] = main_binary::B[0];
537+
call_var(9, 0, 2) = main_binary::B_byte[0];
538+
call_var(9, -1, 0) = main_binary::operation_id_next[-1];
539+
call_var(9, -1, 1) = main_binary::A_byte[-1];
538540
call_var(9, -1, 2) = main_binary::B_byte[-1];
539541
machine_call(2, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]);
540542
main_binary::C_byte[-1] = call_var(9, -1, 3);
@@ -545,8 +547,6 @@ main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
545547
machine_call(2, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]);
546548
main_binary::C_byte[1] = call_var(9, 1, 3);
547549
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
548-
main_binary::operation_id_next[2] = main_binary::operation_id[3];
549-
call_var(9, 2, 0) = main_binary::operation_id_next[2];
550550
machine_call(2, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]);
551551
main_binary::C_byte[2] = call_var(9, 2, 3);
552552
main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216));
@@ -611,12 +611,12 @@ params[1] = Sub::b[0];"
611611
params[1] = SubM::b[0];
612612
params[2] = SubM::c[0];
613613
call_var(1, 0, 0) = SubM::c[0];
614-
machine_call(2, [Known(call_var(1, 0, 0))]);
614+
SubM::c[1] = SubM::c[0];
615615
SubM::b[1] = SubM::b[0];
616616
call_var(1, 1, 0) = SubM::b[1];
617-
SubM::c[1] = SubM::c[0];
618-
machine_call(2, [Known(call_var(1, 1, 0))]);
619-
SubM::a[1] = ((SubM::b[1] * 256) + SubM::c[1]);"
617+
SubM::a[1] = ((SubM::b[1] * 256) + SubM::c[1]);
618+
machine_call(2, [Known(call_var(1, 0, 0))]);
619+
machine_call(2, [Known(call_var(1, 1, 0))]);"
620620
);
621621
}
622622

@@ -670,15 +670,15 @@ machine_call(3, [Known(call_var(3, 0, 0))]);"
670670
code,
671671
"S::a[0] = params[0];
672672
S::b[0] = 0;
673-
params[1] = 0;
674673
S::b[1] = 0;
675674
S::c[0] = 1;
676-
params[2] = 1;
677675
S::b[2] = 0;
678676
S::c[1] = 1;
679677
S::b[3] = 8;
680678
S::c[2] = 1;
681-
S::c[3] = 9;"
679+
S::c[3] = 9;
680+
params[1] = 0;
681+
params[2] = 1;"
682682
);
683683
}
684684

executor/src/witgen/jit/identity_queue.rs

+86-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
collections::{BTreeSet, HashMap},
2+
collections::{BTreeSet, HashMap, VecDeque},
33
rc::Rc,
44
};
55

@@ -20,50 +20,116 @@ use super::{prover_function_heuristics::ProverFunction, variable::Variable};
2020
/// updates this list based on the occurrence of updated variables
2121
/// in identities.
2222
#[derive(Clone)]
23-
pub struct IdentityQueue<'a, T: FieldElement> {
24-
queue: BTreeSet<QueueItem<'a, T>>,
25-
occurrences: Rc<HashMap<Variable, Vec<QueueItem<'a, T>>>>,
23+
pub struct IdentityQueue<'ast, 'queue, T: FieldElement> {
24+
items: &'queue Vec<QueueItem<'ast, T>>,
25+
in_queue: Vec<bool>,
26+
identity_queue: VecDeque<usize>,
27+
/// This is a priority queue because we always want to process
28+
/// the "first" machine call that received a variable update
29+
/// in the hope that this results in the machine calls to be
30+
/// in source order, where possible.
31+
machine_call_queue: BTreeSet<usize>,
32+
prover_function_queue: VecDeque<usize>,
33+
/// Maps a variable to a list of indices in `items`, pointing to the items where they are referenced.
34+
occurrences: Rc<HashMap<Variable, Vec<usize>>>,
2635
}
2736

28-
impl<'a, T: FieldElement> IdentityQueue<'a, T> {
29-
pub fn new(items: impl IntoIterator<Item = QueueItem<'a, T>>) -> Self {
30-
let queue: BTreeSet<_> = items.into_iter().collect();
37+
impl<'ast, 'queue, T: FieldElement> IdentityQueue<'ast, 'queue, T> {
38+
/// Creates a new queue based on the given identities.
39+
/// The order of identities in this queue matters for the
40+
/// order in which they are processed.
41+
pub fn new(items: &'queue Vec<QueueItem<'ast, T>>) -> Self {
3142
let mut references = ReferencesComputer::default();
3243
let occurrences = Rc::new(
33-
queue
44+
items
3445
.iter()
35-
.flat_map(|item| {
46+
.enumerate()
47+
.flat_map(|(id, item)| {
3648
references
3749
.references(item)
3850
.iter()
39-
.map(|v| (v.clone(), item.clone()))
51+
.map(|v| (v.clone(), id))
4052
.collect_vec()
4153
})
4254
.into_group_map(),
4355
);
44-
Self { queue, occurrences }
56+
Self {
57+
items,
58+
in_queue: vec![true; items.len()],
59+
identity_queue: collect_filtered_indices(items, &is_polynomial_identity_or_assignment),
60+
machine_call_queue: collect_filtered_indices(items, &is_submachine_call),
61+
prover_function_queue: collect_filtered_indices(items, &is_prover_function),
62+
occurrences,
63+
}
4564
}
4665

4766
/// Returns the next identity to be processed and its row and
4867
/// removes it from the queue.
49-
pub fn next(&mut self) -> Option<QueueItem<'a, T>> {
50-
self.queue.pop_first()
68+
pub fn next(&mut self) -> Option<&'queue QueueItem<'ast, T>> {
69+
self.identity_queue
70+
.pop_front()
71+
.or_else(|| self.machine_call_queue.pop_first())
72+
.or_else(|| self.prover_function_queue.pop_front())
73+
.map(|id| {
74+
self.in_queue[id] = false;
75+
&self.items[id]
76+
})
5177
}
5278

5379
pub fn variables_updated(&mut self, variables: impl IntoIterator<Item = Variable>) {
5480
// Note that this will usually re-add the item that caused the update,
5581
// which is fine, since there are situations where we can further process
5682
// it from an update (for example a range constraint).
57-
self.queue.extend(
58-
variables
59-
.into_iter()
60-
.flat_map(|var| self.occurrences.get(&var))
61-
.flatten()
62-
.cloned(),
63-
)
83+
for id in variables
84+
.into_iter()
85+
.flat_map(|var| self.occurrences.get(&var))
86+
.flatten()
87+
{
88+
if !self.in_queue[*id] {
89+
self.in_queue[*id] = true;
90+
if is_polynomial_identity_or_assignment(&self.items[*id]) {
91+
self.identity_queue.push_back(*id);
92+
} else if is_submachine_call(&self.items[*id]) {
93+
self.machine_call_queue.insert(*id);
94+
} else {
95+
assert!(is_prover_function(&self.items[*id]));
96+
self.prover_function_queue.push_back(*id);
97+
}
98+
}
99+
}
64100
}
65101
}
66102

103+
fn is_polynomial_identity_or_assignment<T: FieldElement>(item: &QueueItem<'_, T>) -> bool {
104+
match item {
105+
QueueItem::Identity(Identity::Polynomial(..), _)
106+
| QueueItem::VariableAssignment(..)
107+
| QueueItem::ConstantAssignment(..) => true,
108+
QueueItem::Identity(Identity::BusSend(..), _) | QueueItem::ProverFunction(..) => false,
109+
QueueItem::Identity(Identity::Connect(..), _) => unreachable!(),
110+
}
111+
}
112+
113+
fn is_submachine_call<T: FieldElement>(item: &QueueItem<'_, T>) -> bool {
114+
matches!(item, QueueItem::Identity(Identity::BusSend(..), _))
115+
}
116+
117+
fn is_prover_function<T: FieldElement>(item: &QueueItem<'_, T>) -> bool {
118+
matches!(item, QueueItem::ProverFunction(..))
119+
}
120+
121+
/// Filters a slice by a boolean selector and returns a collection of the matching indices.
122+
fn collect_filtered_indices<D, F, C: FromIterator<usize>>(items: &[D], mut filter: F) -> C
123+
where
124+
F: FnMut(&D) -> bool,
125+
{
126+
items
127+
.iter()
128+
.enumerate()
129+
.filter_map(|(index, item)| if filter(item) { Some(index) } else { None })
130+
.collect()
131+
}
132+
67133
#[derive(Clone)]
68134
pub enum QueueItem<'a, T: FieldElement> {
69135
Identity(&'a Identity<T>, i32),

executor/src/witgen/jit/processor.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,17 @@ impl<'a, T: FieldElement> Processor<'a, T> {
9999
}
100100
}));
101101
let branch_depth = 0;
102-
let identity_queue = IdentityQueue::new(queue_items);
102+
// Sort the queue so that we have proper source order.
103+
queue_items.sort();
104+
let identity_queue = IdentityQueue::new(&queue_items);
103105
self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth)
104106
}
105107

106108
fn generate_code_for_branch<FixedEval: FixedEvaluator<T>>(
107109
&self,
108110
can_process: impl CanProcessCall<T>,
109111
mut witgen: WitgenInference<'a, T, FixedEval>,
110-
mut identity_queue: IdentityQueue<'a, T>,
112+
mut identity_queue: IdentityQueue<'a, '_, T>,
111113
branch_depth: usize,
112114
) -> Result<ProcessorResult<T>, Error<'a, T, FixedEval>> {
113115
if self
@@ -267,16 +269,16 @@ impl<'a, T: FieldElement> Processor<'a, T> {
267269
&self,
268270
can_process: impl CanProcessCall<T>,
269271
witgen: &mut WitgenInference<'a, T, FixedEval>,
270-
mut identity_queue: IdentityQueue<'a, T>,
272+
mut identity_queue: IdentityQueue<'a, '_, T>,
271273
) -> Result<(), affine_symbolic_expression::Error> {
272274
while let Some(item) = identity_queue.next() {
273275
let updated_vars = match item {
274276
QueueItem::Identity(identity, row_offset) => match identity {
275277
Identity::Polynomial(PolynomialIdentity { expression, .. }) => {
276-
witgen.process_equation_on_row(expression, None, 0.into(), row_offset)
278+
witgen.process_equation_on_row(expression, None, 0.into(), *row_offset)
277279
}
278280
Identity::BusSend(bus_send) => {
279-
witgen.process_call(can_process.clone(), bus_send, row_offset)
281+
witgen.process_call(can_process.clone(), bus_send, *row_offset)
280282
}
281283
Identity::Connect(..) => Ok(vec![]),
282284
},
@@ -293,7 +295,7 @@ impl<'a, T: FieldElement> Processor<'a, T> {
293295
assignment.row_offset,
294296
),
295297
QueueItem::ProverFunction(prover_function, row_offset) => {
296-
witgen.process_prover_function(&prover_function, row_offset)
298+
witgen.process_prover_function(prover_function, *row_offset)
297299
}
298300
}?;
299301
identity_queue.variables_updated(updated_vars);
@@ -358,7 +360,7 @@ impl<'a, T: FieldElement> Processor<'a, T> {
358360
incomplete_machine_calls: &[(&Identity<T>, i32)],
359361
can_process: impl CanProcessCall<T>,
360362
witgen: &mut WitgenInference<'a, T, FixedEval>,
361-
mut identity_queue: IdentityQueue<'a, T>,
363+
mut identity_queue: IdentityQueue<'a, '_, T>,
362364
) -> bool {
363365
let missing_sends_in_block = incomplete_machine_calls
364366
.iter()

executor/src/witgen/jit/single_step_processor.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ call_var(1, 0, 1) = VM::instr_add[0];
263263
call_var(1, 0, 2) = VM::instr_mul[0];
264264
call_var(1, 0, 0) = VM::pc[0];
265265
VM::pc[1] = (VM::pc[0] + 1);
266-
call_var(1, 1, 0) = VM::pc[1];
267266
VM::B[1] = VM::B[0];
267+
call_var(1, 1, 0) = VM::pc[1];
268268
machine_call(1, [Known(call_var(1, 1, 0)), Unknown(call_var(1, 1, 1)), Unknown(call_var(1, 1, 2))]);
269269
VM::instr_add[1] = call_var(1, 1, 1);
270270
VM::instr_mul[1] = call_var(1, 1, 2);
@@ -306,9 +306,9 @@ call_var(2, 0, 1) = VM::instr_add[0];
306306
call_var(2, 0, 2) = VM::instr_mul[0];
307307
call_var(2, 0, 0) = VM::pc[0];
308308
VM::pc[1] = VM::pc[0];
309-
call_var(2, 1, 0) = VM::pc[1];
310309
VM::instr_add[1] = 0;
311310
call_var(2, 1, 1) = 0;
311+
call_var(2, 1, 0) = VM::pc[1];
312312
call_var(2, 1, 2) = 1;
313313
machine_call(1, [Known(call_var(2, 1, 0)), Known(call_var(2, 1, 1)), Unknown(call_var(2, 1, 2))]);
314314
VM::instr_mul[1] = 1;"

0 commit comments

Comments
 (0)