Skip to content

Commit 913dc20

Browse files
committed
apply update.
1 parent 9fd9ad7 commit 913dc20

File tree

4 files changed

+113
-73
lines changed

4 files changed

+113
-73
lines changed

executor/src/witgen/jit/effect.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::fmt::Formatter;
2-
use std::sync::Arc;
32
use std::{cmp::Ordering, fmt::Display};
43

54
use std::{fmt, iter};

executor/src/witgen/jit/interpreter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ impl<T: FieldElement> RPNExpression<T, usize> {
711711
SymbolicExpression::BinaryOperation(lhs, op, rhs, _) => {
712712
inner(lhs, elems, var_mapper);
713713
inner(rhs, elems, var_mapper);
714-
elems.push(RPNExpressionElem::BinaryOperation(op.clone()));
714+
elems.push(RPNExpressionElem::BinaryOperation(*op));
715715
}
716716
SymbolicExpression::UnaryOperation(op, expr, _) => {
717717
inner(expr, elems, var_mapper);

executor/src/witgen/jit/quadratic_symbolic_expression.rs

Lines changed: 99 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use std::{
2-
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
2+
collections::BTreeMap,
33
fmt::Display,
44
hash::Hash,
5-
ops::{Add, Mul, MulAssign, Neg, Sub},
5+
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub},
66
};
77

88
use itertools::Itertools;
@@ -34,15 +34,8 @@ pub struct QuadraticSymbolicExpression<T: FieldElement, V> {
3434
linear: BTreeMap<V, SymbolicExpression<T, V>>,
3535
/// Constant term, a (symbolically) known value.
3636
constant: SymbolicExpression<T, V>,
37-
occurrences: HashMap<V, HashSet<VariableOccurrence<V>>>,
3837
}
3938

40-
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
41-
enum VariableOccurrence<V> {
42-
Quadratic { index: usize, first: bool },
43-
Linear(V),
44-
Constant,
45-
}
4639
// TODO We need occurrence lists for all variables, both in their unknon
4740
// version and in their known version (in the symbolic expressions),
4841
// because range constraints therein can also change.
@@ -62,16 +55,10 @@ impl<T: FieldElement, V: Clone + Hash + Eq> From<SymbolicExpression<T, V>>
6255
for QuadraticSymbolicExpression<T, V>
6356
{
6457
fn from(k: SymbolicExpression<T, V>) -> Self {
65-
let occurrences = k
66-
.referenced_symbols()
67-
.map(|v| ((*v).clone(), VariableOccurrence::Constant))
68-
.into_grouping_map()
69-
.collect();
7058
Self {
7159
quadratic: Default::default(),
7260
linear: Default::default(),
7361
constant: k,
74-
occurrences,
7562
}
7663
}
7764
}
@@ -91,7 +78,6 @@ impl<T: FieldElement, V: Ord + Clone + Hash + Eq> QuadraticSymbolicExpression<T,
9178
quadratic: Default::default(),
9279
linear: [(var.clone(), T::from(1).into())].into_iter().collect(),
9380
constant: T::from(0).into(),
94-
occurrences: Default::default(),
9581
}
9682
}
9783

