Skip to content

Commit 3dab871

Browse files
authored
Rewrite multiplicity generator, remove Connection (#2538)
The multiplicity column generator was the last place where we used `Connection`. This PR rewrites it and removes `Connection`.
1 parent 199dd1a commit 3dab871

File tree

2 files changed

+60
-125
lines changed

2 files changed

+60
-125
lines changed

Diff for: executor/src/witgen/machines/mod.rs

+1-61
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::fmt::Display;
33

44
use bit_vec::BitVec;
55
use dynamic_machine::DynamicMachine;
6-
use powdr_ast::analyzed::{self, AlgebraicExpression, ContainsNextRef, DegreeRange, PolyID};
6+
use powdr_ast::analyzed::{self, ContainsNextRef, DegreeRange, PolyID};
77

88
use powdr_number::DegreeType;
99
use powdr_number::FieldElement;
@@ -244,66 +244,6 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
244244
}
245245
}
246246

247-
#[derive(Clone, Copy, Debug)]
248-
/// A connection is a witness generation directive to propagate rows across machines
249-
pub struct Connection<'a, T> {
250-
pub left: &'a analyzed::SelectedExpressions<T>,
251-
pub right: &'a analyzed::SelectedExpressions<T>,
252-
/// For [ConnectionKind::Permutation], rows of `left` are a permutation of rows of `right`. For [ConnectionKind::Lookup], all rows in `left` are in `right`.
253-
pub kind: ConnectionKind,
254-
/// If the connection comes from a phantom lookup, this is the multiplicity column.
255-
/// For each row of `right` it counts how often that row occurs in `left`.
256-
/// Note that multiple connections can share the same multiplicity column.
257-
pub multiplicity_column: Option<PolyID>,
258-
}
259-
260-
impl<T: Display> Display for Connection<'_, T> {
261-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262-
write!(f, "{} {} {}", self.left, self.kind, self.right)
263-
}
264-
}
265-
266-
impl<'a, T: FieldElement> Connection<'a, T> {
267-
/// Creates a connection if the identity is a bus send.
268-
pub fn try_new(
269-
identity: &'a Identity<T>,
270-
bus_receives: &'a BTreeMap<T, BusReceive<T>>,
271-
) -> Option<Self> {
272-
match identity {
273-
Identity::BusSend(bus_interaction) => {
274-
let send = bus_interaction;
275-
let receive = send
276-
.try_match_static(bus_receives)
277-
.expect("No matching receive!");
278-
let multiplicity_column = if receive.has_arbitrary_multiplicity() {
279-
receive
280-
.multiplicity
281-
.as_ref()
282-
.and_then(|multiplicity| match multiplicity {
283-
AlgebraicExpression::Reference(reference) => Some(reference.poly_id),
284-
// For receives of permutations, we would have complex expressions here.
285-
_ => None,
286-
})
287-
} else {
288-
// For permutations, the selector is already generated by "normal" witgen.
289-
None
290-
};
291-
Some(Connection {
292-
left: &send.selected_payload,
293-
right: &receive.selected_payload,
294-
kind: if receive.has_arbitrary_multiplicity() {
295-
ConnectionKind::Lookup
296-
} else {
297-
ConnectionKind::Permutation
298-
},
299-
multiplicity_column,
300-
})
301-
}
302-
_ => None,
303-
}
304-
}
305-
}
306-
307247
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
308248
pub enum ConnectionKind {
309249
Lookup,

Diff for: executor/src/witgen/multiplicity_column_generator.rs

+59-64
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
use std::collections::{BTreeMap, HashMap};
22

33
use powdr_ast::{
4-
analyzed::{AlgebraicExpression, PolynomialType, SelectedExpressions},
4+
analyzed::{AlgebraicExpression, PolyID, PolynomialType, SelectedExpressions},
55
parsed::visitor::AllChildren,
66
};
77
use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, OwnedTerminalValues};
88
use powdr_number::FieldElement;
99
use rayon::iter::{IntoParallelIterator, ParallelIterator};
1010

1111
use crate::witgen::{
12-
data_structures::identity::convert_identities,
13-
machines::{
14-
profiling::{record_end, record_start},
15-
Connection,
16-
},
12+
data_structures::identity::{convert_identities, Identity},
13+
machines::profiling::{record_end, record_start},
1714
};
1815

19-
use super::FixedData;
16+
use super::{util::try_to_simple_poly, FixedData};
2017

2118
static MULTIPLICITY_WITGEN_NAME: &str = "multiplicity witgen";
2219

@@ -30,28 +27,18 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
3027
}
3128

3229
/// Takes a map of witness columns and extends it with the multiplicity columns
33-
/// reference in the phantom lookups.
30+
/// referenced by bus sends with a non-binary multiplicity.
3431
pub fn generate(
3532
&self,
3633
witness_columns: HashMap<String, Vec<T>>,
3734
publics: BTreeMap<String, Option<T>>,
3835
) -> HashMap<String, Vec<T>> {
3936
record_start(MULTIPLICITY_WITGEN_NAME);
4037

41-
log::trace!("Starting multiplicity witness generation.");
42-
let start = std::time::Instant::now();
43-
44-
// Several range constraints might point to the same target
38+
// A map from multiplicity column ID to the vector of multiplicities.
4539
let mut multiplicity_columns = BTreeMap::new();
4640

4741
let (identities, _) = convert_identities(self.fixed.analyzed);
48-
let phantom_lookups = identities
49-
.iter()
50-
.filter_map(|identity| {
51-
Connection::try_new(identity, &self.fixed.bus_receives)
52-
.and_then(|connection| connection.multiplicity_column.map(|_| connection))
53-
})
54-
.collect::<Vec<_>>();
5542

5643
let all_columns = witness_columns
5744
.into_iter()
@@ -78,64 +65,65 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
7865
challenge_values: self.fixed.challenges.clone(),
7966
};
8067

