77
88use crate :: { Poly , expr_to_var, extract_bit} ;
99use baa:: { BitVecOps , BitVecValueRef } ;
10- use patronus:: expr:: { Context , Expr , ExprRef , ExprSet , SparseExprSet , TypeCheck , traversal} ;
11- use polysub:: Coef ;
12- use rustc_hash:: FxHashSet ;
10+ use patronus:: expr:: {
11+ Context , Expr , ExprRef , ExprSet , ForEachChild , SerializableIrNode , SparseExprMap ,
12+ SparseExprSet , TypeCheck , count_expr_uses, traversal,
13+ } ;
14+ use polysub:: { Coef , Mod } ;
15+ use rustc_hash:: { FxHashMap , FxHashSet } ;
16+
17+ #[ derive( Debug , Copy , Clone ) ]
18+ pub enum BuildPolyMode {
19+ Arithmetic ,
20+ Gates ( Mod ) ,
21+ }
1322
1423/// Returns a polynomial representation of the expression + all input expressions if possible.
1524/// Returns `None` if the conversion fails.
1625pub fn build_bottom_up_poly (
1726 ctx : & mut Context ,
1827 inputs : & FxHashSet < ExprRef > ,
1928 e : ExprRef ,
29+ mode : BuildPolyMode ,
2030) -> Option < Poly > {
21- traversal:: bottom_up_mut ( ctx, e, |ctx, e, c| to_poly ( ctx, inputs, e, c) ) . map ( |( p, _) | p)
31+ // we use a custom traversal that caches polynomials until they are no longer used
32+ let mut uses = count_expr_uses ( ctx, vec ! [ e] ) ;
33+ let mut todo = vec ! [ e] ;
34+ let mut result: FxHashMap < ExprRef , Option < ( Poly , bool ) > > = FxHashMap :: default ( ) ;
35+ let mut child_vec = Vec :: with_capacity ( 4 ) ;
36+
37+ while let Some ( e) = todo. pop ( ) {
38+ assert ! ( !result. contains_key( & e) ) ;
39+
40+ let expr = & ctx[ e] ;
41+ // find children that are not available yet.
42+ debug_assert ! ( child_vec. is_empty( ) ) ;
43+ expr. collect_children ( & mut child_vec) ;
44+ let all_available = child_vec. iter ( ) . all ( |c| result. contains_key ( c) ) ;
45+
46+ if all_available {
47+ let child_results: Vec < Option < & ( Poly , bool ) > > =
48+ child_vec. iter ( ) . map ( |c| result[ c] . as_ref ( ) ) . collect ( ) ;
49+ let r = to_poly ( ctx, inputs, mode, e, & child_results) ;
50+ result. insert ( e, r) ;
51+ for child in child_vec. drain ( ..) {
52+ let old_use = uses[ child] ;
53+ if old_use == 1 {
54+ result. remove ( & child) ;
55+ }
56+ uses[ child] = old_use - 1 ;
57+ }
58+ } else {
59+ todo. push ( e) ;
60+ for child in child_vec. drain ( ..) {
61+ if !result. contains_key ( & child) {
62+ todo. push ( child) ;
63+ }
64+ }
65+ }
66+ }
67+
68+ result[ & e] . as_ref ( ) . map ( |( p, _) | p. clone ( ) )
2269}
2370
2471fn to_poly (
2572 ctx : & mut Context ,
2673 inputs : & FxHashSet < ExprRef > ,
74+ mode : BuildPolyMode ,
2775 e : ExprRef ,
28- children : & [ Option < ( Poly , bool ) > ] ,
76+ children : & [ Option < & ( Poly , bool ) > ] ,
2977) -> Option < ( Poly , bool ) > {
3078 // have we given up yet?
3179 if children. iter ( ) . any ( |c| c. is_none ( ) ) {
@@ -37,6 +85,89 @@ fn to_poly(
3785 return Some ( ( poly_for_bv_expr ( ctx, e) , false ) ) ;
3886 }
3987
88+ match mode {
89+ BuildPolyMode :: Arithmetic => to_poly_arithmetic ( ctx, inputs, e, children) ,
90+ BuildPolyMode :: Gates ( m) => Some ( ( to_poly_gate ( ctx, inputs, m, e, children) , false ) ) ,
91+ }
92+ }
93+
94+ /// For bit-level polynomials from gates, the normal arithmetic rules and overflow checking do not
95+ /// apply. Instead, we use the modulo coefficient of the top-level expression.
96+ fn to_poly_gate (
97+ ctx : & mut Context ,
98+ inputs : & FxHashSet < ExprRef > ,
99+ m : Mod ,
100+ e : ExprRef ,
101+ children : & [ Option < & ( Poly , bool ) > ] ,
102+ ) -> Poly {
103+ debug_assert ! (
104+ children
105+ . iter( )
106+ . all( |c| c. map( |( _, ov) | !* ov) . unwrap_or( true ) )
107+ ) ;
108+
109+ match ( ctx[ e] . clone ( ) , children) {
110+ ( Expr :: BVSymbol { .. } , _) => unreachable ! ( "all symbols should be in inputs" ) ,
111+ ( Expr :: BVLiteral ( value) , _) => {
112+ let mut r = poly_for_bv_literal ( value. get ( ctx) ) ;
113+ r. change_mod ( m) ;
114+ r
115+ }
116+ ( Expr :: BVSlice { e, hi, lo } , _) if hi == lo && inputs. contains ( & e) => {
117+ // special case: bit_slice of an input
118+ let var = expr_to_var ( extract_bit ( ctx, e, hi) ) ;
119+ Poly :: from_monoms ( m, [ ( Coef :: from_i64 ( 1 , m) , vec ! [ var] . into ( ) ) ] . into_iter ( ) )
120+ }
121+ ( Expr :: BVConcat ( _, be, w) , [ Some ( ( a, _) ) , Some ( ( b, _) ) ] ) => {
122+ // left shift a
123+ let shift_by = be. get_bv_type ( ctx) . unwrap ( ) ;
124+ let shift_coef = Coef :: pow2 ( shift_by, m) ;
125+ let mut r: Poly = a. clone ( ) ;
126+ r. scale ( & shift_coef) ;
127+ r. add_assign ( b) ;
128+ r
129+ }
130+ ( Expr :: BVOr ( _, _, 1 ) , [ Some ( ( a, _) ) , Some ( ( b, _) ) ] ) => {
131+ // a + b - ab
132+ let mut r = a. mul ( b) ;
133+ r. scale ( & Coef :: from_i64 ( -1 , a. get_mod ( ) ) ) ;
134+ r. add_assign ( a) ;
135+ r. add_assign ( b) ;
136+ r
137+ }
138+ ( Expr :: BVXor ( _, _, 1 ) , [ Some ( ( a, _) ) , Some ( ( b, _) ) ] ) => {
139+ // a + b - 2ab
140+ let mut r = a. mul ( b) ;
141+ let minus_2 = Coef :: from_i64 ( -2 , a. get_mod ( ) ) ;
142+ r. scale ( & minus_2) ;
143+ r. add_assign ( a) ;
144+ r. add_assign ( b) ;
145+ r
146+ }
147+ ( Expr :: BVAnd ( _, _, 1 ) , [ Some ( ( a, _) ) , Some ( ( b, _) ) ] ) => {
148+ // ab
149+ a. clone ( ) . mul ( b)
150+ }
151+ ( Expr :: BVNot ( _, 1 ) , [ Some ( ( a, _) ) ] ) => {
152+ // 1 - a
153+ let one = Poly :: from_monoms ( m, [ ( Coef :: from_i64 ( 1 , m) , vec ! [ ] . into ( ) ) ] . into_iter ( ) ) ;
154+ let mut r = a. clone ( ) ;
155+ r. scale ( & Coef :: from_i64 ( -1 , a. get_mod ( ) ) ) ;
156+ r. add_assign ( & one) ;
157+ r
158+ }
159+ ( other, cs) => todo ! ( "{other:?}: {cs:?}" ) ,
160+ }
161+ }
162+
163+ /// When building a polynomial over an arithmetic circuit, we are tracking overflow and
164+ /// determining the modulo coefficient from the bit-widths.
165+ fn to_poly_arithmetic (
166+ ctx : & mut Context ,
167+ _inputs : & FxHashSet < ExprRef > ,
168+ e : ExprRef ,
169+ children : & [ Option < & ( Poly , bool ) > ] ,
170+ ) -> Option < ( Poly , bool ) > {
40171 match ( ctx[ e] . clone ( ) , children) {
41172 ( Expr :: BVSymbol { .. } , _) => unreachable ! ( "all symbols should be in inputs" ) ,
42173 ( Expr :: BVLiteral ( value) , _) => Some ( ( poly_for_bv_literal ( value. get ( ctx) ) , false ) ) ,
0 commit comments