Skip to content

Commit 90fe8c9

Browse files
committed
sca: exhaustive check for testing
1 parent c1b201b commit 90fe8c9

File tree

4 files changed

+99
-61
lines changed

4 files changed

+99
-61
lines changed

patronus-sca/src/lib.rs

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
use baa::{BitVecOps, BitVecValue, BitVecValueRef};
66
use patronus::expr::{
77
Context, DenseExprMetaData, Expr, ExprRef, ForEachChild, SerializableIrNode, Simplifier,
8-
TypeCheck, WidthInt, count_expr_uses, simplify_single_expression, traversal,
8+
TypeCheck, WidthInt, count_expr_uses, traversal,
99
};
1010
use polysub::{Coef, Term, VarIndex};
11-
use rustc_hash::FxHashMap;
1211
use std::collections::VecDeque;
1312
use std::fmt::{Display, Formatter};
1413

@@ -85,24 +84,15 @@ fn backwards_sub(ctx: &Context, mut todo: VecDeque<(VarIndex, ExprRef)>, mut spe
8584
var_roots.sort();
8685

8786
let m = spec.get_mod();
88-
println!("MOD={m:?} ({} bits)", m.bits());
8987
let one: DefaultCoef = Coef::from_i64(1, m);
90-
println!("one = {one:?}");
9188
let minus_one: DefaultCoef = Coef::from_i64(-1, m);
92-
println!("minus_one = {minus_one:?}");
9389
let minus_two: DefaultCoef = Coef::from_i64(-2, m);
9490
// first, we count how often expressions are used
9591
let roots: Vec<_> = todo.iter().map(|(_, e)| *e).collect();
9692
let mut uses = count_expr_uses(ctx, roots);
9793
let mut replaced = vec![];
9894

9995
while let Some((output_var, gate)) = todo.pop_back() {
100-
println!(
101-
"{output_var}, {}, {:?} ({})",
102-
expr_to_var(gate) == output_var,
103-
&ctx[gate],
104-
spec.size()
105-
);
10696
replaced.push(output_var);
10797

10898
let add_children = match ctx[gate].clone() {
@@ -171,7 +161,6 @@ fn backwards_sub(ctx: &Context, mut todo: VecDeque<(VarIndex, ExprRef)>, mut spe
171161
}
172162
});
173163
}
174-
println!("{spec}");
175164
}
176165

