Skip to content

Commit f992c7e

Browse files
committed
__call__ notation, new is_* predicates, rename keyword to admit
1 parent 8aa6211 commit f992c7e

File tree

4 files changed

+99
-10
lines changed

4 files changed

+99
-10
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ assert not isinstance(simple_taut, smt.ExprRef)
4141
# kd.lemma will throw an error if the theorem is not provable
4242
try:
4343
false_lemma = kd.lemma(smt.Implies(p, smt.And(p, q)))
44-
assert False # This will not be reached
44+
print("This will not be reached")
4545
except kd.kernel.LemmaError as e:
4646
pass
4747

@@ -58,7 +58,7 @@ or_idem = kd.lemma(smt.ForAll([x], x | x == x))
5858
# But the point of Knuckledragger is really for the things Z3 can't do in one shot
5959

6060
# Knuckledragger support algebraic datatypes and induction
61-
Nat = kd.Inductive("Nat", strict=False)
61+
Nat = kd.Inductive("Nat", admit=True)
6262
Zero = Nat.declare("Zero")
6363
Succ = Nat.declare("Succ", ("pred", Nat))
6464
Nat = Nat.create()

kdrag/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
solver = None
6+
# admit_enabled = True
67
# timeout = 1000
78

89
# TODO: Someday, when it is annoyingly slow to check built in theories, we can add a flag to disable them

kdrag/notation.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@
2727
smt.ArrayRef.__call__ = lambda self, arg: self[arg]
2828

2929

30+
def quantifier_call(self, *args):
31+
"""
32+
Instantiate a quantifier. This does substitution
33+
>>> x,y = smt.Ints("x y")
34+
>>> smt.Lambda([x,y], x + 1)(2,3)
35+
2 + 1
36+
37+
To apply a Lambda without substituting, use square brackets
38+
>>> smt.Lambda([x,y], x + 1)[2,3]
39+
Select(Lambda([x, y], x + 1), 2, 3)
40+
"""
41+
if self.num_vars() != len(args):
42+
raise TypeError("Wrong number of arguments", self, args)
43+
return smt.substitute_vars(
44+
self.body(), *(smt._py2expr(arg) for arg in reversed(args))
45+
)
46+
47+
48+
smt.QuantifierRef.__call__ = quantifier_call
49+
50+
3051
class SortDispatch:
3152
"""
3253
Sort dispatch is modeled after functools.singledispatch
@@ -114,6 +135,12 @@ def QForAll(vs: list[smt.ExprRef], *hyp_conc) -> smt.BoolRef:
114135
115136
If variables have a property `wf` attached, this is added as a hypothesis.
116137
138+
There is no downside to always using this compared to `smt.ForAll` and it can avoid some errors.
139+
140+
>>> x,y = smt.Ints("x y")
141+
>>> QForAll([x,y], x > 0, y > 0, x + y > 0)
142+
ForAll([x, y], Implies(And(x > 0, y > 0), x + y > 0))
143+
117144
"""
118145
conc = hyp_conc[-1]
119146
hyps = hyp_conc[:-1]
@@ -187,13 +214,13 @@ def datatype_call(self, *args):
187214
records = {}
188215

189216

190-
def Record(name: str, *fields, pred=None) -> smt.DatatypeSortRef:
217+
def Record(name: str, *fields, pred=None, admit=False) -> smt.DatatypeSortRef:
191218
"""
192219
Define a record datatype.
193220
The optional argument `pred` will add a well-formedness condition to the record
194221
giving something akin to a refinement type.
195222
"""
196-
if name in records:
223+
if not admit and name in records:
197224
raise Exception("Record already defined", name)
198225
rec = smt.Datatype(name)
199226
rec.declare(name, *fields)
@@ -219,9 +246,24 @@ def Record(name: str, *fields, pred=None) -> smt.DatatypeSortRef:
219246
return rec
220247

221248

