Skip to content

Commit 3fdb066

Browse files
committed
sca: building a gate polynomial from the bottom up does not seem to work
1 parent 859a0f1 commit 3fdb066

File tree

5 files changed

+154
-76
lines changed

5 files changed

+154
-76
lines changed

patronus-sca/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ patronus = { path = "../patronus" }
1414
baa = {version = "0.17.1", features = ["bigint"]}
1515
rustc-hash.workspace = true
1616
smallvec.workspace = true
17-
polysub = "0.2.4"
17+
polysub = "0.2.5"
1818
bit-set = "0.8.0"

patronus-sca/src/forward.rs

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,73 @@
77
88
use crate::{Poly, expr_to_var, extract_bit};
99
use baa::{BitVecOps, BitVecValueRef};
10-
use patronus::expr::{Context, Expr, ExprRef, ExprSet, SparseExprSet, TypeCheck, traversal};
11-
use polysub::Coef;
12-
use rustc_hash::FxHashSet;
10+
use patronus::expr::{
11+
Context, Expr, ExprRef, ExprSet, ForEachChild, SerializableIrNode, SparseExprMap,
12+
SparseExprSet, TypeCheck, count_expr_uses, traversal,
13+
};
14+
use polysub::{Coef, Mod};
15+
use rustc_hash::{FxHashMap, FxHashSet};
16+
17+
#[derive(Debug, Copy, Clone)]
18+
pub enum BuildPolyMode {
19+
Arithmetic,
20+
Gates(Mod),
21+
}
1322

1423
/// Returns a polynomial representation of the expression + all input expressions if possible.
1524
/// Returns `None` if the conversion fails.
1625
pub fn build_bottom_up_poly(
1726
ctx: &mut Context,
1827
inputs: &FxHashSet<ExprRef>,
1928
e: ExprRef,
29+
mode: BuildPolyMode,
2030
) -> Option<Poly> {
21-
traversal::bottom_up_mut(ctx, e, |ctx, e, c| to_poly(ctx, inputs, e, c)).map(|(p, _)| p)
31+
// we use a custom traversal that caches polynomials until they are no longer used
32+
let mut uses = count_expr_uses(ctx, vec![e]);
33+
let mut todo = vec![e];
34+
let mut result: FxHashMap<ExprRef, Option<(Poly, bool)>> = FxHashMap::default();
35+
let mut child_vec = Vec::with_capacity(4);
36+
37+
while let Some(e) = todo.pop() {
38+
assert!(!result.contains_key(&e));
39+
40+
let expr = &ctx[e];
41+
// find children that are not available yet.
42+
debug_assert!(child_vec.is_empty());
43+
expr.collect_children(&mut child_vec);
44+
let all_available = child_vec.iter().all(|c| result.contains_key(c));
45+
46+
if all_available {
47+
let child_results: Vec<Option<&(Poly, bool)>> =
48+
child_vec.iter().map(|c| result[c].as_ref()).collect();
49+
let r = to_poly(ctx, inputs, mode, e, &child_results);
50+
result.insert(e, r);
51+
for child in child_vec.drain(..) {
52+
let old_use = uses[child];
53+
if old_use == 1 {
54+
result.remove(&child);
55+
}
56+
uses[child] = old_use - 1;
57+
}
58+
} else {
59+
todo.push(e);
60+
for child in child_vec.drain(..) {
61+
if !result.contains_key(&child) {
62+
todo.push(child);
63+
}
64+
}
65+
}
66+
}
67+
68+
result[&e].as_ref().map(|(p, _)| p.clone())
2269
}
2370

