Skip to content

Commit 2748963

Browse files
committed
python: smt solver interface
1 parent a03f9bd commit 2748963

File tree

5 files changed

+211
-2
lines changed

5 files changed

+211
-2
lines changed

python/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ pyo3 = { version = "0.26.0" , features = ["num-bigint"]}
1515
patronus = { path = "../patronus" }
1616
baa = { version = "0.17.1", features = ["rand1", "bigint"] }
1717
num-bigint = "0.4.6"
18+
rustc-hash.workspace = true

python/src/expr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@ use num_bigint::BigInt;
55
use patronus::expr::{TypeCheck, WidthInt};
66
use pyo3::exceptions::PyTypeError;
77
use pyo3::prelude::*;
8-
use std::fmt::format;
98

109
#[pyclass]
11-
#[derive(Clone)]
10+
#[derive(Clone, Copy)]
1211
pub struct ExprRef(pub(crate) patronus::expr::ExprRef);
1312

1413
/// Helper for binary ops that require a and b to be bitvectors of the same width

python/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
mod ctx;
22
mod expr;
33
mod sim;
4+
mod smt;
45

56
pub use ctx::Context;
67
use ctx::{ContextGuardRead, ContextGuardWrite};
78
pub use expr::*;
89
pub use sim::{Simulator, interpreter};
10+
pub use smt::*;
911
use std::path::PathBuf;
1012

1113
use ::patronus::btor2;
@@ -259,5 +261,7 @@ fn patronus(_py: Python<'_>, m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> PyRe
259261
m.add_function(wrap_pyfunction!(bit_vec, m)?)?;
260262
m.add_function(wrap_pyfunction!(bit_vec_val, m)?)?;
261263
m.add_function(wrap_pyfunction!(if_expr, m)?)?;
264+
// smt
265+
m.add_function(wrap_pyfunction!(solver, m)?)?;
262266
Ok(())
263267
}

