|
| 1 | +from dataclasses import dataclass |
| 2 | +import kdrag.smt as smt |
| 3 | +import kdrag as kd |
| 4 | + |
| 5 | + |
| 6 | +@dataclass(frozen=True) |
| 7 | +class Contract: |
| 8 | + f: smt.FuncDeclRef |
| 9 | + args: list[smt.ExprRef] |
| 10 | + pre: smt.BoolRef |
| 11 | + post: smt.BoolRef |
| 12 | + proof: kd.Proof |
| 13 | + |
| 14 | + |
| 15 | +contracts: dict[smt.FuncDeclRef, Contract] = {} |
| 16 | + |
| 17 | +""" |
| 18 | +def add_contract(f: smt.FuncDeclRef, proof: kd.Proof): |
| 19 | + assert f not in contracts |
| 20 | + contracts[f] = proof |
| 21 | +""" |
| 22 | + |
| 23 | + |
| 24 | +def lemmas(e: smt.ExprRef) -> list[kd.Proof]: |
| 25 | + """ |
| 26 | + Instantiate all contract lemmas found in |
| 27 | + """ |
| 28 | + res = [] |
| 29 | + seen: set[int] = set([e.get_id()]) |
| 30 | + todo = [e] |
| 31 | + decls = set() |
| 32 | + while todo: |
| 33 | + e = todo.pop() |
| 34 | + if smt.is_app(e): |
| 35 | + f = e.decl() |
| 36 | + children = e.children() |
| 37 | + if f in contracts: |
| 38 | + # we know how this should be instantiated |
| 39 | + res.append(contracts[f].proof(*children)) |
| 40 | + for c in children: |
| 41 | + idx = c.get_id() |
| 42 | + if idx not in seen: |
| 43 | + seen.add(idx) |
| 44 | + todo.append(c) |
| 45 | + elif isinstance(e, smt.QuantifierRef): |
| 46 | + # There may be variables inside here. Fallback to just giving z3 |
| 47 | + decls.update(kd.utils.decls(e.body())) |
| 48 | + else: |
| 49 | + raise ValueError(f"Unexpected term type: {e}") |
| 50 | + res.extend(contracts[decl].proof for decl in decls if decl in contracts) |
| 51 | + return res |
| 52 | + |
| 53 | + |
| 54 | +def contract( |
| 55 | + f: smt.FuncDeclRef, args: list[smt.ExprRef], pre, post, by=None, **kwargs |
| 56 | +) -> kd.Proof: |
| 57 | + """ |
| 58 | + Register the contract for function f: for all args, pre => post. |
| 59 | +
|
| 60 | + >>> x = smt.Int("x") |
| 61 | + >>> foo = kd.define("foo4392", [x], x + 1) |
| 62 | + >>> c = contract(foo, [x], x > 0, foo(x) > 0, by=[foo.defn]) |
| 63 | + >>> c |
| 64 | + |= ForAll(x, Implies(x > 0, foo4392(x) > 0)) |
| 65 | + >>> c.thm.pattern(0) |
| 66 | + foo4392(Var(0)) |
| 67 | + >>> prove(foo(5) > 0) |
| 68 | + |= foo4392(5) > 0 |
| 69 | + >>> prove(foo(5) > 5) |
| 70 | + Traceback (most recent call last): |
| 71 | + ... |
| 72 | + LemmaError: ... |
| 73 | + """ |
| 74 | + assert f not in contracts |
| 75 | + if by is None: |
| 76 | + by = lemmas(pre) + lemmas(post) |
| 77 | + else: |
| 78 | + by = by + lemmas(pre) + lemmas(post) |
| 79 | + pf = kd.prove(smt.ForAll(args, pre, post, patterns=[f(*args)]), by=by, **kwargs) |
| 80 | + contracts[f] = Contract(f, args, pre, post, pf) |
| 81 | + return pf |
| 82 | + |
| 83 | + |
| 84 | +# def define(name, args, body, pre=None, post=None): ... |
| 85 | + |
| 86 | + |
| 87 | +def prove(thm: smt.BoolRef, by=[], **kwargs) -> kd.Proof: |
| 88 | + by = by + [ |
| 89 | + contracts[decl].proof for decl in kd.utils.decls(thm) if decl in contracts |
| 90 | + ] |
| 91 | + return kd.prove(thm, by=by, **kwargs) |
0 commit comments