222-
def NewType(name: str, sort: smt.SortRef, pred=None) -> smt.DatatypeSortRef:
249+
def NewType(
250+
name: str, sort: smt.SortRef, pred=None, admit=False
251+
) -> smt.DatatypeSortRef:
223252
"""Minimal wrapper around a sort for sort based overloading"""
224-
return Record(name, ("val", sort), pred=pred)
253+
return Record(name, ("val", sort), pred=pred, admit=admit)
254+
255+
256+
def Enum(name, args, admit=False):
257+
"""Shorthand for simple enumeration datatypes. Similar to python's Enum.
258+
>>> Color = Enum("Color", "Red Green Blue")
259+
>>> smt.And(Color.Red != Color.Green, Color.Red != Color.Blue)
260+
And(Red != Green, Red != Blue)
261+
"""
262+
T = kd.Inductive(name, admit=admit)
263+
for c in args.split():
264+
T.declare(c)
265+
T = T.create()
266+
return T
225267

226268

227269
def induct_inductive(DT: smt.DatatypeSortRef, x=None, P=None) -> kd.kernel.Proof:
@@ -257,9 +299,9 @@ def induct_inductive(DT: smt.DatatypeSortRef, x=None, P=None) -> kd.kernel.Proof
257299
)
258300

259301

260-
def Inductive(name: str, strict=True) -> smt.DatatypeSortRef:
302+
def Inductive(name: str, admit=False) -> smt.DatatypeSortRef:
261303
"""Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype"""
262-
if strict and name in records:
304+
if not admit and name in records:
263305
raise Exception(
264306
"Datatype with that name already defined. Use keyword strict=False to override",
265307
name,
@@ -271,7 +313,7 @@ def Inductive(name: str, strict=True) -> smt.DatatypeSortRef:
271313
def create():
272314
dt = oldcreate()
273315
# Sanity check no duplicate names. Causes confusion.
274-
if strict:
316+
if not admit:
275317
names = set()
276318
for i in range(dt.num_constructors()):
277319
cons = dt.constructor(i)
@@ -284,7 +326,6 @@ def create():
284326
if n in names:
285327
raise Exception("Duplicate field name", n)
286328
names.add(n)
287-
x = smt.FreshConst(dt, prefix="x")
288329
kd.notation.induct.register(dt, lambda x: induct_inductive(dt, x=x))
289330
records[name] = dt
290331
return dt
@@ -318,6 +359,10 @@ def cond(*cases, default=None) -> smt.ExprRef:
318359
return acc
319360

320361

362+
def conde(*cases):
363+
return smt.Or([smt.And(c) for c in cases])
364+
365+
321366
class Cond:
322367
def __init__(self):
323368
self.cases = []

kdrag/smt.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,49 @@
1515
solver = "z3"
1616
from z3 import *
1717

18+
_py2expr = z3.z3._py2expr
19+
20+
def is_if(x: z3.ExprRef) -> bool:
21+
"""
22+
Check if an expression is an if-then-else.
23+
>>> is_if(z3.If(True, 1, 2))
24+
True
25+
"""
26+
return z3.is_app_of(x, z3.Z3_OP_ITE)
27+
28+
def is_constructor(x: z3.ExprRef) -> bool:
29+
"""
30+
Check if an expression is a constructor.
31+
>>> Color = z3.Datatype("Color")
32+
>>> Color.declare("red")
33+
>>> Color = Color.create()
34+
>>> is_constructor(Color.red)
35+
True
36+
"""
37+
return z3.is_app_of(x, z3.Z3_OP_DT_CONSTRUCTOR)
38+
39+
def is_accessor(x: z3.ExprRef) -> bool:
40+
"""
41+
Check if an expression is an accessor.
42+
>>> Color = z3.Datatype("Color")
43+
>>> Color.declare("red", ("r", z3.IntSort()))
44+
>>> Color = Color.create()
45+
>>> is_accessor(Color.r(Color.red(3)))
46+
True
47+
"""
48+
return z3.is_app_of(x, z3.Z3_OP_DT_ACCESSOR)
49+
50+
def is_recognizer(x: z3.ExprRef) -> bool:
51+
"""
52+
Check if recognizer.
53+
>>> Color = z3.Datatype("Color")
54+
>>> Color.declare("red")
55+
>>> Color = Color.create()
56+
>>> is_recognizer(Color.is_red(Color.red))
57+
True
58+
"""
59+
return z3.is_app_of(x, z3.Z3_OP_DT_IS)
60+
1861
Z3Solver = Solver
1962
elif solver == VAMPIRESOLVER:
2063
from z3 import *

0 commit comments

Comments
 (0)