177166
println!("Roots: {var_roots:?}");
@@ -183,7 +172,7 @@ fn backwards_sub(ctx: &Context, mut todo: VecDeque<(VarIndex, ExprRef)>, mut spe
183172
use patronus::expr::ExprMap;
184173
let mut still_used: Vec<_> = uses
185174
.iter()
186-
.filter(|(k, v)| **v > 0)
175+
.filter(|(_, v)| **v > 0)
187176
.map(|(k, _)| expr_to_var(k))
188177
.collect();
189178
still_used.sort();
@@ -274,7 +263,7 @@ fn build_bottom_up_poly(ctx: &mut Context, e: ExprRef) -> Poly {
274263
poly
275264
}
276265

277-
#[derive(Debug, Clone, PartialEq)]
266+
#[derive(Debug, Clone, Copy, PartialEq)]
278267
pub struct ScaEqualityProblem {
279268
equality: ExprRef,
280269
gate_level: ExprRef,
@@ -486,11 +475,10 @@ impl<'a, 'b> Display for PrettyPoly<'a, 'b> {
486475
#[cfg(test)]
487476
mod tests {
488477
use super::*;
489-
use patronus::expr::{eval_bv_expr, eval_expr, simplify_single_expression};
490-
use patronus::smt::{SmtCommand, read_command, serialize_cmd};
478+
use patronus::expr::{eval_bv_expr, find_symbols, simplify_single_expression};
479+
use patronus::smt::{SmtCommand, read_command};
491480
use rustc_hash::FxHashMap;
492-
use std::io::{BufReader, BufWriter};
493-
use std::ptr::eq;
481+
use std::io::BufReader;
494482

495483
fn read_first_assert_expr(
496484
ctx: &mut Context,
@@ -516,7 +504,13 @@ mod tests {
516504
let candidates = find_sca_simplification_candidates(&ctx, e);
517505
let simplified: Vec<_> = candidates
518506
.into_iter()
519-
.flat_map(|c| simplify_word_level_equality(&mut ctx, c))
507+
.flat_map(|p| {
508+
let exhaustive = is_eq_exhaustive(&ctx, p.clone());
509+
let sca_based = simplify_word_level_equality(&mut ctx, p).unwrap();
510+
let sca_based_bool = ctx[sca_based].is_true();
511+
assert_eq!(exhaustive, sca_based_bool);
512+
Some(sca_based)
513+
})
520514
.collect();
521515
if simplified.is_empty() {
522516
None
@@ -551,7 +545,70 @@ mod tests {
551545

552546
/// Performs an exhaustive check of all input values
553547
fn is_eq_exhaustive(ctx: &Context, p: ScaEqualityProblem) -> bool {
554-
todo!()
548+
let word_symbols = find_symbols(ctx, p.word_level);
549+
let gate_symbols = find_symbols(ctx, p.word_level);
550+
debug_assert_eq!(word_symbols, gate_symbols);
551+
let inputs: Vec<_> = word_symbols
552+
.iter()
553+
.map(|&s| {
554+
let width = s.get_bv_type(ctx).unwrap();
555+
debug_assert!(width <= 16);
556+
let max_value = (1u64 << width) - 1;
557+
(
558+
ctx[s].get_symbol_name(ctx).unwrap().to_string(),
559+
s,
560+
width,
561+
max_value,
562+
)
563+
})
564+
.collect();
565+
566+
let mut values = vec![0u64; inputs.len()];
567+
let max_values: Vec<_> = inputs.iter().map(|(_, _, _, v)| *v).collect();
568+
569+
let mut count = 0;
570+
while values != max_values {
571+
count += 1;
572+
// perform check
573+
let symbols: Vec<_> = inputs
574+
.iter()
575+
.zip(values.iter())
576+
.map(|((_, s, w, _), v)| (*s, BitVecValue::from_u64(*v, *w)))
577+
.collect();
578+
579+
let word_value = eval_bv_expr(&ctx, symbols.as_slice(), p.word_level);
580+
let gate_value = eval_bv_expr(&ctx, symbols.as_slice(), p.gate_level);
581+
let is_equal = eval_bv_expr(&ctx, symbols.as_slice(), p.equality);
582+
583+
if !is_equal.is_true() {
584+
let syms: Vec<_> = inputs
585+
.iter()
586+
.zip(values.iter())
587+
.map(|((n, _, _, _), v)| format!("{n}={v}"))
588+
.collect();
589+
println!(
590+
"Not equal! GATE: {} =/= WORD: {} w/ {}",
591+
gate_value.to_dec_str(),
592+
word_value.to_dec_str(),
593+
syms.join(", ")
594+
);
595+
return false;
596+
}
597+
598+
// increment
599+
for (value, max_value) in values.iter_mut().zip(max_values.iter()) {
600+
debug_assert!(*value <= *max_value);
601+
if *value == *max_value {
602+
*value = 0;
603+
// set to zero and go to next "digit"
604+
} else {
605+
*value += 1;
606+
break; // done
607+
}
608+
}
609+
}
610+
println!("expressions appear to be equivalent after {count} iterations");
611+
true
555612
}
556613

557614
#[test]
@@ -578,11 +635,10 @@ mod tests {
578635
let gate_level = ctx.concat(c0, s0);
579636
let equality = ctx.equal(word_level, gate_level);
580637

581-
let problem = ScaEqualityProblem {
582-
equality,
583-
gate_level,
584-
word_level,
585-
};
638+
let problem = find_sca_simplification_candidates(&ctx, equality)[0];
639+
640+
// manually check that our problem is actually correct
641+
assert!(is_eq_exhaustive(&ctx, problem));
586642

587643
let result = simplify_word_level_equality(&mut ctx, problem).unwrap();
588644
assert_eq!(result, ctx.get_true());
@@ -604,30 +660,10 @@ mod tests {
604660
let gate_level = ctx.concat(gate_level_1, gate_level_0);
605661
let equality = ctx.equal(word_level, gate_level);
606662

607-
// manually check that our problem is actually correct
608-
for a_value in 0..3 {
609-
for b_value in 0..3 {
610-
let symbols = [
611-
(a, BitVecValue::from_u64(a_value, 2)),
612-
(b, BitVecValue::from_u64(b_value, 2)),
613-
];
614-
let word_value = eval_bv_expr(&ctx, symbols.as_slice(), word_level);
615-
let gate_value = eval_bv_expr(&ctx, symbols.as_slice(), gate_level);
616-
let is_equal = eval_bv_expr(&ctx, symbols.as_slice(), equality);
617-
assert!(
618-
is_equal.is_true(),
619-
"a={a_value}, b={b_value}, gate={}, word={}",
620-
gate_value.to_dec_str(),
621-
word_value.to_dec_str()
622-
);
623-
}
624-
}
663+
let problem = find_sca_simplification_candidates(&ctx, equality)[0];
625664

626-
let problem = ScaEqualityProblem {
627-
equality,
628-
gate_level,
629-
word_level,
630-
};
665+
// manually check that our problem is actually correct
666+
assert!(is_eq_exhaustive(&ctx, problem));
631667

632668
let result = simplify_word_level_equality(&mut ctx, problem).unwrap();
633669
assert_eq!(result, ctx.get_true());

patronus/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mod transform;
1414
pub mod traversal;
1515
mod types;
1616

17-
pub use analysis::{UseCountInt, count_expr_uses, update_expr_child_uses};
17+
pub use analysis::{UseCountInt, count_expr_uses, find_symbols, update_expr_child_uses};
1818
pub use context::{Builder, Context, ExprRef, StringRef};
1919
pub use eval::{SymbolValueStore, eval_array_expr, eval_bv_expr, eval_expr};
2020
pub use foreach::ForEachChild;

patronus/src/expr/analysis.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
// released under BSD 3-Clause License
44
// author: Kevin Laeufer <laeufer@cornell.edu>
55

6-
use crate::expr::{Context, ExprMap, ExprRef, ForEachChild, SparseExprMap};
6+
use crate::expr::{Context, ExprMap, ExprRef, ForEachChild, SparseExprMap, traversal};
7+
use rustc_hash::FxHashSet;
78

89
pub type UseCountInt = u16;
910

@@ -24,6 +25,17 @@ pub fn count_expr_uses(ctx: &Context, roots: Vec<ExprRef>) -> impl ExprMap<UseCo
2425
use_count
2526
}
2627

28+
/// Returns all symbols in the given expression.
29+
pub fn find_symbols(ctx: &Context, e: ExprRef) -> FxHashSet<ExprRef> {
30+
let mut out = FxHashSet::default();
31+
traversal::bottom_up(ctx, e, |ctx, e, _| {
32+
if ctx[e].is_symbol() {
33+
out.insert(e);
34+
}
35+
});
36+
out
37+
}
38+
2739
/// Increments the use counts for all children of the expression `expr` and
2840
/// adds any child encountered for the first time to the `todo` list.
2941
#[inline]

python/src/smt.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,6 @@ impl SolverCtx {
3838
}
3939
}
4040

41-
fn find_symbols(ctx: &Context, e: patronus::expr::ExprRef) -> FxHashSet<patronus::expr::ExprRef> {
42-
let mut out = FxHashSet::default();
43-
patronus::expr::traversal::bottom_up(ctx, e, |ctx, e, _| {
44-
if ctx[e].is_symbol() {
45-
out.insert(e);
46-
}
47-
});
48-
out
49-
}
50-
5141
#[pymethods]
5242
impl SolverCtx {
5343
#[pyo3(signature = (*assertions))]
@@ -86,7 +76,7 @@ impl SolverCtx {
8676
let ctx = ctx_guard.deref();
8777
let a = assertion.0;
8878
// scan the expression for any unknown symbols and declare them
89-
let symbols = find_symbols(ctx, a);
79+
let symbols = patronus::expr::find_symbols(ctx, a);
9080
for symbol in symbols.into_iter() {
9181
let tpe = ctx[symbol].get_type(ctx);
9282
let name = ctx[symbol].get_symbol_name(ctx).unwrap();

0 commit comments

Comments
 (0)