Skip to content

Commit 9613842

Browse files
committed
start contracts
1 parent 806d935 commit 9613842

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

src/kdrag/contract.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from dataclasses import dataclass
2+
import kdrag.smt as smt
3+
import kdrag as kd
4+
5+
6+
@dataclass(frozen=True)
7+
class Contract:
8+
f: smt.FuncDeclRef
9+
args: list[smt.ExprRef]
10+
pre: smt.BoolRef
11+
post: smt.BoolRef
12+
proof: kd.Proof
13+
14+
15+
contracts: dict[smt.FuncDeclRef, Contract] = {}
16+
17+
"""
18+
def add_contract(f: smt.FuncDeclRef, proof: kd.Proof):
19+
assert f not in contracts
20+
contracts[f] = proof
21+
"""
22+
23+
24+
def lemmas(e: smt.ExprRef) -> list[kd.Proof]:
25+
"""
26+
Instantiate all contract lemmas found in
27+
"""
28+
res = []
29+
seen: set[int] = set([e.get_id()])
30+
todo = [e]
31+
decls = set()
32+
while todo:
33+
e = todo.pop()
34+
if smt.is_app(e):
35+
f = e.decl()
36+
children = e.children()
37+
if f in contracts:
38+
# we know how this should be instantiated
39+
res.append(contracts[f].proof(*children))
40+
for c in children:
41+
idx = c.get_id()
42+
if idx not in seen:
43+
seen.add(idx)
44+
todo.append(c)
45+
elif isinstance(e, smt.QuantifierRef):
46+
# There may be variables inside here. Fallback to just giving z3
47+
decls.update(kd.utils.decls(e.body()))
48+
else:
49+
raise ValueError(f"Unexpected term type: {e}")
50+
res.extend(contracts[decl].proof for decl in decls if decl in contracts)
51+
return res
52+
53+
54+
def contract(
55+
f: smt.FuncDeclRef, args: list[smt.ExprRef], pre, post, by=None, **kwargs
56+
) -> kd.Proof:
57+
"""
58+
Register the contract for function f: for all args, pre => post.
59+
60+
>>> x = smt.Int("x")
61+
>>> foo = kd.define("foo4392", [x], x + 1)
62+
>>> c = contract(foo, [x], x > 0, foo(x) > 0, by=[foo.defn])
63+
>>> c
64+
|= ForAll(x, Implies(x > 0, foo4392(x) > 0))
65+
>>> c.thm.pattern(0)
66+
foo4392(Var(0))
67+
>>> prove(foo(5) > 0)
68+
|= foo4392(5) > 0
69+
>>> prove(foo(5) > 5)
70+
Traceback (most recent call last):
71+
...
72+
LemmaError: ...
73+
"""
74+
assert f not in contracts
75+
if by is None:
76+
by = lemmas(pre) + lemmas(post)
77+
else:
78+
by = by + lemmas(pre) + lemmas(post)
79+
pf = kd.prove(smt.ForAll(args, pre, post, patterns=[f(*args)]), by=by, **kwargs)
80+
contracts[f] = Contract(f, args, pre, post, pf)
81+
return pf
82+
83+
84+
# def define(name, args, body, pre=None, post=None): ...
85+
86+
87+
def prove(thm: smt.BoolRef, by=[], **kwargs) -> kd.Proof:
88+
by = by + [
89+
contracts[decl].proof for decl in kd.utils.decls(thm) if decl in contracts
90+
]
91+
return kd.prove(thm, by=by, **kwargs)

0 commit comments

Comments
 (0)