55use baa:: { BitVecOps , BitVecValue , BitVecValueRef } ;
66use patronus:: expr:: {
77 Context , DenseExprMetaData , Expr , ExprRef , ForEachChild , SerializableIrNode , Simplifier ,
8- TypeCheck , WidthInt , count_expr_uses, simplify_single_expression , traversal,
8+ TypeCheck , WidthInt , count_expr_uses, traversal,
99} ;
1010use polysub:: { Coef , Term , VarIndex } ;
11- use rustc_hash:: FxHashMap ;
1211use std:: collections:: VecDeque ;
1312use std:: fmt:: { Display , Formatter } ;
1413
@@ -85,24 +84,15 @@ fn backwards_sub(ctx: &Context, mut todo: VecDeque<(VarIndex, ExprRef)>, mut spe
8584 var_roots. sort ( ) ;
8685
8786 let m = spec. get_mod ( ) ;
88- println ! ( "MOD={m:?} ({} bits)" , m. bits( ) ) ;
8987 let one: DefaultCoef = Coef :: from_i64 ( 1 , m) ;
90- println ! ( "one = {one:?}" ) ;
9188 let minus_one: DefaultCoef = Coef :: from_i64 ( -1 , m) ;
92- println ! ( "minus_one = {minus_one:?}" ) ;
9389 let minus_two: DefaultCoef = Coef :: from_i64 ( -2 , m) ;
9490 // first, we count how often expressions are used
9591 let roots: Vec < _ > = todo. iter ( ) . map ( |( _, e) | * e) . collect ( ) ;
9692 let mut uses = count_expr_uses ( ctx, roots) ;
9793 let mut replaced = vec ! [ ] ;
9894
9995 while let Some ( ( output_var, gate) ) = todo. pop_back ( ) {
100- println ! (
101- "{output_var}, {}, {:?} ({})" ,
102- expr_to_var( gate) == output_var,
103- & ctx[ gate] ,
104- spec. size( )
105- ) ;
10696 replaced. push ( output_var) ;
10797
10898 let add_children = match ctx[ gate] . clone ( ) {
@@ -171,7 +161,6 @@ fn backwards_sub(ctx: &Context, mut todo: VecDeque<(VarIndex, ExprRef)>, mut spe
171161 }
172162 } ) ;
173163 }
174- println ! ( "{spec}" ) ;
175164 }
176165
177166 println ! ( "Roots: {var_roots:?}" ) ;
@@ -183,7 +172,7 @@ fn backwards_sub(ctx: &Context, mut todo: VecDeque<(VarIndex, ExprRef)>, mut spe
183172 use patronus:: expr:: ExprMap ;
184173 let mut still_used: Vec < _ > = uses
185174 . iter ( )
186- . filter ( |( k , v) | * * v > 0 )
175+ . filter ( |( _ , v) | * * v > 0 )
187176 . map ( |( k, _) | expr_to_var ( k) )
188177 . collect ( ) ;
189178 still_used. sort ( ) ;
@@ -274,7 +263,7 @@ fn build_bottom_up_poly(ctx: &mut Context, e: ExprRef) -> Poly {
274263 poly
275264}
276265
277- #[ derive( Debug , Clone , PartialEq ) ]
266+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
278267pub struct ScaEqualityProblem {
279268 equality : ExprRef ,
280269 gate_level : ExprRef ,
@@ -486,11 +475,10 @@ impl<'a, 'b> Display for PrettyPoly<'a, 'b> {
486475#[ cfg( test) ]
487476mod tests {
488477 use super :: * ;
489- use patronus:: expr:: { eval_bv_expr, eval_expr , simplify_single_expression} ;
490- use patronus:: smt:: { SmtCommand , read_command, serialize_cmd } ;
478+ use patronus:: expr:: { eval_bv_expr, find_symbols , simplify_single_expression} ;
479+ use patronus:: smt:: { SmtCommand , read_command} ;
491480 use rustc_hash:: FxHashMap ;
492- use std:: io:: { BufReader , BufWriter } ;
493- use std:: ptr:: eq;
481+ use std:: io:: BufReader ;
494482
495483 fn read_first_assert_expr (
496484 ctx : & mut Context ,
@@ -516,7 +504,13 @@ mod tests {
516504 let candidates = find_sca_simplification_candidates ( & ctx, e) ;
517505 let simplified: Vec < _ > = candidates
518506 . into_iter ( )
519- . flat_map ( |c| simplify_word_level_equality ( & mut ctx, c) )
507+ . flat_map ( |p| {
508+ let exhaustive = is_eq_exhaustive ( & ctx, p. clone ( ) ) ;
509+ let sca_based = simplify_word_level_equality ( & mut ctx, p) . unwrap ( ) ;
510+ let sca_based_bool = ctx[ sca_based] . is_true ( ) ;
511+ assert_eq ! ( exhaustive, sca_based_bool) ;
512+ Some ( sca_based)
513+ } )
520514 . collect ( ) ;
521515 if simplified. is_empty ( ) {
522516 None
@@ -551,7 +545,70 @@ mod tests {
551545
552546 /// Performs an exhaustive check of all input values
553547 fn is_eq_exhaustive ( ctx : & Context , p : ScaEqualityProblem ) -> bool {
554- todo ! ( )
548+ let word_symbols = find_symbols ( ctx, p. word_level ) ;
549+ let gate_symbols = find_symbols ( ctx, p. word_level ) ;
550+ debug_assert_eq ! ( word_symbols, gate_symbols) ;
551+ let inputs: Vec < _ > = word_symbols
552+ . iter ( )
553+ . map ( |& s| {
554+ let width = s. get_bv_type ( ctx) . unwrap ( ) ;
555+ debug_assert ! ( width <= 16 ) ;
556+ let max_value = ( 1u64 << width) - 1 ;
557+ (
558+ ctx[ s] . get_symbol_name ( ctx) . unwrap ( ) . to_string ( ) ,
559+ s,
560+ width,
561+ max_value,
562+ )
563+ } )
564+ . collect ( ) ;
565+
566+ let mut values = vec ! [ 0u64 ; inputs. len( ) ] ;
567+ let max_values: Vec < _ > = inputs. iter ( ) . map ( |( _, _, _, v) | * v) . collect ( ) ;
568+
569+ let mut count = 0 ;
570+ while values != max_values {
571+ count += 1 ;
572+ // perform check
573+ let symbols: Vec < _ > = inputs
574+ . iter ( )
575+ . zip ( values. iter ( ) )
576+ . map ( |( ( _, s, w, _) , v) | ( * s, BitVecValue :: from_u64 ( * v, * w) ) )
577+ . collect ( ) ;
578+
579+ let word_value = eval_bv_expr ( & ctx, symbols. as_slice ( ) , p. word_level ) ;
580+ let gate_value = eval_bv_expr ( & ctx, symbols. as_slice ( ) , p. gate_level ) ;
581+ let is_equal = eval_bv_expr ( & ctx, symbols. as_slice ( ) , p. equality ) ;
582+
583+ if !is_equal. is_true ( ) {
584+ let syms: Vec < _ > = inputs
585+ . iter ( )
586+ . zip ( values. iter ( ) )
587+ . map ( |( ( n, _, _, _) , v) | format ! ( "{n}={v}" ) )
588+ . collect ( ) ;
589+ println ! (
590+ "Not equal! GATE: {} =/= WORD: {} w/ {}" ,
591+ gate_value. to_dec_str( ) ,
592+ word_value. to_dec_str( ) ,
593+ syms. join( ", " )
594+ ) ;
595+ return false ;
596+ }
597+
598+ // increment
599+ for ( value, max_value) in values. iter_mut ( ) . zip ( max_values. iter ( ) ) {
600+ debug_assert ! ( * value <= * max_value) ;
601+ if * value == * max_value {
602+ * value = 0 ;
603+ // set to zero and go to next "digit"
604+ } else {
605+ * value += 1 ;
606+ break ; // done
607+ }
608+ }
609+ }
610+ println ! ( "expressions appear to be equivalent after {count} iterations" ) ;
611+ true
555612 }
556613
557614 #[ test]
@@ -578,11 +635,10 @@ mod tests {
578635 let gate_level = ctx. concat ( c0, s0) ;
579636 let equality = ctx. equal ( word_level, gate_level) ;
580637
581- let problem = ScaEqualityProblem {
582- equality,
583- gate_level,
584- word_level,
585- } ;
638+ let problem = find_sca_simplification_candidates ( & ctx, equality) [ 0 ] ;
639+
640+ // manually check that our problem is actually correct
641+ assert ! ( is_eq_exhaustive( & ctx, problem) ) ;
586642
587643 let result = simplify_word_level_equality ( & mut ctx, problem) . unwrap ( ) ;
588644 assert_eq ! ( result, ctx. get_true( ) ) ;
@@ -604,30 +660,10 @@ mod tests {
604660 let gate_level = ctx. concat ( gate_level_1, gate_level_0) ;
605661 let equality = ctx. equal ( word_level, gate_level) ;
606662
607- // manually check that our problem is actually correct
608- for a_value in 0 ..3 {
609- for b_value in 0 ..3 {
610- let symbols = [
611- ( a, BitVecValue :: from_u64 ( a_value, 2 ) ) ,
612- ( b, BitVecValue :: from_u64 ( b_value, 2 ) ) ,
613- ] ;
614- let word_value = eval_bv_expr ( & ctx, symbols. as_slice ( ) , word_level) ;
615- let gate_value = eval_bv_expr ( & ctx, symbols. as_slice ( ) , gate_level) ;
616- let is_equal = eval_bv_expr ( & ctx, symbols. as_slice ( ) , equality) ;
617- assert ! (
618- is_equal. is_true( ) ,
619- "a={a_value}, b={b_value}, gate={}, word={}" ,
620- gate_value. to_dec_str( ) ,
621- word_value. to_dec_str( )
622- ) ;
623- }
624- }
663+ let problem = find_sca_simplification_candidates ( & ctx, equality) [ 0 ] ;
625664
626- let problem = ScaEqualityProblem {
627- equality,
628- gate_level,
629- word_level,
630- } ;
665+ // manually check that our problem is actually correct
666+ assert ! ( is_eq_exhaustive( & ctx, problem) ) ;
631667
632668 let result = simplify_word_level_equality ( & mut ctx, problem) . unwrap ( ) ;
633669 assert_eq ! ( result, ctx. get_true( ) ) ;
0 commit comments