2471
fn to_poly(
2572
ctx: &mut Context,
2673
inputs: &FxHashSet<ExprRef>,
74+
mode: BuildPolyMode,
2775
e: ExprRef,
28-
children: &[Option<(Poly, bool)>],
76+
children: &[Option<&(Poly, bool)>],
2977
) -> Option<(Poly, bool)> {
3078
// have we given up yet?
3179
if children.iter().any(|c| c.is_none()) {
@@ -37,6 +85,89 @@ fn to_poly(
3785
return Some((poly_for_bv_expr(ctx, e), false));
3886
}
3987

88+
match mode {
89+
BuildPolyMode::Arithmetic => to_poly_arithmetic(ctx, inputs, e, children),
90+
BuildPolyMode::Gates(m) => Some((to_poly_gate(ctx, inputs, m, e, children), false)),
91+
}
92+
}
93+
94+
/// For bit-level polynomials from gates, the normal arithmetic rules and overflow checking do not
95+
/// apply. Instead, we use the modulo coefficient of the top-level expression.
96+
fn to_poly_gate(
97+
ctx: &mut Context,
98+
inputs: &FxHashSet<ExprRef>,
99+
m: Mod,
100+
e: ExprRef,
101+
children: &[Option<&(Poly, bool)>],
102+
) -> Poly {
103+
debug_assert!(
104+
children
105+
.iter()
106+
.all(|c| c.map(|(_, ov)| !*ov).unwrap_or(true))
107+
);
108+
109+
match (ctx[e].clone(), children) {
110+
(Expr::BVSymbol { .. }, _) => unreachable!("all symbols should be in inputs"),
111+
(Expr::BVLiteral(value), _) => {
112+
let mut r = poly_for_bv_literal(value.get(ctx));
113+
r.change_mod(m);
114+
r
115+
}
116+
(Expr::BVSlice { e, hi, lo }, _) if hi == lo && inputs.contains(&e) => {
117+
// special case: bit_slice of an input
118+
let var = expr_to_var(extract_bit(ctx, e, hi));
119+
Poly::from_monoms(m, [(Coef::from_i64(1, m), vec![var].into())].into_iter())
120+
}
121+
(Expr::BVConcat(_, be, w), [Some((a, _)), Some((b, _))]) => {
122+
// left shift a
123+
let shift_by = be.get_bv_type(ctx).unwrap();
124+
let shift_coef = Coef::pow2(shift_by, m);
125+
let mut r: Poly = a.clone();
126+
r.scale(&shift_coef);
127+
r.add_assign(b);
128+
r
129+
}
130+
(Expr::BVOr(_, _, 1), [Some((a, _)), Some((b, _))]) => {
131+
// a + b - ab
132+
let mut r = a.mul(b);
133+
r.scale(&Coef::from_i64(-1, a.get_mod()));
134+
r.add_assign(a);
135+
r.add_assign(b);
136+
r
137+
}
138+
(Expr::BVXor(_, _, 1), [Some((a, _)), Some((b, _))]) => {
139+
// a + b - 2ab
140+
let mut r = a.mul(b);
141+
let minus_2 = Coef::from_i64(-2, a.get_mod());
142+
r.scale(&minus_2);
143+
r.add_assign(a);
144+
r.add_assign(b);
145+
r
146+
}
147+
(Expr::BVAnd(_, _, 1), [Some((a, _)), Some((b, _))]) => {
148+
// ab
149+
a.clone().mul(b)
150+
}
151+
(Expr::BVNot(_, 1), [Some((a, _))]) => {
152+
// 1 - a
153+
let one = Poly::from_monoms(m, [(Coef::from_i64(1, m), vec![].into())].into_iter());
154+
let mut r = a.clone();
155+
r.scale(&Coef::from_i64(-1, a.get_mod()));
156+
r.add_assign(&one);
157+
r
158+
}
159+
(other, cs) => todo!("{other:?}: {cs:?}"),
160+
}
161+
}
162+
163+
/// When building a polynomial over an arithmetic circuit, we are tracking overflow and
164+
/// determining the modulo coefficient from the bit-widths.
165+
fn to_poly_arithmetic(
166+
ctx: &mut Context,
167+
_inputs: &FxHashSet<ExprRef>,
168+
e: ExprRef,
169+
children: &[Option<&(Poly, bool)>],
170+
) -> Option<(Poly, bool)> {
40171
match (ctx[e].clone(), children) {
41172
(Expr::BVSymbol { .. }, _) => unreachable!("all symbols should be in inputs"),
42173
(Expr::BVLiteral(value), _) => Some((poly_for_bv_literal(value.get(ctx)), false)),

patronus-sca/src/lib.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
mod forward;
66
mod rewrite;
77

8-
use crate::forward::{build_bottom_up_poly, poly_for_bv_expr};
9-
use crate::rewrite::{backwards_sub, build_gate_polynomial};
8+
use crate::forward::{BuildPolyMode, build_bottom_up_poly, poly_for_bv_expr};
9+
use crate::rewrite::backwards_sub;
1010
use baa::{BitVecOps, BitVecValue, BitVecValueRef};
1111
use patronus::expr::*;
1212
use polysub::{Coef, Term, VarIndex};
@@ -32,11 +32,13 @@ pub fn verify_word_level_equality(ctx: &mut Context, p: ScaEqualityProblem) -> S
3232
let inputs = find_symbols(ctx, p.word_level);
3333

3434
// create a reference polynomial from the word level side
35-
let mut word_poly = match build_bottom_up_poly(ctx, &inputs, p.word_level) {
36-
None => return ScaVerifyResult::Unknown,
37-
Some(p) => p,
38-
};
35+
let mut word_poly =
36+
match build_bottom_up_poly(ctx, &inputs, p.word_level, BuildPolyMode::Arithmetic) {
37+
None => return ScaVerifyResult::Unknown,
38+
Some(p) => p,
39+
};
3940
println!("word-level polynomial: {word_poly}");
41+
println!("word-level bits: {}", word_poly.get_mod().bits());
4042

4143
// collect all (bit-level) input variables
4244
let input_vars: FxHashSet<VarIndex> = inputs
@@ -50,7 +52,11 @@ pub fn verify_word_level_equality(ctx: &mut Context, p: ScaEqualityProblem) -> S
5052
})
5153
.collect();
5254

53-
//let gate_poly = build_gate_polynomial(ctx, &input_vars, word_poly.get_mod(), p.gate_level);
55+
// TODO: how do we calculate a correct polynomial from the gate level?
56+
// let gate_poly = build_bottom_up_poly(ctx, &inputs, p.gate_level, BuildPolyMode::Gates(word_poly.get_mod()));
57+
// if let Some(gate_poly) = gate_poly {
58+
// println!("GATE POLY: {gate_poly}");
59+
// }
5460

5561
// the actual reference polynomial needs to contain the output bits as well
5662
let output_poly = poly_for_bv_expr(ctx, p.word_level);
@@ -395,7 +401,9 @@ mod tests {
395401
let cs = find_sca_simplification_candidates(&ctx, e);
396402
for c in cs {
397403
let inputs = find_symbols(&ctx, c.word_level);
398-
if let Some(word_poly) = build_bottom_up_poly(&mut ctx, &inputs, c.word_level) {
404+
if let Some(word_poly) =
405+
build_bottom_up_poly(&mut ctx, &inputs, c.word_level, BuildPolyMode::Arithmetic)
406+
{
399407
println!("{filename}:\n{}\n", PrettyPoly::n(&ctx, &word_poly));
400408
} else {
401409
println!(

patronus-sca/src/rewrite.rs

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -317,67 +317,6 @@ fn try_exhaustive(
317317
min_with_lowest_use
318318
}
319319

320-
/// tries to build the gate-level polynomial from the bottom up.
321-
pub fn build_gate_polynomial(
322-
ctx: &Context,
323-
input_vars: &FxHashSet<VarIndex>,
324-
m: Mod,
325-
e: ExprRef,
326-
) -> Poly {
327-
traversal::bottom_up(ctx, e, |ctx, gate, children: &[Poly]| {
328-
let var = expr_to_var(gate);
329-
330-
// if this is an input, we just want to return the variable
331-
if input_vars.contains(&var) {
332-
return Poly::from_monoms(m, [(Coef::from_i64(1, m), vec![var].into())].into_iter());
333-
}
334-
335-
match (ctx[gate].clone(), children) {
336-
(Expr::BVOr(_, _, 1), [a, b]) => {
337-
// a + b - ab
338-
let mut r = a.mul(b);
339-
r.scale(&Coef::from_i64(-1, m));
340-
r.add_assign(a);
341-
r.add_assign(b);
342-
r
343-
}
344-
(Expr::BVXor(_, _, 1), [a, b]) => {
345-
// a + b - 2ab
346-
let mut r = a.mul(b);
347-
r.scale(&Coef::from_i64(-2, m));
348-
r.add_assign(a);
349-
r.add_assign(b);
350-
r
351-
}
352-
(Expr::BVAnd(_, _, 1), [a, b]) => {
353-
// ab
354-
a.mul(b)
355-
}
356-
(Expr::BVNot(_, 1), [a]) => {
357-
// 1 - a
358-
let one = Poly::from_monoms(m, [(Coef::from_i64(1, m), vec![].into())].into_iter());
359-
let mut r = a.clone();
360-
r.scale(&Coef::from_i64(-1, m));
361-
r.add_assign(&one);
362-
r
363-
}
364-
(Expr::BVSlice { .. }, [_]) => {
365-
todo!("should not get here!")
366-
}
367-
(Expr::BVLiteral(value), _) => {
368-
let value = value.get(ctx);
369-
debug_assert_eq!(value.width(), 1);
370-
if value.is_true() {
371-
Poly::from_monoms(m, [(Coef::from_i64(1, m), vec![].into())].into_iter())
372-
} else {
373-
Poly::from_monoms(m, [(Coef::from_i64(0, m), vec![].into())].into_iter())
374-
}
375-
}
376-
other => todo!("add support for {other:?}"),
377-
}
378-
})
379-
}
380-
381320
/// Calculates for each expression which root depends on it.
382321
fn analyze_uses(ctx: &Context, roots: &[ExprRef]) -> impl ExprMap<BitSet> {
383322
let mut out = DenseExprMetaData::<BitSet>::default();

patronus/src/expr/analysis.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rustc_hash::FxHashSet;
99
pub type UseCountInt = u16;
1010

1111
/// Counts how often expressions in the DAGs characterized by the provided roots are used.
12-
pub fn count_expr_uses(ctx: &Context, roots: Vec<ExprRef>) -> impl ExprMap<UseCountInt> {
12+
pub fn count_expr_uses(ctx: &Context, roots: Vec<ExprRef>) -> impl ExprMap<UseCountInt> + use<> {
1313
let mut use_count = SparseExprMap::default();
1414
let mut todo = roots;
1515

0 commit comments

Comments
 (0)