Skip to content

Commit d906a72

Browse files
committed
contracts
1 parent 873057e commit d906a72

File tree

4 files changed

+78
-25
lines changed

4 files changed

+78
-25
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
A very nice lean course on real analysis. It is particularly suited to knuckledragger because it avoids mathlib and highly generic polymorphic definitions or crazy typeclasses.
2+
3+
<https://github.com/AlexKontorovich/RealAnalysisGame>

src/kdrag/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from . import utils
1515
from . import datatype
1616
from . import rewrite
17+
from . import contracts
1718
from . import tactics
1819
import functools
1920
from .parsers import microlean
@@ -46,6 +47,8 @@ def define_const(name: str, body: smt.ExprRef) -> smt.ExprRef:
4647
return kernel.define(name, [], body) # type: ignore
4748

4849

50+
contract = contracts.contract
51+
4952
FreshVar = kernel.FreshVar
5053

5154
FreshVars = tactics.FreshVars
Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
"""
2+
3+
Contracts are a database of lemmas associated with a function symbol.
4+
They are usually used to specify intended pre and postconditions for functions as an abstraction over their definition.
5+
This can also be though of as a subtyping refinement.
6+
7+
There is a registry which can be queried.
8+
9+
This registry is not trusted code.
10+
11+
"""
12+
113
from dataclasses import dataclass
214
import kdrag.smt as smt
315
import kdrag as kd
@@ -9,21 +21,17 @@ class Contract:
921
args: list[smt.ExprRef]
1022
pre: smt.BoolRef
1123
post: smt.BoolRef
12-
proof: kd.Proof
24+
proof: kd.kernel.Proof
1325

1426

1527
contracts: dict[smt.FuncDeclRef, Contract] = {}
1628

17-
"""
18-
def add_contract(f: smt.FuncDeclRef, proof: kd.Proof):
19-
assert f not in contracts
20-
contracts[f] = proof
21-
"""
2229

23-
24-
def lemmas(e: smt.ExprRef) -> list[kd.Proof]:
30+
def lemmas(e: smt.ExprRef, into_binders=True) -> list[kd.kernel.Proof]:
2531
"""
26-
Instantiate all contract lemmas found in
32+
Instantiate all contract lemmas found in a pexression.
33+
This sweeps the expression and instantiates the contract lemma with the arguments to the function.
34+
Once it goes under binders, this becomes more difficult, so it returns the quantified form of the lemmas
2735
"""
2836
res = []
2937
seen: set[int] = set([e.get_id()])
@@ -44,7 +52,8 @@ def lemmas(e: smt.ExprRef) -> list[kd.Proof]:
4452
todo.append(c)
4553
elif isinstance(e, smt.QuantifierRef):
4654
# There may be variables inside here. Fallback to just giving z3
47-
decls.update(kd.utils.decls(e.body()))
55+
if into_binders:
56+
decls.update(kd.utils.decls(e.body()))
4857
else:
4958
raise ValueError(f"Unexpected term type: {e}")
5059
res.extend(contracts[decl].proof for decl in decls if decl in contracts)
@@ -53,7 +62,7 @@ def lemmas(e: smt.ExprRef) -> list[kd.Proof]:
5362

5463
def contract(
5564
f: smt.FuncDeclRef, args: list[smt.ExprRef], pre, post, by=None, **kwargs
56-
) -> kd.Proof:
65+
) -> kd.kernel.Proof:
5766
"""
5867
Register the contract for function f: for all args, pre => post.
5968
@@ -64,28 +73,21 @@ def contract(
6473
|= ForAll(x, Implies(x > 0, foo4392(x) > 0))
6574
>>> c.thm.pattern(0)
6675
foo4392(Var(0))
67-
>>> prove(foo(5) > 0)
76+
>>> kd.prove(foo(5) > 0, contracts=True)
6877
|= foo4392(5) > 0
69-
>>> prove(foo(5) > 5)
78+
>>> kd.prove(foo(5) > 5, contracts=True)
7079
Traceback (most recent call last):
7180
...
7281
LemmaError: ...
7382
"""
7483
assert f not in contracts
7584
if by is None:
76-
by = lemmas(pre) + lemmas(post)
77-
else:
78-
by = by + lemmas(pre) + lemmas(post)
79-
pf = kd.prove(smt.ForAll(args, pre, post, patterns=[f(*args)]), by=by, **kwargs)
85+
by = []
86+
thm = smt.ForAll(args, smt.Implies(pre, post), patterns=[f(*args)])
87+
by = by + lemmas(thm)
88+
pf = kd.kernel.prove(thm, by=by, **kwargs) # Do we want kd.tactics.prove here?
8089
contracts[f] = Contract(f, args, pre, post, pf)
8190
return pf
8291

