Skip to content

Commit d039c7b

Browse files
authored
Witgen: Match machine calls by Bus ID (#2488)
With this PR, machines are no longer called by the send's identity ID, but by the bus ID. We still assume that a send maps to a bus ID statically, i.e., `BusSend::bus_id()` returns `Some(...)`. This can be relaxed in a future PR, allowing the receiver to be dynamic. I recommend starting the review by looking into the [changes of the `Machine` trait](https://github.com/powdr-labs/powdr/pull/2488/files#diff-9f434c590cca98f5e5198ee0a62044e930b78487ee751ab211a7ee641501e330).
1 parent ce10fe7 commit d039c7b

24 files changed

+370
-362
lines changed

executor/src/witgen/analysis/mod.rs

+14-9
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ use powdr_ast::{
1010
use powdr_number::FieldElement;
1111

1212
use super::{
13-
machines::{Connection, ConnectionKind},
14-
util::try_to_simple_poly,
13+
data_structures::identity::BusReceive, machines::ConnectionKind, util::try_to_simple_poly,
1514
FixedData,
1615
};
1716

@@ -20,14 +19,20 @@ use super::{
2019
/// On success, return the connection kind, block size and latch row.
2120
pub fn detect_connection_type_and_block_size<'a, T: FieldElement>(
2221
fixed_data: &'a FixedData<'a, T>,
23-
connections: &BTreeMap<u64, Connection<'a, T>>,
22+
receives: &BTreeMap<T, &'a BusReceive<T>>,
2423
) -> Option<(ConnectionKind, usize, usize)> {
2524
// TODO we should check that the other constraints/fixed columns are also periodic.
2625

2726
// Connecting identities should either all be permutations or all lookups.
28-
let connection_type = connections
27+
let connection_type = receives
2928
.values()
30-
.map(|id| id.kind)
29+
.map(|receive| {
30+
if receive.has_arbitrary_multiplicity() {
31+
ConnectionKind::Lookup
32+
} else {
33+
ConnectionKind::Permutation
34+
}
35+
})
3136
.unique()
3237
.exactly_one()
3338
.ok()?;
@@ -36,9 +41,9 @@ pub fn detect_connection_type_and_block_size<'a, T: FieldElement>(
3641
let (latch_row, block_size) = match connection_type {
3742
ConnectionKind::Lookup => {
3843
// We'd expect all RHS selectors to be fixed columns of the same period.
39-
connections
44+
receives
4045
.values()
41-
.map(|id| try_to_period(&id.right.selector, fixed_data))
46+
.map(|receive| try_to_period(&receive.selected_payload.selector, fixed_data))
4247
.unique()
4348
.exactly_one()
4449
.ok()??
@@ -54,8 +59,8 @@ pub fn detect_connection_type_and_block_size<'a, T: FieldElement>(
5459
.max_by_key(|&(_, period)| period)
5560
};
5661
let mut latch_candidates = BTreeSet::new();
57-
for id in connections.values() {
58-
collect_fixed_cols(&id.right.selector, &mut latch_candidates);
62+
for receive in receives.values() {
63+
collect_fixed_cols(&receive.selected_payload.selector, &mut latch_candidates);
5964
}
6065
if latch_candidates.is_empty() {
6166
(0, 1)

executor/src/witgen/data_structures/mutable_state.rs

+22-26
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,21 @@ use crate::witgen::{
1818
/// This struct uses interior mutability for accessing the machines.
1919
pub struct MutableState<'a, T: FieldElement, Q: QueryCallback<T>> {
2020
machines: Vec<RefCell<KnownMachine<'a, T>>>,
21-
identity_to_machine_index: BTreeMap<u64, usize>,
21+
bus_to_machine_index: BTreeMap<T, usize>,
2222
query_callback: &'a Q,
2323
}
2424

2525
impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
2626
pub fn new(machines: impl Iterator<Item = KnownMachine<'a, T>>, query_callback: &'a Q) -> Self {
2727
let machines: Vec<_> = machines.map(RefCell::new).collect();
28-
let identity_to_machine_index = machines
28+
let bus_to_machine_index = machines
2929
.iter()
3030
.enumerate()
31-
.flat_map(|(index, m)| {
32-
m.borrow()
33-
.identity_ids()
34-
.into_iter()
35-
.map(move |id| (id, index))
36-
})
31+
.flat_map(|(index, m)| m.borrow().bus_ids().into_iter().map(move |id| (id, index)))
3732
.collect();
3833
Self {
3934
machines,
40-
identity_to_machine_index,
35+
bus_to_machine_index,
4136
query_callback,
4237
}
4338
}
@@ -53,48 +48,49 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
5348

5449
pub fn can_process_call_fully(
5550
&self,
56-
identity_id: u64,
51+
bus_id: T,
5752
known_inputs: &BitVec,
5853
range_constraints: Vec<RangeConstraint<T>>,
5954
) -> (bool, Vec<RangeConstraint<T>>) {
60-
let mut machine = self.responsible_machine(identity_id).ok().unwrap();
61-
machine.can_process_call_fully(self, identity_id, known_inputs, range_constraints)
55+
let mut machine = self.responsible_machine(bus_id).ok().unwrap();
56+
machine.can_process_call_fully(self, bus_id, known_inputs, range_constraints)
6257
}
6358

6459
/// Call the machine responsible for the right-hand-side of an identity given its ID,
6560
/// the evaluated arguments and the caller's range constraints.
6661
pub fn call(
6762
&self,
68-
identity_id: u64,
63+
bus_id: T,
6964
arguments: &[AffineExpression<AlgebraicVariable<'a>, T>],
7065
range_constraints: &dyn RangeConstraintSet<AlgebraicVariable<'a>, T>,
7166
) -> EvalResult<'a, T> {
72-
self.responsible_machine(identity_id)?
73-
.process_plookup_timed(self, identity_id, arguments, range_constraints)
67+
self.responsible_machine(bus_id)?.process_plookup_timed(
68+
self,
69+
bus_id,
70+
arguments,
71+
range_constraints,
72+
)
7473
}
7574

7675
/// Call the machine responsible for the right-hand-side of an identity given its ID,
7776
/// use the direct interface.
7877
pub fn call_direct(
7978
&self,
80-
identity_id: u64,
79+
bus_id: T,
8180
values: &mut [LookupCell<'_, T>],
8281
) -> Result<bool, EvalError<T>> {
83-
self.responsible_machine(identity_id)?
84-
.process_lookup_direct_timed(self, identity_id, values)
82+
self.responsible_machine(bus_id)?
83+
.process_lookup_direct_timed(self, bus_id, values)
8584
}
8685

87-
fn responsible_machine(
88-
&self,
89-
identity_id: u64,
90-
) -> Result<RefMut<KnownMachine<'a, T>>, EvalError<T>> {
86+
fn responsible_machine(&self, bus_id: T) -> Result<RefMut<KnownMachine<'a, T>>, EvalError<T>> {
9187
let machine_index = *self
92-
.identity_to_machine_index
93-
.get(&identity_id)
94-
.unwrap_or_else(|| panic!("No executor machine matched identity ID: {identity_id}"));
88+
.bus_to_machine_index
89+
.get(&bus_id)
90+
.unwrap_or_else(|| panic!("No executor machine matched identity ID: {bus_id}"));
9591
self.machines[machine_index].try_borrow_mut().map_err(|_| {
9692
EvalError::RecursiveMachineCalls(format!(
97-
"Detected when processing identity with ID {identity_id}"
93+
"Detected when processing machine call with bus ID {bus_id}"
9894
))
9995
})
10096
}

executor/src/witgen/identity_processor.rs

+12-8
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
3737
) -> EvalResult<'a, T> {
3838
let result = match identity {
3939
Identity::Polynomial(identity) => self.process_polynomial_identity(identity, rows),
40-
Identity::BusSend(bus_interaction) => self.process_lookup_or_permutation(
41-
bus_interaction.identity_id,
40+
Identity::BusSend(bus_interaction) => self.process_machine_call(
41+
bus_interaction.bus_id().unwrap(),
4242
&bus_interaction.selected_payload,
4343
rows,
4444
),
@@ -65,9 +65,9 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
6565
}
6666
}
6767

68-
fn process_lookup_or_permutation(
68+
fn process_machine_call(
6969
&mut self,
70-
id: u64,
70+
bus_id: T,
7171
left: &'a powdr_ast::analyzed::SelectedExpressions<T>,
7272
rows: &RowPair<'_, 'a, T>,
7373
) -> EvalResult<'a, T> {
@@ -85,7 +85,7 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
8585
Err(incomplete_cause) => return Ok(EvalValue::incomplete(incomplete_cause)),
8686
};
8787

88-
self.mutable_state.call(id, &left, rows)
88+
self.mutable_state.call(bus_id, &left, rows)
8989
}
9090

9191
/// Handles the lookup that connects the current machine to the calling machine.
@@ -102,10 +102,10 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
102102
outer_query: &OuterQuery<'a, '_, T>,
103103
current_rows: &RowPair<'_, 'a, T>,
104104
) -> EvalResult<'a, T> {
105-
let right = outer_query.connection.right;
105+
let receive_payload = &outer_query.bus_receive.selected_payload;
106106
// sanity check that the right hand side selector is active
107107
current_rows
108-
.evaluate(&right.selector)
108+
.evaluate(&receive_payload.selector)
109109
.ok()
110110
.and_then(|affine_expression| affine_expression.constant_value())
111111
.and_then(|v| v.is_one().then_some(()))
@@ -116,7 +116,11 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> IdentityProcessor<'a, 'c, T,
116116

117117
let mut updates = EvalValue::complete(vec![]);
118118

119-
for (l, r) in outer_query.arguments.iter().zip(right.expressions.iter()) {
119+
for (l, r) in outer_query
120+
.arguments
121+
.iter()
122+
.zip(receive_payload.expressions.iter())
123+
{
120124
match current_rows.evaluate(r) {
121125
Ok(r) => {
122126
let result = (l.clone() - r).solve_with_range_constraints(&range_constraint)?;

executor/src/witgen/jit/block_machine_processor.rs

+28-28
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,20 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
4848
}
4949
}
5050

51-
/// Generates the JIT code for a given combination of connection and known arguments.
51+
/// Generates the JIT code for a given combination of bus and known arguments.
5252
/// Fails if it cannot solve for the outputs, or if any sub-machine calls cannot be completed.
5353
pub fn generate_code(
5454
&self,
5555
can_process: impl CanProcessCall<T>,
56-
identity_id: u64,
56+
bus_id: T,
5757
known_args: &BitVec,
5858
known_concrete: Option<(usize, T)>,
5959
) -> Result<(ProcessorResult<T>, Vec<ProverFunction<'a, T>>), String> {
60-
let connection = self.machine_parts.connections[&identity_id];
61-
assert_eq!(connection.right.expressions.len(), known_args.len());
60+
let bus_receive = self.machine_parts.bus_receives[&bus_id];
61+
assert_eq!(
62+
bus_receive.selected_payload.expressions.len(),
63+
known_args.len()
64+
);
6265

6366
// Set up WitgenInference with known arguments.
6467
let known_variables = known_args
@@ -73,7 +76,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
7376
let mut queue_items = vec![];
7477

7578
// In the latch row, set the RHS selector to 1.
76-
let selector = &connection.right.selector;
79+
let selector = &bus_receive.selected_payload.selector;
7780
queue_items.push(QueueItem::constant_assignment(
7881
selector,
7982
T::one(),
@@ -83,15 +86,15 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
8386
if let Some((index, value)) = known_concrete {
8487
// Set the known argument to the concrete value.
8588
queue_items.push(QueueItem::constant_assignment(
86-
&connection.right.expressions[index],
89+
&bus_receive.selected_payload.expressions[index],
8790
value,
8891
self.latch_row as i32,
8992
));
9093
}
9194

9295
// Set all other selectors to 0 in the latch row.
93-
for other_connection in self.machine_parts.connections.values() {
94-
let other_selector = &other_connection.right.selector;
96+
for other_receive in self.machine_parts.bus_receives.values() {
97+
let other_selector = &other_receive.selected_payload.selector;
9598
if other_selector != selector {
9699
queue_items.push(QueueItem::constant_assignment(
97100
other_selector,
@@ -102,7 +105,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
102105
}
103106

104107
// For each argument, connect the expression on the RHS with the formal parameter.
105-
for (index, expr) in connection.right.expressions.iter().enumerate() {
108+
for (index, expr) in bus_receive.selected_payload.expressions.iter().enumerate() {
106109
queue_items.push(QueueItem::variable_assignment(
107110
expr,
108111
Variable::Param(index),
@@ -175,14 +178,14 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
175178
.generate_code(can_process, witgen)
176179
.map_err(|e| {
177180
let err_str = e.to_string_with_variable_formatter(|var| match var {
178-
Variable::Param(i) => format!("{} (connection param)", &connection.right.expressions[*i]),
181+
Variable::Param(i) => format!("{} (receive param)", &bus_receive.selected_payload.expressions[*i]),
179182
_ => var.to_string(),
180183
});
181-
log::trace!("\nCode generation failed for connection:\n {connection}");
184+
log::trace!("\nCode generation failed for bus receive:\n {bus_receive}");
182185
let known_args_str = known_args
183186
.iter()
184187
.enumerate()
185-
.filter_map(|(i, b)| b.then_some(connection.right.expressions[i].to_string()))
188+
.filter_map(|(i, b)| b.then_some(bus_receive.selected_payload.expressions[i].to_string()))
186189
.join("\n ");
187190
log::trace!("Known arguments:\n {known_args_str}");
188191
log::trace!("Error:\n {err_str}");
@@ -306,7 +309,7 @@ fn written_rows_per_column<T: FieldElement>(
306309
})
307310
}
308311

309-
/// Returns, for each bus send ID, the collection of row offsets that have a machine call
312+
/// Returns, for each bus send *identity* ID, the collection of row offsets that have a machine call
310313
/// and if in all the calls or that row, all the arguments are known.
311314
/// Combines calls from branches.
312315
fn completed_rows_for_bus_send<T: FieldElement>(
@@ -330,21 +333,18 @@ fn fully_known_call<T: FieldElement>(e: &Effect<T, Variable>) -> bool {
330333
}
331334
}
332335

333-
/// Returns all machine calls (bus identity and row offset) found in the effect.
336+
/// Returns all machine calls (bus send identity ID and row offset) found in the effect.
334337
/// Recurses into branches.
335338
fn machine_calls<T: FieldElement>(
336339
e: &Effect<T, Variable>,
337340
) -> Box<dyn Iterator<Item = (u64, i32, &Effect<T, Variable>)> + '_> {
338341
match e {
339-
Effect::MachineCall(id, _, arguments) => match &arguments[0] {
342+
Effect::MachineCall(_, _, arguments) => match &arguments[0] {
340343
Variable::MachineCallParam(MachineCallVariable {
341344
identity_id,
342345
row_offset,
343346
..
344-
}) => {
345-
assert_eq!(*id, *identity_id);
346-
Box::new(std::iter::once((*identity_id, *row_offset, e)))
347-
}
347+
}) => Box::new(std::iter::once((*identity_id, *row_offset, e))),
348348
_ => panic!("Expected machine call variable."),
349349
},
350350
Effect::Branch(_, first, second) => {
@@ -420,8 +420,8 @@ mod test {
420420
panic!("Expected exactly one matching block machine")
421421
};
422422
let (machine_parts, block_size, latch_row) = machine.machine_info();
423-
assert_eq!(machine_parts.connections.len(), 1);
424-
let connection_id = *machine_parts.connections.keys().next().unwrap();
423+
assert_eq!(machine_parts.bus_receives.len(), 1);
424+
let bus_id = *machine_parts.bus_receives.keys().next().unwrap();
425425
let processor = BlockMachineProcessor {
426426
fixed_data: &fixed_data,
427427
machine_parts: machine_parts.clone(),
@@ -440,7 +440,7 @@ mod test {
440440
);
441441

442442
processor
443-
.generate_code(&mutable_state, connection_id, &known_values, None)
443+
.generate_code(&mutable_state, bus_id, &known_values, None)
444444
.map(|(result, _)| result)
445445
}
446446

@@ -549,18 +549,18 @@ assert (main_binary::B[1] & 0xffffffffffff0000) == 0;
549549
call_var(9, 0, 2) = main_binary::B_byte[0];
550550
main_binary::B_byte[-1] = main_binary::B[0];
551551
call_var(9, -1, 2) = main_binary::B_byte[-1];
552-
machine_call(9, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]);
552+
machine_call(2, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]);
553553
main_binary::C_byte[-1] = call_var(9, -1, 3);
554554
main_binary::C[0] = main_binary::C_byte[-1];
555-
machine_call(9, [Known(call_var(9, 0, 0)), Known(call_var(9, 0, 1)), Known(call_var(9, 0, 2)), Unknown(call_var(9, 0, 3))]);
555+
machine_call(2, [Known(call_var(9, 0, 0)), Known(call_var(9, 0, 1)), Known(call_var(9, 0, 2)), Unknown(call_var(9, 0, 3))]);
556556
main_binary::C_byte[0] = call_var(9, 0, 3);
557557
main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
558-
machine_call(9, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]);
558+
machine_call(2, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]);
559559
main_binary::C_byte[1] = call_var(9, 1, 3);
560560
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
561561
main_binary::operation_id_next[2] = main_binary::operation_id[3];
562562
call_var(9, 2, 0) = main_binary::operation_id_next[2];
563-
machine_call(9, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]);
563+
machine_call(2, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]);
564564
main_binary::C_byte[2] = call_var(9, 2, 3);
565565
main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216));
566566
params[3] = main_binary::C[3];"
@@ -624,11 +624,11 @@ assert (SubM::a[0] & 0xffffffffffff0000) == 0;
624624
params[1] = SubM::b[0];
625625
params[2] = SubM::c[0];
626626
call_var(1, 0, 0) = SubM::c[0];
627-
machine_call(1, [Known(call_var(1, 0, 0))]);
627+
machine_call(2, [Known(call_var(1, 0, 0))]);
628628
SubM::b[1] = SubM::b[0];
629629
call_var(1, 1, 0) = SubM::b[1];
630630
SubM::c[1] = SubM::c[0];
631-
machine_call(1, [Known(call_var(1, 1, 0))]);
631+
machine_call(2, [Known(call_var(1, 1, 0))]);
632632
SubM::a[1] = ((SubM::b[1] * 256) + SubM::c[1]);"
633633
);
634634
}

0 commit comments

Comments
 (0)