11use std:: {
2- collections:: { BTreeMap , BTreeSet , HashMap , HashSet } ,
2+ collections:: BTreeMap ,
33 fmt:: Display ,
44 hash:: Hash ,
5- ops:: { Add , Mul , MulAssign , Neg , Sub } ,
5+ ops:: { Add , AddAssign , Mul , MulAssign , Neg , Sub } ,
66} ;
77
88use itertools:: Itertools ;
@@ -34,15 +34,8 @@ pub struct QuadraticSymbolicExpression<T: FieldElement, V> {
3434 linear : BTreeMap < V , SymbolicExpression < T , V > > ,
3535 /// Constant term, a (symbolically) known value.
3636 constant : SymbolicExpression < T , V > ,
37- occurrences : HashMap < V , HashSet < VariableOccurrence < V > > > ,
3837}
3938
40- #[ derive( Debug , Clone , Hash , PartialEq , Eq ) ]
41- enum VariableOccurrence < V > {
42- Quadratic { index : usize , first : bool } ,
43- Linear ( V ) ,
44- Constant ,
45- }
4639// TODO We need occurrence lists for all variables, both in their unknon
4740// version and in their known version (in the symbolic expressions),
4841// because range constraints therein can also change.
@@ -62,16 +55,10 @@ impl<T: FieldElement, V: Clone + Hash + Eq> From<SymbolicExpression<T, V>>
6255 for QuadraticSymbolicExpression < T , V >
6356{
6457 fn from ( k : SymbolicExpression < T , V > ) -> Self {
65- let occurrences = k
66- . referenced_symbols ( )
67- . map ( |v| ( ( * v) . clone ( ) , VariableOccurrence :: Constant ) )
68- . into_grouping_map ( )
69- . collect ( ) ;
7058 Self {
7159 quadratic : Default :: default ( ) ,
7260 linear : Default :: default ( ) ,
7361 constant : k,
74- occurrences,
7562 }
7663 }
7764}
@@ -91,7 +78,6 @@ impl<T: FieldElement, V: Ord + Clone + Hash + Eq> QuadraticSymbolicExpression<T,
9178 quadratic : Default :: default ( ) ,
9279 linear : [ ( var. clone ( ) , T :: from ( 1 ) . into ( ) ) ] . into_iter ( ) . collect ( ) ,
9380 constant : T :: from ( 0 ) . into ( ) ,
94- occurrences : Default :: default ( ) ,
9581 }
9682 }
9783
@@ -110,60 +96,73 @@ impl<T: FieldElement, V: Ord + Clone + Hash + Eq> QuadraticSymbolicExpression<T,
11096 known,
11197 range_constraint,
11298 } = var_update;
113- if let Some ( occurrences) = self . occurrences . get ( variable) {
114- for occurence in occurrences {
115- match occurence {
116- VariableOccurrence :: Quadratic { index, first } => { }
117- VariableOccurrence :: Linear ( v) => todo ! ( ) ,
118- VariableOccurrence :: Constant => todo ! ( ) , //self.constant.apply_update(var_update),
119- }
99+ self . constant . apply_update ( var_update) ;
100+ // If the variable is a key in `linear`, it must be unknown
101+ // and thus can only occur there. Otherwise, it can be in
102+ // any symbolic expression.
103+ if self . linear . contains_key ( variable) {
104+ if * known {
105+ let coeff = self . linear . remove ( variable) . unwrap ( ) ;
106+ let expr =
107+ SymbolicExpression :: from_symbol ( variable. clone ( ) , range_constraint. clone ( ) ) ;
108+ self . constant += expr * coeff;
109+ self . linear . remove ( variable) ;
110+ }
111+ } else {
112+ for coeff in self . linear . values_mut ( ) {
113+ coeff. apply_update ( var_update) ;
120114 }
121115 }
122- if * known {
123- // TODO if it turns into a constant, we should remove all occurrences.
124- if let Some ( coefficient) = self . linear . remove ( variable) {
125- // TODO update occurrences
126- self . constant +=
127- SymbolicExpression :: from_symbol ( variable. clone ( ) , range_constraint. clone ( ) )
128- * coefficient
116+
117+ // TODO can we do that without moving everything?
118+ // In the end, the order does not matter much.
119+
120+ let mut to_add = QuadraticSymbolicExpression :: from ( T :: zero ( ) ) ;
121+ self . quadratic . retain_mut ( |( l, r) | {
122+ l. apply_update ( var_update) ;
123+ r. apply_update ( var_update) ;
124+ match ( l. try_to_known ( ) , r. try_to_known ( ) ) {
125+ ( Some ( l) , Some ( r) ) => {
126+ to_add += ( l * r) . into ( ) ;
127+ false
128+ }
129+ ( Some ( l) , None ) => {
130+ to_add += r. clone ( ) * l;
131+ false
132+ }
133+ ( None , Some ( r) ) => {
134+ to_add += l. clone ( ) * r;
135+ false
136+ }
137+ _ => true ,
129138 }
139+ } ) ;
140+ if to_add. try_to_known ( ) . map ( |ta| ta. is_known_zero ( ) ) != Some ( true ) {
141+ * self += to_add;
130142 }
131143 }
132144
133145 /// Returns the set of referenced variables, both know and unknown.
134- pub fn referenced_variables ( & self ) -> impl Iterator < Item = & V > {
135- self . occurrences . keys ( )
146+ pub fn referenced_variables ( & self ) -> Box < dyn Iterator < Item = & V > + ' _ > {
147+ let quadr = self
148+ . quadratic
149+ . iter ( )
150+ . flat_map ( |( a, b) | a. referenced_variables ( ) . chain ( b. referenced_variables ( ) ) ) ;
151+
152+ let linear = self
153+ . linear
154+ . iter ( )
155+ . flat_map ( |( var, coeff) | std:: iter:: once ( var) . chain ( coeff. referenced_symbols ( ) ) ) ;
156+ let constant = self . constant . referenced_symbols ( ) ;
157+ Box :: new ( quadr. chain ( linear) . chain ( constant) )
136158 }
137159}
138160
139161impl < T : FieldElement , V : Clone + Ord + Hash + Eq > Add for QuadraticSymbolicExpression < T , V > {
140162 type Output = QuadraticSymbolicExpression < T , V > ;
141163
142164 fn add ( mut self , rhs : Self ) -> Self {
143- self . quadratic . extend ( rhs. quadratic ) ;
144- for ( var, coeff) in rhs. linear {
145- self . linear
146- . entry ( var)
147- . and_modify ( |f| * f += coeff. clone ( ) )
148- . or_insert_with ( || coeff) ;
149- }
150- self . constant += rhs. constant ;
151-
152- // Update occurrences.
153- for ( var, occurrences) in rhs. occurrences {
154- let occurrences = occurrences. into_iter ( ) . map ( |occurrence| match & occurrence {
155- VariableOccurrence :: Quadratic { index, first } => VariableOccurrence :: Quadratic {
156- index : index + self . quadratic . len ( ) ,
157- first : * first,
158- } ,
159- VariableOccurrence :: Linear ( _) | VariableOccurrence :: Constant => occurrence,
160- } ) ;
161- self . occurrences . entry ( var) . or_default ( ) . extend ( occurrences) ;
162- }
163-
164- // TODO remove all occurrences that point to "linear(v)", where
165- // va was removed.
166- self . linear . retain ( |_, f| f. is_known_zero ( ) ) ;
165+ self += rhs;
167166 self
168167 }
169168}
@@ -176,6 +175,22 @@ impl<T: FieldElement, V: Clone + Ord + Hash + Eq> Add for &QuadraticSymbolicExpr
176175 }
177176}
178177
178+ impl < T : FieldElement , V : Clone + Ord + Hash + Eq > AddAssign < QuadraticSymbolicExpression < T , V > >
179+ for QuadraticSymbolicExpression < T , V >
180+ {
181+ fn add_assign ( & mut self , rhs : Self ) {
182+ self . quadratic . extend ( rhs. quadratic ) ;
183+ for ( var, coeff) in rhs. linear {
184+ self . linear
185+ . entry ( var. clone ( ) )
186+ . and_modify ( |f| * f += coeff. clone ( ) )
187+ . or_insert_with ( || coeff) ;
188+ }
189+ self . constant += rhs. constant . clone ( ) ;
190+ self . linear . retain ( |_, f| !f. is_known_zero ( ) ) ;
191+ }
192+ }
193+
179194impl < T : FieldElement , V : Clone + Ord + Hash + Eq > Sub for & QuadraticSymbolicExpression < T , V > {
180195 type Output = QuadraticSymbolicExpression < T , V > ;
181196
@@ -252,13 +267,10 @@ impl<T: FieldElement, V: Clone + Ord + Hash + Eq> MulAssign<&SymbolicExpression<
252267 } else {
253268 for ( first, _) in & mut self . quadratic {
254269 * first *= rhs;
255- // TODO update occurrences
256270 }
257271 for coeff in self . linear . values_mut ( ) {
258- // TODO update occurrences
259272 * coeff *= rhs. clone ( ) ;
260273 }
261- // TODO update occurrences
262274 self . constant *= rhs. clone ( ) ;
263275 }
264276 }
@@ -273,16 +285,10 @@ impl<T: FieldElement, V: Clone + Ord + Hash + Eq> Mul for QuadraticSymbolicExpre
273285 } else if let Some ( k) = self . try_to_known ( ) {
274286 rhs * k
275287 } else {
276- let occurrences = ( self . referenced_variables ( ) . map ( |v| ( ( * v) . clone ( ) , true ) ) )
277- . chain ( rhs. referenced_variables ( ) . map ( |v| ( ( * v) . clone ( ) , false ) ) )
278- . map ( |( v, first) | ( v, VariableOccurrence :: Quadratic { index : 0 , first } ) )
279- . into_grouping_map ( )
280- . collect ( ) ;
281288 Self {
282289 quadratic : vec ! [ ( self , rhs) ] ,
283290 linear : Default :: default ( ) ,
284291 constant : T :: from ( 0 ) . into ( ) ,
285- occurrences,
286292 }
287293 }
288294 }
@@ -375,4 +381,32 @@ mod tests {
375381 assert_eq ! ( t. to_string( ) , "(X) * (Y) + A" ) ;
376382 assert_eq ! ( ( t. clone( ) * zero) . to_string( ) , "0" ) ;
377383 }
384+
385+ #[ test]
386+ fn test_apply_update ( ) {
387+ let x = Qse :: from_unknown_variable ( "X" . to_string ( ) ) ;
388+ let y = Qse :: from_unknown_variable ( "Y" . to_string ( ) ) ;
389+ let a = Qse :: from_known_symbol ( "A" . to_string ( ) , RangeConstraint :: default ( ) ) ;
390+ let b = Qse :: from_known_symbol ( "B" . to_string ( ) , RangeConstraint :: default ( ) ) ;
391+ let mut t: Qse = ( x * y + a) * b;
392+ assert_eq ! ( t. to_string( ) , "(B * X) * (Y) + (A * B)" ) ;
393+ t. apply_update ( & VariableUpdate {
394+ variable : "B" . to_string ( ) ,
395+ known : true ,
396+ range_constraint : RangeConstraint :: from_value ( 7 . into ( ) ) ,
397+ } ) ;
398+ assert_eq ! ( t. to_string( ) , "(7 * X) * (Y) + (A * 7)" ) ;
399+ t. apply_update ( & VariableUpdate {
400+ variable : "X" . to_string ( ) ,
401+ known : true ,
402+ range_constraint : RangeConstraint :: from_range ( 1 . into ( ) , 2 . into ( ) ) ,
403+ } ) ;
404+ assert_eq ! ( t. to_string( ) , "(X * 7) * Y + (A * 7)" ) ;
405+ t. apply_update ( & VariableUpdate {
406+ variable : "Y" . to_string ( ) ,
407+ known : true ,
408+ range_constraint : RangeConstraint :: from_value ( 3 . into ( ) ) ,
409+ } ) ;
410+ assert_eq ! ( t. to_string( ) , "((A * 7) + (3 * (X * 7)))" ) ;
411+ }
378412}
0 commit comments