8392

84-
# def define(name, args, body, pre=None, post=None): ...
85-
86-
87-
def prove(thm: smt.BoolRef, by=[], **kwargs) -> kd.Proof:
88-
by = by + [
89-
contracts[decl].proof for decl in kd.utils.decls(thm) if decl in contracts
90-
]
91-
return kd.prove(thm, by=by, **kwargs)
93+
# Special define?

src/kdrag/tactics.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def prove(
229229
timeout=1000,
230230
dump=False,
231231
solver=None,
232+
contracts=None,
232233
# defns=True,
233234
# induct=False,
234235
# simps=simps,
@@ -265,6 +266,10 @@ def prove(
265266
|= succ2(x) == x + 2
266267
>>> prove(succ(x) == x + 1, unfold=[succ])
267268
|= succ(x) == x + 1
269+
>>> kd.contract(succ, [x], x >= 0 , succ(x) >= 0, by=[succ.defn])
270+
|= ForAll(x, Implies(x >= 0, succ(x) >= 0))
271+
>>> kd.prove(smt.Implies(x >=0, succ(x) >= 0), contracts=True)
272+
|= Implies(x >= 0, succ(x) >= 0)
268273
"""
269274
start_time = time.perf_counter()
270275

@@ -278,6 +283,18 @@ def prove(
278283
elif not isinstance(by, list):
279284
by = list(by)
280285

286+
if contracts is None:
287+
pass
288+
elif contracts is True:
289+
by = by + kd.contracts.lemmas(thm)
290+
elif isinstance(contracts, list):
291+
by = list(by)
292+
for decl in contracts:
293+
if decl in kd.contracts.contracts:
294+
by.append(kd.contracts.contracts[decl].proof)
295+
else:
296+
raise KeyError(f"No contract found for {decl}")
297+
281298
if unfold is None:
282299
pass
283300
elif isinstance(unfold, int):
@@ -1076,6 +1093,34 @@ def obtain(self, n: int | smt.QuantifierRef) -> smt.ExprRef | list[smt.ExprRef]:
10761093
"obtain failed. Not an exists", formula, "Available exists:", exists_f
10771094
)
10781095

1096+
def contract(self) -> "ProofState":
1097+
"""
1098+
Add contract lemmas to the context
1099+
1100+
>>> x = smt.Int("x")
1101+
>>> succ = kd.define("mysucc72", [x], x + 1)
1102+
>>> kd.contract(succ, [x], x >= 0, succ(x) >= 0, by=[succ.defn])
1103+
|= ForAll(x, Implies(x >= 0, mysucc72(x) >= 0))
1104+
>>> l = Lemma(smt.Implies(smt.And(x >= 0, succ(x) >= 0), succ(succ(x)) >= 0))
1105+
>>> _ = l.intros()
1106+
>>> l.contract()
1107+
[And(x >= 0, mysucc72(x) >= 0),
1108+
Implies(mysucc72(x) >= 0, mysucc72(mysucc72(x)) >= 0),
1109+
Implies(x >= 0, mysucc72(x) >= 0)]
1110+
?|= mysucc72(mysucc72(x)) >= 0
1111+
1112+
"""
1113+
self.goalctx = self.top_goal()
1114+
ctx, goal = self.goalctx.ctx, self.goalctx.goal
1115+
clemmas = kd.contracts.lemmas(goal)
1116+
if not clemmas:
1117+
raise ValueError("No contract lemmas available for goal", goal)
1118+
for l in clemmas:
1119+
self.add_lemma(l)
1120+
newctx = ctx + [l.thm for l in clemmas]
1121+
self.goals[-1] = self.goalctx._replace(ctx=newctx)
1122+
return self
1123+
10791124
def specialize(self, n: int | smt.QuantifierRef, *ts):
10801125
"""
10811126
Instantiate a universal quantifier in the context.

0 commit comments

Comments
 (0)