Skip to content

Commit f198341

Browse files
authored
Handle assignments in identity queue as well (#2453)
1 parent 2db83db commit f198341

File tree

4 files changed

+115
-61
lines changed

4 files changed

+115
-61
lines changed

ast/src/analyzed/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,16 @@ impl From<&AlgebraicReference> for AlgebraicReferenceThin {
12091209
}
12101210
}
12111211

1212+
impl AlgebraicReferenceThin {
1213+
pub fn with_name(&self, name: String) -> AlgebraicReference {
1214+
AlgebraicReference {
1215+
name,
1216+
poly_id: self.poly_id,
1217+
next: self.next,
1218+
}
1219+
}
1220+
}
1221+
12121222
#[derive(Debug, Clone, Eq, Serialize, Deserialize, JsonSchema)]
12131223
pub struct AlgebraicReference {
12141224
/// Name of the polynomial - just for informational purposes.

executor/src/witgen/jit/identity_queue.rs

+94-53
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ use std::{
55

66
use itertools::Itertools;
77
use powdr_ast::{
8-
analyzed::{
9-
AlgebraicExpression as Expression, AlgebraicReference, AlgebraicReferenceThin,
10-
PolynomialType,
11-
},
8+
analyzed::{AlgebraicExpression as Expression, AlgebraicReferenceThin, PolynomialType},
129
parsed::visitor::{AllChildren, Children},
1310
};
1411
use powdr_number::FieldElement;
1512

1613
use crate::witgen::{data_structures::identity::Identity, FixedData};
1714

18-
use super::variable::Variable;
15+
use super::{
16+
variable::Variable,
17+
witgen_inference::{Assignment, VariableOrValue},
18+
};
1919

2020
/// Keeps track of identities that still need to be processed and
2121
/// updates this list based on the occurrence of updated variables
@@ -27,98 +27,120 @@ pub struct IdentityQueue<'a, T: FieldElement> {
2727
}
2828

2929
impl<'a, T: FieldElement> IdentityQueue<'a, T> {
30-
pub fn new(fixed_data: &'a FixedData<'a, T>, identities: &[(&'a Identity<T>, i32)]) -> Self {
31-
let occurrences = compute_occurrences_map(fixed_data, identities).into();
32-
Self {
33-
queue: identities
34-
.iter()
35-
.map(|(id, row)| QueueItem(id, *row))
36-
.collect(),
37-
occurrences,
38-
}
30+
pub fn new(
31+
fixed_data: &'a FixedData<'a, T>,
32+
identities: &[(&'a Identity<T>, i32)],
33+
assignments: &[Assignment<'a, T>],
34+
) -> Self {
35+
let queue: BTreeSet<_> = identities
36+
.iter()
37+
.map(|(id, row)| QueueItem::Identity(id, *row))
38+
.chain(assignments.iter().map(|a| QueueItem::Assignment(a.clone())))
39+
.collect();
40+
let occurrences = compute_occurrences_map(fixed_data, &queue).into();
41+
Self { queue, occurrences }
3942
}
4043

4144
/// Returns the next identity to be processed and its row and
4245
/// removes it from the queue.
43-
pub fn next(&mut self) -> Option<(&'a Identity<T>, i32)> {
44-
self.queue.pop_first().map(|QueueItem(id, row)| (id, row))
46+
pub fn next(&mut self) -> Option<QueueItem<'a, T>> {
47+
self.queue.pop_first()
4548
}
4649

4750
pub fn variables_updated(
4851
&mut self,
4952
variables: impl IntoIterator<Item = Variable>,
50-
skip_identity: Option<(&'a Identity<T>, i32)>,
53+
skip_item: Option<QueueItem<'a, T>>,
5154
) {
5255
self.queue.extend(
5356
variables
5457
.into_iter()
5558
.flat_map(|var| self.occurrences.get(&var))
5659
.flatten()
57-
.filter(|QueueItem(id, row)| match skip_identity {
58-
Some((id2, row2)) => (id.id(), *row) != (id2.id(), row2),
60+
.filter(|item| match &skip_item {
61+
Some(it) => *item != it,
5962
None => true,
60-
}),
63+
})
64+
.cloned(),
6165
)
6266
}
6367
}
6468

65-
/// Sorts identities by row and then by ID.
66-
#[derive(Clone, Copy)]
67-
struct QueueItem<'a, T>(&'a Identity<T>, i32);
68-
69-
impl<T> QueueItem<'_, T> {
70-
fn key(&self) -> (i32, u64) {
71-
let QueueItem(id, row) = self;
72-
(*row, id.id())
73-
}
69+
#[derive(Clone)]
70+
pub enum QueueItem<'a, T: FieldElement> {
71+
Identity(&'a Identity<T>, i32),
72+
Assignment(Assignment<'a, T>),
7473
}
7574

76-
impl<T> Ord for QueueItem<'_, T> {
75+
/// Sorts identities by row and then by ID, preceded by assignments.
76+
impl<T: FieldElement> Ord for QueueItem<'_, T> {
7777
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
78-
self.key().cmp(&other.key())
78+
match (self, other) {
79+
(QueueItem::Identity(id1, row1), QueueItem::Identity(id2, row2)) => {
80+
(row1, id1.id()).cmp(&(row2, id2.id()))
81+
}
82+
(QueueItem::Assignment(a1), QueueItem::Assignment(a2)) => a1.cmp(a2),
83+
(QueueItem::Assignment(_), QueueItem::Identity(_, _)) => std::cmp::Ordering::Less,
84+
(QueueItem::Identity(_, _), QueueItem::Assignment(_)) => std::cmp::Ordering::Greater,
85+
}
7986
}
8087
}
8188

82-
impl<T> PartialOrd for QueueItem<'_, T> {
89+
impl<T: FieldElement> PartialOrd for QueueItem<'_, T> {
8390
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
8491
Some(self.cmp(other))
8592
}
8693
}
8794

88-
impl<T> PartialEq for QueueItem<'_, T> {
95+
impl<T: FieldElement> PartialEq for QueueItem<'_, T> {
8996
fn eq(&self, other: &Self) -> bool {
90-
self.key() == other.key()
97+
self.cmp(other) == std::cmp::Ordering::Equal
9198
}
9299
}
93100

94-
impl<T> Eq for QueueItem<'_, T> {}
101+
impl<T: FieldElement> Eq for QueueItem<'_, T> {}
95102

96-
/// Computes a map from each variable to the identity-row-offset pairs it occurs in.
97-
fn compute_occurrences_map<'a, T: FieldElement>(
103+
/// Computes a map from each variable to the queue items it occurs in.
104+
fn compute_occurrences_map<'b, 'a: 'b, T: FieldElement>(
98105
fixed_data: &'a FixedData<'a, T>,
99-
identities: &[(&'a Identity<T>, i32)],
106+
items: &BTreeSet<QueueItem<'a, T>>,
100107
) -> HashMap<Variable, Vec<QueueItem<'a, T>>> {
101-
let mut references_per_identity = HashMap::new();
102108
let mut intermediate_cache = HashMap::new();
103-
for id in identities.iter().map(|(id, _)| *id).unique_by(|id| id.id()) {
109+
110+
// Compute references only once per identity.
111+
let mut references_per_identity = HashMap::new();
112+
for id in items
113+
.iter()
114+
.filter_map(|item| match item {
115+
QueueItem::Identity(id, _) => Some(id),
116+
_ => None,
117+
})
118+
.unique_by(|id| id.id())
119+
{
104120
references_per_identity.insert(
105-
id,
121+
id.id(),
106122
references_in_identity(id, fixed_data, &mut intermediate_cache),
107123
);
108124
}
109-
identities
125+
126+
items
110127
.iter()
111-
.flat_map(|(id, row)| {
112-
references_per_identity[id].iter().map(move |reference| {
113-
let name = fixed_data.column_name(&reference.poly_id).to_string();
114-
let fat_ref = AlgebraicReference {
115-
name,
116-
poly_id: reference.poly_id,
117-
next: reference.next,
118-
};
119-
let var = Variable::from_reference(&fat_ref, *row);
120-
(var, QueueItem(*id, *row))
121-
})
128+
.flat_map(|item| {
129+
let variables = match item {
130+
QueueItem::Identity(id, row) => {
131+
references_in_identity(id, fixed_data, &mut intermediate_cache)
132+
.into_iter()
133+
.map(|r| {
134+
let name = fixed_data.column_name(&r.poly_id).to_string();
135+
Variable::from_reference(&r.with_name(name), *row)
136+
})
137+
.collect_vec()
138+
}
139+
QueueItem::Assignment(a) => {
140+
variables_in_assignment(a, fixed_data, &mut intermediate_cache)
141+
}
142+
};
143+
variables.into_iter().map(move |v| (v, item.clone()))
122144
})
123145
.into_group_map()
124146
}
@@ -184,3 +206,22 @@ fn references_in_expression<'a, T: FieldElement>(
184206
)
185207
.unique()
186208
}
209+
210+
/// Returns a vector of all variables that occur in the assignment.
211+
fn variables_in_assignment<'a, T: FieldElement>(
212+
assignment: &Assignment<'a, T>,
213+
fixed_data: &'a FixedData<'a, T>,
214+
intermediate_cache: &mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
215+
) -> Vec<Variable> {
216+
let rhs_var = match &assignment.rhs {
217+
VariableOrValue::Variable(v) => Some(v.clone()),
218+
VariableOrValue::Value(_) => None,
219+
};
220+
references_in_expression(assignment.lhs, fixed_data, intermediate_cache)
221+
.map(|r| {
222+
let name = fixed_data.column_name(&r.poly_id).to_string();
223+
Variable::from_reference(&r.with_name(name), assignment.row_offset)
224+
})
225+
.chain(rhs_var)
226+
.collect()
227+
}

executor/src/witgen/jit/processor.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::witgen::{
1818
use super::{
1919
affine_symbolic_expression,
2020
effect::{format_code, Effect},
21-
identity_queue::IdentityQueue,
21+
identity_queue::{IdentityQueue, QueueItem},
2222
prover_function_heuristics::ProverFunction,
2323
variable::{Cell, MachineCallVariable, Variable},
2424
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference},
@@ -127,7 +127,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
127127
}
128128
}
129129
let branch_depth = 0;
130-
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities);
130+
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &[]);
131131
self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth)
132132
}
133133