@@ -110,60 +96,73 @@ impl<T: FieldElement, V: Ord + Clone + Hash + Eq> QuadraticSymbolicExpression<T,
11096
known,
11197
range_constraint,
11298
} = var_update;
113-
if let Some(occurrences) = self.occurrences.get(variable) {
114-
for occurence in occurrences {
115-
match occurence {
116-
VariableOccurrence::Quadratic { index, first } => {}
117-
VariableOccurrence::Linear(v) => todo!(),
118-
VariableOccurrence::Constant => todo!(), //self.constant.apply_update(var_update),
119-
}
99+
self.constant.apply_update(var_update);
100+
// If the variable is a key in `linear`, it must be unknown
101+
// and thus can only occur there. Otherwise, it can be in
102+
// any symbolic expression.
103+
if self.linear.contains_key(variable) {
104+
if *known {
105+
let coeff = self.linear.remove(variable).unwrap();
106+
let expr =
107+
SymbolicExpression::from_symbol(variable.clone(), range_constraint.clone());
108+
self.constant += expr * coeff;
109+
self.linear.remove(variable);
110+
}
111+
} else {
112+
for coeff in self.linear.values_mut() {
113+
coeff.apply_update(var_update);
120114
}
121115
}
122-
if *known {
123-
// TODO if it turns into a constant, we should remove all occurrences.
124-
if let Some(coefficient) = self.linear.remove(variable) {
125-
// TODO update occurrences
126-
self.constant +=
127-
SymbolicExpression::from_symbol(variable.clone(), range_constraint.clone())
128-
* coefficient
116+
117+
// TODO can we do that without moving everything?
118+
// In the end, the order does not matter much.
119+
120+
let mut to_add = QuadraticSymbolicExpression::from(T::zero());
121+
self.quadratic.retain_mut(|(l, r)| {
122+
l.apply_update(var_update);
123+
r.apply_update(var_update);
124+
match (l.try_to_known(), r.try_to_known()) {
125+
(Some(l), Some(r)) => {
126+
to_add += (l * r).into();
127+
false
128+
}
129+
(Some(l), None) => {
130+
to_add += r.clone() * l;
131+
false
132+
}
133+
(None, Some(r)) => {
134+
to_add += l.clone() * r;
135+
false
136+
}
137+
_ => true,
129138
}
139+
});
140+
if to_add.try_to_known().map(|ta| ta.is_known_zero()) != Some(true) {
141+
*self += to_add;
130142
}
131143
}
132144

133145
/// Returns the set of referenced variables, both know and unknown.
134-
pub fn referenced_variables(&self) -> impl Iterator<Item = &V> {
135-
self.occurrences.keys()
146+
pub fn referenced_variables(&self) -> Box<dyn Iterator<Item = &V> + '_> {
147+
let quadr = self
148+
.quadratic
149+
.iter()
150+
.flat_map(|(a, b)| a.referenced_variables().chain(b.referenced_variables()));
151+
152+
let linear = self
153+
.linear
154+
.iter()
155+
.flat_map(|(var, coeff)| std::iter::once(var).chain(coeff.referenced_symbols()));
156+
let constant = self.constant.referenced_symbols();
157+
Box::new(quadr.chain(linear).chain(constant))
136158
}
137159
}
138160

139161
impl<T: FieldElement, V: Clone + Ord + Hash + Eq> Add for QuadraticSymbolicExpression<T, V> {
140162
type Output = QuadraticSymbolicExpression<T, V>;
141163

142164
fn add(mut self, rhs: Self) -> Self {
143-
self.quadratic.extend(rhs.quadratic);
144-
for (var, coeff) in rhs.linear {
145-
self.linear
146-
.entry(var)
147-
.and_modify(|f| *f += coeff.clone())
148-
.or_insert_with(|| coeff);
149-
}
150-
self.constant += rhs.constant;
151-
152-
// Update occurrences.
153-
for (var, occurrences) in rhs.occurrences {
154-
let occurrences = occurrences.into_iter().map(|occurrence| match &occurrence {
155-
VariableOccurrence::Quadratic { index, first } => VariableOccurrence::Quadratic {
156-
index: index + self.quadratic.len(),
157-
first: *first,
158-
},
159-
VariableOccurrence::Linear(_) | VariableOccurrence::Constant => occurrence,
160-
});
161-
self.occurrences.entry(var).or_default().extend(occurrences);
162-
}
163-
164-
// TODO remove all occurrences that point to "linear(v)", where
165-
// va was removed.
166-
self.linear.retain(|_, f| f.is_known_zero());
165+
self += rhs;
167166
self
168167
}
169168
}
@@ -176,6 +175,22 @@ impl<T: FieldElement, V: Clone + Ord + Hash + Eq> Add for &QuadraticSymbolicExpr
176175
}
177176
}
178177