81-
log::trace!(
82-
" Done building trace values, took: {}s",
83-
start.elapsed().as_secs_f64()
84-
);
85-
86-
for lookup in phantom_lookups {
87-
log::trace!(" Updating multiplicity for phantom lookup: {lookup}");
88-
let start = std::time::Instant::now();
89-
90-
let (rhs_machine_size, rhs_tuples) = self.get_tuples(&terminal_values, lookup.right);
91-
92-
log::trace!(
93-
" Done collecting RHS tuples, took {}s",
94-
start.elapsed().as_secs_f64()
95-
);
96-
let start = std::time::Instant::now();
97-
98-
let index = rhs_tuples
99-
.iter()
100-
.map(|(i, tuple)| {
101-
// There might be multiple identical rows, but it's fine, we can pick any.
102-
(tuple, *i)
103-
})
104-
.collect::<HashMap<_, _>>();
105-
106-
log::trace!(
107-
" Done building index, took {}s",
108-
start.elapsed().as_secs_f64()
109-
);
110-
let start = std::time::Instant::now();
68+
// Index all bus receives with arbitrary multiplicity.
69+
let receive_infos = self
70+
.fixed
71+
.bus_receives
72+
.iter()
73+
.filter(|(_, bus_receive)| {
74+
bus_receive.has_arbitrary_multiplicity() && bus_receive.multiplicity.is_some()
75+
})
76+
.map(|(bus_id, bus_receive)| {
77+
let (size, rhs_tuples) =
78+
self.get_tuples(&terminal_values, &bus_receive.selected_payload);
11179

112-
let (_, lhs_tuples) = self.get_tuples(&terminal_values, lookup.left);
80+
let index = rhs_tuples
81+
.into_iter()
82+
.map(|(i, tuple)| {
83+
// There might be multiple identical rows, but it's fine, we can pick any.
84+
(tuple, i)
85+
})
86+
.collect::<HashMap<_, _>>();
87+
88+
let multiplicity = bus_receive.multiplicity.as_ref().unwrap();
89+
(
90+
*bus_id,
91+
ReceiveInfo {
92+
multiplicity_column: try_to_simple_poly(multiplicity)
93+
.unwrap_or_else(|| {
94+
panic!("Expected simple reference, got: {multiplicity}")
95+
})
96+
.poly_id,
97+
size,
98+
index,
99+
},
100+
)
101+
})
102+
.collect::<BTreeMap<_, _>>();
113103

114-
log::trace!(
115-
" Done collecting LHS tuples, took: {}s",
116-
start.elapsed().as_secs_f64()
117-
);
118-
let start = std::time::Instant::now();
104+
// Increment multiplicities for all bus sends.
105+
for (bus_send, bus_receive) in identities.iter().filter_map(|i| match i {
106+
Identity::BusSend(bus_send) => receive_infos
107+
.get(&bus_send.bus_id().unwrap())
108+
.map(|bus_receive| (bus_send, bus_receive)),
109+
_ => None,
110+
}) {
111+
let (_, lhs_tuples) = self.get_tuples(&terminal_values, &bus_send.selected_payload);
119112

120-
let multiplicity_column_id = lookup.multiplicity_column.unwrap();
121113
let multiplicities = multiplicity_columns
122-
.entry(multiplicity_column_id)
123-
.or_insert_with(|| vec![0; rhs_machine_size]);
124-
assert_eq!(multiplicities.len(), rhs_machine_size);
114+
.entry(bus_receive.multiplicity_column)
115+
.or_insert_with(|| vec![0; bus_receive.size]);
116+
assert_eq!(multiplicities.len(), bus_receive.size);
125117

126118
// Looking up the index is slow, so we do it in parallel.
127119
let indices = lhs_tuples
128120
.into_par_iter()
129-
.map(|(_, tuple)| index[&tuple])
121+
.map(|(_, tuple)| bus_receive.index[&tuple])
130122
.collect::<Vec<_>>();
131123

132124
for index in indices {
133125
multiplicities[index] += 1;
134126
}
135-
log::trace!(
136-
" Done updating multiplicities, took: {}s",
137-
start.elapsed().as_secs_f64()
138-
);
139127
}
140128

141129
let columns = terminal_values
@@ -206,3 +194,10 @@ impl<'a, T: FieldElement> MultiplicityColumnGenerator<'a, T> {
206194
(machine_size, tuples)
207195
}
208196
}
197+
198+
struct ReceiveInfo<T> {
199+
multiplicity_column: PolyID,
200+
size: usize,
201+
/// Maps a tuple of values to its index in the trace.
202+
index: HashMap<Vec<T>, usize>,
203+
}

0 commit comments

Comments
 (0)