python/src/smt.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
use crate::ExprRef;
2+
use crate::ctx::{ContextGuardRead, ContextGuardWrite};
3+
use baa::{BitVecOps, Value};
4+
use num_bigint::BigInt;
5+
use patronus::expr::{Context, ForEachChild, TypeCheck};
6+
use patronus::mc::get_smt_value;
7+
use patronus::smt::*;
8+
use pyo3::exceptions::PyRuntimeError;
9+
use pyo3::prelude::*;
10+
use rustc_hash::{FxHashMap, FxHashSet};
11+
use std::fs::File;
12+
13+
#[pyclass]
14+
pub struct SolverCtx {
15+
underlying: SmtLibSolverCtx<File>,
16+
declared_symbols: Vec<FxHashMap<String, patronus::expr::ExprRef>>,
17+
}
18+
19+
impl SolverCtx {
20+
fn new(underlying: SmtLibSolverCtx<File>) -> Self {
21+
Self {
22+
underlying,
23+
declared_symbols: vec![FxHashMap::default()],
24+
}
25+
}
26+
27+
fn symbol_by_name(&self, name: &str) -> Option<patronus::expr::ExprRef> {
28+
// from inner to outer
29+
for map in self.declared_symbols.iter().rev() {
30+
if let Some(symbol) = map.get(name) {
31+
return Some(*symbol);
32+
}
33+
}
34+
None
35+
}
36+
}
37+
38+
fn find_symbols(ctx: &Context, e: patronus::expr::ExprRef) -> FxHashSet<patronus::expr::ExprRef> {
39+
let mut out = FxHashSet::default();
40+
patronus::expr::traversal::bottom_up(ctx, e, |ctx, e, _| {
41+
if ctx[e].is_symbol() {
42+
out.insert(e);
43+
}
44+
});
45+
out
46+
}
47+
48+
#[pymethods]
49+
impl SolverCtx {
50+
#[pyo3(signature = (*assertions))]
51+
fn check(&mut self, assertions: Vec<ExprRef>) -> PyResult<CheckSatResult> {
52+
if !assertions.is_empty() {
53+
self.push()?;
54+
for a in assertions.iter() {
55+
self.add(*a)?;
56+
}
57+
}
58+
let r = self
59+
.underlying
60+
.check_sat()
61+
.map(CheckSatResult)
62+
.map_err(convert_smt_err)?;
63+
if !assertions.is_empty() {
64+
self.pop()?;
65+
}
66+
Ok(r)
67+
}
68+
69+
fn push(&mut self) -> PyResult<()> {
70+
self.underlying.push().map_err(convert_smt_err)?;
71+
self.declared_symbols.push(FxHashMap::default());
72+
Ok(())
73+
}
74+
75+
fn pop(&mut self) -> PyResult<()> {
76+
self.underlying.pop().map_err(convert_smt_err)?;
77+
self.declared_symbols.pop();
78+
Ok(())
79+
}
80+
81+
fn add(&mut self, assertion: ExprRef) -> PyResult<()> {
82+
let ctx_guard = ContextGuardRead::default();
83+
let ctx = ctx_guard.deref();
84+
let a = assertion.0;
85+
// scan the expression for any unknown symbols and declare them
86+
let symbols = find_symbols(ctx, a);
87+
for symbol in symbols.into_iter() {
88+
let tpe = ctx[symbol].get_type(ctx);
89+
let name = ctx[symbol].get_symbol_name(ctx).unwrap();
90+
if let Some(existing) = self.symbol_by_name(name) {
91+
// check for compatible type for existing symbols
92+
let existing_tpe = ctx[existing].get_type(ctx);
93+
if existing_tpe != tpe {
94+
return Err(PyRuntimeError::new_err(format!(
95+
"There is already a symbol `{name}` with incompatible type {existing_tpe} != {tpe}"
96+
)));
97+
}
98+
} else {
99+
// declare if symbol does not exist
100+
self.underlying
101+
.declare_const(ctx, symbol)
102+
.map_err(convert_smt_err)?;
103+
self.declared_symbols
104+
.last_mut()
105+
.unwrap()
106+
.insert(name.to_string(), symbol);
107+
}
108+
}
109+
110+
self.underlying.assert(ctx, a).map_err(convert_smt_err)?;
111+
Ok(())
112+
}
113+
114+
fn model(&mut self) -> PyResult<Model> {
115+
let mut ctx_guard = ContextGuardWrite::default();
116+
let ctx = ctx_guard.deref_mut();
117+
let mut entries = vec![];
118+
for s in self.declared_symbols.iter().flat_map(|m| m.values()) {
119+
let value = get_smt_value(ctx, &mut self.underlying, *s).map_err(convert_smt_err)?;
120+
entries.push((*s, value));
121+
}
122+
Ok(Model(entries))
123+
}
124+
}
125+
126+
#[pyclass]
127+
pub struct Model(Vec<(patronus::expr::ExprRef, Value)>);
128+
129+
#[pymethods]
130+
impl Model {
131+
fn __str__(&self) -> String {
132+
"TODO".to_string()
133+
}
134+
135+
fn __len__(&self) -> usize {
136+
self.0.len()
137+
}
138+
139+
fn __getitem__(&self, symbol: ExprRef) -> Option<BigInt> {
140+
self.0
141+
.iter()
142+
.find(|(e, _)| *e == symbol.0)
143+
.map(|(_, value)| match value {
144+
Value::Array(_) => {
145+
todo!("Array support!")
146+
}
147+
Value::BitVec(bv) => bv.to_big_int(),
148+
})
149+
}
150+
}
151+
152+
#[pyclass]
153+
pub struct CheckSatResult(CheckSatResponse);
154+
155+
#[pymethods]
156+
impl CheckSatResult {
157+
fn __str__(&self) -> String {
158+
match self.0 {
159+
CheckSatResponse::Sat => "sat".to_string(),
160+
CheckSatResponse::Unsat => "unsat".to_string(),
161+
CheckSatResponse::Unknown => "unknonw".to_string(),
162+
}
163+
}
164+
}
165+
166+
#[pyfunction]
167+
#[pyo3(name = "Solver")]
168+
pub fn solver(name: &str) -> PyResult<SolverCtx> {
169+
match name.to_ascii_lowercase().as_str() {
170+
"z3" => Ok(SolverCtx::new(Z3.start(None).map_err(convert_smt_err)?)),
171+
"bitwuzla" => Ok(SolverCtx::new(
172+
BITWUZLA.start(None).map_err(convert_smt_err)?,
173+
)),
174+
"yices" | "yices2" | "yices2-smt" => {
175+
Ok(SolverCtx::new(YICES2.start(None).map_err(convert_smt_err)?))
176+
}
177+
_ => Err(PyRuntimeError::new_err(format!(
178+
"Unknonw or unsupported solver: {name}"
179+
))),
180+
}
181+
}
182+
183+
fn convert_smt_err(e: Error) -> PyErr {
184+
PyRuntimeError::new_err(format!("smt: {e}"))
185+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,23 @@ def test_transition_system_builder():
8888
[next] ite(en, add(count, 8'b00000001), count)
8989
"""
9090
assert str(sys).strip() == expected_system.strip()
91+
92+
93+
def test_call_smt_solver():
94+
a = BitVec('a', 3)
95+
b = BitVec('b', 3)
96+
s = Solver('z3')
97+
r = s.check(a < b)
98+
assert str(r) == "sat"
99+
100+
r = s.check(a < b, a > b)
101+
assert str(r) == "unsat"
102+
103+
# to generate a model, we need to actually add the assertion!
104+
s.add(a < b)
105+
s.check()
106+
m = s.model()
107+
assert len(m) == 2
108+
assert isinstance(m[a], int)
109+
assert isinstance(m[b], int)
110+
assert m[a] < m[b]

0 commit comments

Comments
 (0)