Skip to content

Commit eadb231

Browse files
committed
improvements to Lemma
1 parent fbaf56f commit eadb231

File tree

4 files changed

+298
-36
lines changed

4 files changed

+298
-36
lines changed

kdrag/kernel.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,18 @@ def instan(ts: list[smt.ExprRef], pf: Proof) -> Proof:
204204
Instantiate a universally quantified formula.
205205
This is forall elimination
206206
"""
207-
assert is_proof(pf) and pf.thm.is_forall()
207+
assert is_proof(pf) and pf.thm.is_forall() and len(ts) == pf.thm.num_vars()
208208

209209
return __Proof(smt.substitute_vars(pf.thm.body(), *reversed(ts)), reason=[pf])
210210

211211

212212
def instan2(ts: list[smt.ExprRef], thm: smt.BoolRef) -> Proof:
213213
"""
214-
Instantiate a universally quantified formula.
214+
Instantiate a universally quantified formula
215+
`forall xs, P(xs) -> P(ts)`
215216
This is forall elimination
216217
"""
217-
assert smt.is_quantifier(thm) and thm.is_forall()
218+
assert smt.is_quantifier(thm) and thm.is_forall() and len(ts) == thm.num_vars()
218219

219220
return __Proof(
220221
smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(ts))),
@@ -232,19 +233,37 @@ def forget(ts: list[smt.ExprRef], pf: Proof) -> Proof:
232233
return __Proof(smt.Exists(vs, smt.substitute(pf.thm, *zip(ts, vs))), reason=[pf])
233234

234235

235-
def forget2(ts: list[smt.ExprRef], thm: smt.BoolRef) -> Proof:
236+
def forget2(ts: list[smt.ExprRef], thm: smt.QuantifierRef) -> Proof:
236237
"""
237238
"Forget" a term using existentials. This is existential introduction.
239+
`P(ts) -> exists xs, P(xs)`
238240
`thm` is an existential formula, and `ts` are terms to substitute those variables with.
239241
forget easily follows.
242+
https://en.wikipedia.org/wiki/Existential_generalization
240243
"""
241-
assert smt.is_quantifier(thm) and thm.is_exists()
244+
assert smt.is_quantifier(thm) and thm.is_exists() and len(ts) == thm.num_vars()
242245
return __Proof(
243246
smt.Implies(smt.substitute_vars(thm.body(), *reversed(ts)), thm),
244247
reason="exists_intro",
245248
)
246249

247250

251+
def einstan(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
252+
"""
253+
Skolemize an existential quantifier.
254+
`exists xs, P(xs) -> P(cs)` for fresh cs
255+
https://en.wikipedia.org/wiki/Existential_instantiation
256+
"""
257+
# TODO: Hmm. Maybe we don't need to have a Proof? Lessen this to thm.
258+
assert smt.is_quantifier(thm) and thm.is_exists()
259+
260+
skolems = fresh_const(thm)
261+
return skolems, __Proof(
262+
smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(skolems))),
263+
reason=["einstan"],
264+
)
265+
266+
248267
def skolem(pf: Proof) -> tuple[list[smt.ExprRef], Proof]:
249268
"""
250269
Skolemize an existential quantifier.
@@ -262,7 +281,7 @@ def herb(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
262281
"""
263282
Herbrandize a theorem.
264283
It is sufficient to prove a theorem for fresh consts to prove a universal.
265-
Note: Perhaps lambdaized form is better?
284+
Note: Perhaps lambdaized form is better? Return vars and lamda that could receive `|- P[vars]`
266285
"""
267286
assert smt.is_quantifier(thm) and thm.is_forall()
268287
herbs = fresh_const(thm)

kdrag/tactics.py

Lines changed: 169 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from enum import IntEnum
55
import operator as op
66
from . import config
7+
from typing import NamedTuple
78

89

910
class Calc:
@@ -197,58 +198,191 @@ def simp(t: smt.ExprRef, by: list[kd.kernel.Proof] = [], **kwargs) -> kd.kernel.
197198
return lemma(t == t1, by=by, **kwargs)
198199

199200

201+
class Sequent(NamedTuple):
202+
ctx: list[smt.BoolRef]
203+
goal: smt.BoolRef
204+
205+
def __repr__(self):
206+
return repr(self.ctx) + " ?|- " + repr(self.goal)
207+
208+
200209
class Lemma:
201-
# Isar style forward proof
202210
def __init__(self, goal: smt.BoolRef):
203-
# self.cur_goal = goal
204211
self.lemmas = []
205212
self.thm = goal
206-
self.goals = [([], goal)]
213+
self.goals = [Sequent([], goal)]
214+
215+
def fixes(self):
216+
ctx, goal = self.goals[-1]
217+
if smt.is_quantifier(goal) and goal.is_forall():
218+
self.goals.pop()
219+
vs, herb_lemma = kd.kernel.herb(goal)
220+
self.lemmas.append(herb_lemma)
221+
self.goals.append(Sequent(ctx, herb_lemma.thm.arg(0)))
222+
return vs
223+
else:
224+
raise ValueError(f"fixes tactic failed. Not a forall {goal}")
207225

208226
def intros(self):
209227
ctx, goal = self.goals.pop()
210228
if smt.is_quantifier(goal) and goal.is_forall():
211229
vs, herb_lemma = kd.kernel.herb(goal)
212230
self.lemmas.append(herb_lemma)
213-
self.goals.append((ctx, herb_lemma.thm.arg(0)))
214-
return vs
231+
self.goals.append(Sequent(ctx, herb_lemma.thm.arg(0)))
232+
if len(vs) == 1:
233+
return vs[0]
234+
else:
235+
return vs
215236
elif smt.is_implies(goal):
216-
self.goals.append((ctx + [goal.arg(0)], goal.arg(1)))
217-
return self
237+
self.goals.append(Sequent(ctx + [goal.arg(0)], goal.arg(1)))
238+
return self.top_goal()
239+
elif smt.is_not(goal):
240+
self.goals.append((ctx + [goal.arg(0)], smt.BoolVal(False)))
241+
return
242+
else:
243+
raise ValueError("Intros failed.")
218244

219245
def cases(self, t):
220246
ctx, goal = self.goals.pop()
221247
if t.sort() == smt.BoolSort():
222-
self.goals.append((ctx + [smt.Not(t)], goal))
223-
self.goals.append((ctx + [t], goal))
248+
self.goals.append(Sequent(ctx + [smt.Not(t)], goal))
249+
self.goals.append(Sequent(ctx + [t], goal))
224250
elif isinstance(t, smt.DatatypeRef):
225251
dsort = t.sort()
226252
for i in reversed(range(dsort.num_constructors())):
227-
self.goals.append((ctx + [dsort.recognizer(i)(t)], goal))
253+
self.goals.append(Sequent(ctx + [dsort.recognizer(i)(t)], goal))
228254
else:
229255
raise ValueError("Cases failed. Not a bool or datatype")
230-
return self
256+
return self.top_goal()
231257

232258
def auto(self):
233259
ctx, goal = self.goals[-1]
234260
self.lemmas.append(lemma(smt.Implies(smt.And(ctx), goal)))
235261
self.goals.pop()
236-
return self
262+
return self.top_goal()
237263

238-
def split(self):
264+
def einstan(self, n):
239265
ctx, goal = self.goals[-1]
240-
if smt.is_and(goal):
266+
formula = ctx[n]
267+
if smt.is_quantifier(formula) and formula.is_exists():
241268
self.goals.pop()
242-
self.goals.extend([(ctx, c) for c in goal.children()])
269+
fs, einstan_lemma = kd.kernel.einstan(formula)
270+
self.lemmas.append(einstan_lemma)
271+
self.goals.append(
272+
Sequent(ctx[:n] + [einstan_lemma.thm.arg(1)] + ctx[n + 1 :], goal)
273+
)
274+
if len(fs) == 1:
275+
return fs[0]
276+
else:
277+
return fs
278+
else:
279+
raise ValueError("Einstan failed. Not an exists")
280+
281+
def split(self, at=None):
282+
ctx, goal = self.goals[-1]
283+
if at is None:
284+
if smt.is_and(goal):
285+
self.goals.pop()
286+
self.goals.extend([Sequent(ctx, c) for c in goal.children()])
287+
if smt.is_eq(goal):
288+
self.goals.pop()
289+
self.goals.append(Sequent(ctx, smt.Implies(goal.arg(0), goal.arg(1))))
290+
self.goals.append(Sequent(ctx, smt.Implies(goal.arg(1), goal.arg(0))))
291+
else:
292+
raise ValueError("Split failed")
293+
else:
294+
if smt.is_or(ctx[at]):
295+
self.goals.pop()
296+
for c in ctx[at].children():
297+
self.goals.append(Sequent(ctx[:at] + [c] + ctx[at + 1 :], goal))
298+
if smt.is_and(ctx[at]):
299+
self.goals.pop()
300+
self.goals.append(
301+
Sequent(ctx[:at] + ctx[at].children() + ctx[at + 1 :], goal)
302+
)
303+
else:
304+
raise ValueError("Split failed")
305+
306+
def left(self, n=0):
307+
ctx, goal = self.goals[-1]
308+
if smt.is_or(goal):
309+
if n is None:
310+
n = 0
311+
self.goals[-1] = Sequent(ctx, goal.arg(n))
312+
return self.top_goal()
243313
else:
244-
raise ValueError("Split failed. Not an and")
314+
raise ValueError("Left failed. Not an or")
315+
316+
def right(self):
317+
ctx, goal = self.goals[-1]
318+
if smt.is_or(goal):
319+
self.goals[-1] = Sequent(ctx, goal.arg(goal.num_args() - 1))
320+
return self.top_goal()
321+
else:
322+
raise ValueError("Right failed. Not an or")
245323

246324
def exists(self, *ts):
247325
ctx, goal = self.goals[-1]
248326
lemma = kd.kernel.forget2(ts, goal)
249327
self.lemmas.append(lemma)
250-
self.goals[-1] = (ctx, lemma.thm.arg(0))
251-
return self
328+
self.goals[-1] = Sequent(ctx, lemma.thm.arg(0))
329+
return self.top_goal()
330+
331+
def rewrite(self, rule, at=None, rev=False):
332+
"""
333+
`rewrite` allows you to apply rewrite rule (which may either be a Proof or an index into the context) to the goal or to the context.
334+
"""
335+
ctx, goal = self.goals[-1]
336+
if isinstance(rule, int):
337+
rulethm = ctx[rule]
338+
elif kd.kernel.is_proof(rule):
339+
rulethm = rule.thm
340+
if smt.is_quantifier(rulethm) and rulethm.is_forall():
341+
vs, body = kd.utils.open_binder(rulethm)
342+
else:
343+
vs = []
344+
body = rulethm
345+
if smt.is_eq(body):
346+
lhs, rhs = body.arg(0), body.arg(1)
347+
if rev:
348+
lhs, rhs = rhs, lhs
349+
else:
350+
raise ValueError(f"Rewrite tactic failed. Not an equality {rulethm}")
351+
if at is None:
352+
target = goal
353+
elif isinstance(at, int):
354+
target = ctx[at]
355+
else:
356+
raise ValueError(
357+
"Rewrite tactic failed. `at` is not an index into the context"
358+
)
359+
subst = kd.utils.pmatch_rec(vs, lhs, target)
360+
if subst is None:
361+
raise ValueError(
362+
f"Rewrite tactic failed to apply lemma {rulethm} to goal {goal}"
363+
)
364+
else:
365+
self.goals.pop()
366+
lhs1 = smt.substitute(lhs, *[(v, t) for v, t in subst.items()])
367+
rhs1 = smt.substitute(rhs, *[(v, t) for v, t in subst.items()])
368+
target: smt.BoolRef = smt.substitute(target, (lhs1, rhs1))
369+
self.lemmas.append(kd.kernel.instan2([subst[v] for v in vs], rulethm))
370+
if kd.kernel.is_proof(rule):
371+
self.lemmas.append(rule)
372+
if at is None:
373+
self.goals.append(Sequent(ctx, target))
374+
else:
375+
self.goals.append(Sequent(ctx[:at] + [target] + ctx[at + 1 :], goal))
376+
return self.top_goal()
377+
378+
def rw(self, rule, at=None, rev=False):
379+
return self.rewrite(rule, at=at, rev=rev)
380+
381+
def unfold(self, decl: smt.FuncDeclRef):
382+
if hasattr(decl, "defn"):
383+
return self.rewrite(decl.defn)
384+
else:
385+
raise ValueError("Unfold failed. Not a defined function")
252386

253387
def apply(self, pf: kd.kernel.Proof, rev=False):
254388
ctx, goal = self.goals.pop()
@@ -273,32 +407,38 @@ def apply(self, pf: kd.kernel.Proof, rev=False):
273407
pf1 = kd.kernel.instan([subst[v] for v in vs], pf)
274408
self.lemmas.append(pf1)
275409
if smt.is_implies(pf1.thm):
276-
self.goals.append((ctx, pf1.thm.arg(0)))
410+
self.goals.append(Sequent(ctx, pf1.thm.arg(0)))
277411
elif smt.is_eq(pf1.thm):
278412
if rev:
279-
self.goals.append((ctx, pf1.thm.arg(0)))
413+
self.goals.append(Sequent(ctx, pf1.thm.arg(0)))
280414
else:
281-
self.goals.append((ctx, pf1.thm.arg(1)))
282-
return self
415+
self.goals.append(Sequent(ctx, pf1.thm.arg(1)))
416+
return self.top_goal()
283417

284418
def assumption(self):
285419
ctx, goal = self.goals.pop()
286420
if any([goal.eq(h) for h in ctx]):
287-
return self
421+
return self.top_goal()
288422
else:
289423
raise ValueError("Assumption tactic failed", goal, ctx)
290424

291425
def have(self, conc, **kwargs):
292426
ctx, goal = self.goals.pop()
293427
self.lemmas.append(lemma(smt.Implies(smt.And(ctx), conc)), **kwargs)
294-
self.goals.append((ctx + [conc], conc))
295-
return self
428+
self.goals.append(Sequent(ctx + [conc], conc))
429+
return self.top_goal()
296430

297-
def __repr__(self):
431+
# TODO
432+
# def search():
433+
# def calc
434+
435+
def top_goal(self):
298436
if len(self.goals) == 0:
299437
return "Nothing to do. Hooray!"
300-
ctx, goal = self.goals[-1]
301-
return repr(ctx) + " ?|- " + repr(goal)
438+
return self.goals[-1]
439+
440+
def __repr__(self):
441+
return repr(self.top_goal())
302442

303443
def qed(self):
304-
return lemma(self.thm, by=self.lemmas)
444+
return kd.kernel.lemma(self.thm, by=self.lemmas)

kdrag/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def pmatch(
4949
https://www.philipzucker.com/ho_unify/
5050
"""
5151
if pat.sort() != t.sort():
52-
raise Exception("Sort mismatch", pat, t)
52+
return None
5353
subst = {}
5454
todo = [(pat, t)]
5555
no_escape = []
@@ -107,6 +107,19 @@ def check_escape(x):
107107
return subst
108108

109109

110+
def pmatch_rec(
111+
vs: list[smt.ExprRef], pat: smt.ExprRef, t: smt.ExprRef
112+
) -> Optional[dict[smt.ExprRef, smt.ExprRef]]:
113+
todo = [t]
114+
while todo:
115+
t = todo.pop()
116+
subst = pmatch(vs, pat, t)
117+
if subst is not None:
118+
return subst
119+
elif smt.is_app(t):
120+
todo.extend(t.children())
121+
122+
110123
def rewrite1(
111124
t: smt.ExprRef, vs: list[smt.ExprRef], lhs: smt.ExprRef, rhs: smt.ExprRef
112125
) -> Optional[smt.ExprRef]:

0 commit comments

Comments
 (0)