178+
impl<T: FieldElement, V: Clone + Ord + Hash + Eq> AddAssign<QuadraticSymbolicExpression<T, V>>
179+
for QuadraticSymbolicExpression<T, V>
180+
{
181+
fn add_assign(&mut self, rhs: Self) {
182+
self.quadratic.extend(rhs.quadratic);
183+
for (var, coeff) in rhs.linear {
184+
self.linear
185+
.entry(var.clone())
186+
.and_modify(|f| *f += coeff.clone())
187+
.or_insert_with(|| coeff);
188+
}
189+
self.constant += rhs.constant.clone();
190+
self.linear.retain(|_, f| !f.is_known_zero());
191+
}
192+
}
193+
179194
impl<T: FieldElement, V: Clone + Ord + Hash + Eq> Sub for &QuadraticSymbolicExpression<T, V> {
180195
type Output = QuadraticSymbolicExpression<T, V>;
181196

@@ -252,13 +267,10 @@ impl<T: FieldElement, V: Clone + Ord + Hash + Eq> MulAssign<&SymbolicExpression<
252267
} else {
253268
for (first, _) in &mut self.quadratic {
254269
*first *= rhs;
255-
// TODO update occurrences
256270
}
257271
for coeff in self.linear.values_mut() {
258-
// TODO update occurrences
259272
*coeff *= rhs.clone();
260273
}
261-
// TODO update occurrences
262274
self.constant *= rhs.clone();
263275
}
264276
}
@@ -273,16 +285,10 @@ impl<T: FieldElement, V: Clone + Ord + Hash + Eq> Mul for QuadraticSymbolicExpre
273285
} else if let Some(k) = self.try_to_known() {
274286
rhs * k
275287
} else {
276-
let occurrences = (self.referenced_variables().map(|v| ((*v).clone(), true)))
277-
.chain(rhs.referenced_variables().map(|v| ((*v).clone(), false)))
278-
.map(|(v, first)| (v, VariableOccurrence::Quadratic { index: 0, first }))
279-
.into_grouping_map()
280-
.collect();
281288
Self {
282289
quadratic: vec![(self, rhs)],
283290
linear: Default::default(),
284291
constant: T::from(0).into(),
285-
occurrences,
286292
}
287293
}
288294
}
@@ -375,4 +381,32 @@ mod tests {
375381
assert_eq!(t.to_string(), "(X) * (Y) + A");
376382
assert_eq!((t.clone() * zero).to_string(), "0");
377383
}
384+
385+
#[test]
386+
fn test_apply_update() {
387+
let x = Qse::from_unknown_variable("X".to_string());
388+
let y = Qse::from_unknown_variable("Y".to_string());
389+
let a = Qse::from_known_symbol("A".to_string(), RangeConstraint::default());
390+
let b = Qse::from_known_symbol("B".to_string(), RangeConstraint::default());
391+
let mut t: Qse = (x * y + a) * b;
392+
assert_eq!(t.to_string(), "(B * X) * (Y) + (A * B)");
393+
t.apply_update(&VariableUpdate {
394+
variable: "B".to_string(),
395+
known: true,
396+
range_constraint: RangeConstraint::from_value(7.into()),
397+
});
398+
assert_eq!(t.to_string(), "(7 * X) * (Y) + (A * 7)");
399+
t.apply_update(&VariableUpdate {
400+
variable: "X".to_string(),
401+
known: true,
402+
range_constraint: RangeConstraint::from_range(1.into(), 2.into()),
403+
});
404+
assert_eq!(t.to_string(), "(X * 7) * Y + (A * 7)");
405+
t.apply_update(&VariableUpdate {
406+
variable: "Y".to_string(),
407+
known: true,
408+
range_constraint: RangeConstraint::from_value(3.into()),
409+
});
410+
assert_eq!(t.to_string(), "((A * 7) + (3 * (X * 7)))");
411+
}
378412
}

executor/src/witgen/jit/symbolic_expression.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,23 @@ impl<T: FieldElement, S> SymbolicExpression<T, S> {
113113

114114
impl<T: FieldElement, S: Clone + Eq> SymbolicExpression<T, S> {
115115
/// Applies a variable update and returns a modified version if there was a change.
116-
pub fn apply_update(&self, variable_update: &VariableUpdate<T, S>) -> Option<Self> {
116+
pub fn compute_updated(&self, variable_update: &VariableUpdate<T, S>) -> Option<Self> {
117117
match self {
118118
SymbolicExpression::Concrete(_) => None,
119-
SymbolicExpression::Symbol(v, range_constraint) => {
119+
SymbolicExpression::Symbol(v, _) => {
120120
if *v == variable_update.variable {
121121
Some(SymbolicExpression::from_symbol(
122122
v.clone(),
123-
range_constraint.clone(),
123+
variable_update.range_constraint.clone(),
124124
))
125125
} else {
126126
None
127127
}
128128
}
129129
SymbolicExpression::BinaryOperation(left, op, right, _) => {
130130
let (l, r) = match (
131-
left.apply_update(variable_update),
132-
right.apply_update(variable_update),
131+
left.compute_updated(variable_update),
132+
right.compute_updated(variable_update),
133133
) {
134134
(None, None) => return None,
135135
(Some(l), None) => (l, (**right).clone()),
@@ -144,12 +144,19 @@ impl<T: FieldElement, S: Clone + Eq> SymbolicExpression<T, S> {
144144
}
145145
}
146146
SymbolicExpression::UnaryOperation(op, inner, _) => {
147-
let inner = inner.apply_update(variable_update)?;
147+
let inner = inner.compute_updated(variable_update)?;
148148
assert!(matches!(op, UnaryOperator::Neg));
149149
Some(-inner)
150150
}
151151
}
152152
}
153+
154+
/// Applies a variable update in place.
155+
pub fn apply_update(&mut self, variable_update: &VariableUpdate<T, S>) {
156+
if let Some(updated) = self.compute_updated(variable_update) {
157+
*self = updated;
158+
}
159+
}
153160
}
154161

155162
impl<T: FieldElement, S: Hash + Eq> SymbolicExpression<T, S> {

0 commit comments

Comments
 (0)