@@ -298,7 +298,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
298298
loop {
299299
let identity = identity_queue.next();
300300
let updated_vars = match identity {
301-
Some((identity, row_offset)) => match identity {
301+
Some(QueueItem::Identity(identity, row_offset)) => match identity {
302302
Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => {
303303
witgen.process_polynomial_identity(*id, expression, row_offset)
304304
}
@@ -317,6 +317,9 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
317317
},
318318
// TODO Also add prover functions to the queue (activated by their variables)
319319
// and sort them so that they are always last.
320+
Some(QueueItem::Assignment(_assignment)) => {
321+
todo!()
322+
}
320323
None => self.process_prover_functions(witgen),
321324
}?;
322325
if updated_vars.is_empty() && identity.is_none() {

executor/src/witgen/jit/witgen_inference.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,14 @@ fn is_known_zero<T: FieldElement>(x: &Option<AffineSymbolicExpression<T, Variabl
680680
/// An equality constraint between an algebraic expression evaluated
681681
/// on a certain row offset and a variable or fixed constant value.
682682
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
683-
struct Assignment<'a, T: FieldElement> {
684-
lhs: &'a Expression<T>,
685-
row_offset: i32,
686-
rhs: VariableOrValue<T, Variable>,
683+
pub struct Assignment<'a, T: FieldElement> {
684+
pub lhs: &'a Expression<T>,
685+
pub row_offset: i32,
686+
pub rhs: VariableOrValue<T, Variable>,
687687
}
688688

689689
#[derive(Clone, derive_more::Display, Ord, PartialOrd, Eq, PartialEq, Debug)]
690-
enum VariableOrValue<T, V> {
690+
pub enum VariableOrValue<T, V> {
691691
Variable(V),
692692
Value(T),
693693
}

0 commit comments

Comments
 (0)