Skip to content

Commit 90c329b

Browse files
committed
made at parameters take exprref. started rfun
1 parent 1dda599 commit 90c329b

File tree

2 files changed

+95
-29
lines changed

2 files changed

+95
-29
lines changed

src/kdrag/tactics.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -434,23 +434,25 @@ def to_expr(self):
434434
else:
435435
return self.goal
436436

437-
def ctx_find(self, n: int | smt.BoolRef) -> smt.BoolRef:
437+
def ctx_find(self, n: int | smt.BoolRef) -> tuple[int, smt.BoolRef]:
438438
"""
439439
Find a hypothesis in the context by index or by matching expression.
440440
441441
>>> x = smt.Int("x")
442442
>>> g = Goal(sig=[], ctx=[x > 0, x < 10], goal=x == 5)
443443
>>> g.ctx_find(0)
444-
x > 0
444+
(0, x > 0)
445445
>>> g.ctx_find(x < 10)
446-
x < 10
446+
(1, x < 10)
447447
"""
448448
if isinstance(n, int):
449-
return self.ctx[n]
449+
if n < 0:
450+
n += len(self.ctx)
451+
return n, self.ctx[n]
450452
else:
451-
for h in self.ctx:
453+
for i, h in enumerate(self.ctx):
452454
if h.eq(n):
453-
return h
455+
return i, h
454456
raise KeyError(f"Hypothesis {n} not found in context")
455457

456458
def proof(self) -> "ProofState":
@@ -782,9 +784,7 @@ def simp(self, at=None, unfold=False, path=None) -> "ProofState":
782784
self.goals[-1] = goalctx._replace(goal=newgoal)
783785
else:
784786
oldctx = goalctx.ctx
785-
if at < 0:
786-
at = len(oldctx) + at
787-
old = oldctx[at]
787+
(at, old) = goalctx.ctx_find(at)
788788
new = kd.utils.pathmap(smt.simplify, old, path)
789789
if new.eq(old):
790790
raise ValueError("Simplify failed. Ctx is already simplified.")
@@ -1239,15 +1239,14 @@ def cb():
12391239
raise ValueError("Unexpected case in goal for split tactic", goal)
12401240
return self
12411241
else:
1242-
if at < 0:
1243-
at = len(ctx) + at
1244-
if smt.is_or(ctx[at]):
1242+
(at, hyp) = goalctx.ctx_find(at)
1243+
if smt.is_or(hyp):
12451244
self.pop_goal()
1246-
for c in ctx[at].children():
1245+
for c in hyp.children():
12471246
self.goals.append(
12481247
goalctx._replace(ctx=ctx[:at] + [c] + ctx[at + 1 :], goal=goal)
12491248
)
1250-
elif smt.is_and(ctx[at]):
1249+
elif smt.is_and(hyp):
12511250
self.pop_goal()
12521251
self.goals.append(
12531252
goalctx._replace(
@@ -1323,7 +1322,7 @@ def exists(self, *ts) -> "ProofState":
13231322
return self
13241323

13251324
def rw(
1326-
self, rule: kd.kernel.Proof | int, at=None, rev=False, **kwargs
1325+
self, rule: kd.kernel.Proof | int | smt.BoolRef, at=None, rev=False, **kwargs
13271326
) -> "ProofState":
13281327
"""
13291328
`rewrite` allows you to apply rewrite rule (which may either be a Proof or an index into the context) to the goal or to the context.
@@ -1340,8 +1339,10 @@ def rw(
13401339
"""
13411340
goalctx = self.top_goal()
13421341
ctx, goal = goalctx.ctx, goalctx.goal
1343-
if isinstance(rule, int):
1344-
rulethm = ctx[rule]
1342+
if at is not None:
1343+
(at, hyp) = goalctx.ctx_find(at)
1344+
if isinstance(rule, int) or isinstance(rule, smt.ExprRef):
1345+
_, rulethm = goalctx.ctx_find(rule)
13451346
elif kd.kernel.is_proof(rule):
13461347
rulethm = rule.thm
13471348
else:
@@ -1365,12 +1366,8 @@ def rw(
13651366
raise ValueError(f"Rewrite tactic failed. Not an equality {rulethm}")
13661367
if at is None:
13671368
target = goal
1368-
elif isinstance(at, int):
1369-
target = ctx[at]
13701369
else:
1371-
raise ValueError(
1372-
"Rewrite tactic failed. `at` is not an index into the context"
1373-
)
1370+
at, target = goalctx.ctx_find(at)
13741371
t_subst = kd.utils.pmatch_rec(vs, lhs, target)
13751372
if t_subst is None:
13761373
raise ValueError(
@@ -1383,13 +1380,11 @@ def rw(
13831380
target: smt.BoolRef = smt.substitute(target, (lhs1, rhs1))
13841381
if isinstance(rulethm, smt.QuantifierRef) and rulethm.is_forall():
13851382
self.add_lemma(kd.kernel.specialize([subst[v] for v in vs], rulethm))
1386-
if not isinstance(rule, int) and kd.kernel.is_proof(rule):
1383+
if isinstance(rule, kd.kernel.Proof):
13871384
self.add_lemma(rule)
13881385
if at is None:
13891386
self.goals.append(goalctx._replace(ctx=ctx, goal=target))
13901387
else:
1391-
if at == -1:
1392-
at = len(ctx) - 1
13931388
self.goals.append(
13941389
goalctx._replace(ctx=ctx[:at] + [target] + ctx[at + 1 :], goal=goal)
13951390
)
@@ -1474,13 +1469,13 @@ def beta(self, at=None):
14741469
self.goals[-1] = goalctx._replace(goal=newgoal)
14751470
else:
14761471
oldctx = goalctx.ctx
1477-
old = oldctx[at]
1472+
at, old = goalctx.ctx_find(at)
14781473
new = kd.rewrite.beta(old)
14791474
if new.eq(old):
14801475
raise ValueError(
14811476
"Beta tactic failed. Ctx is already beta reduced.", old
14821477
)
1483-
self.add_lemma(kd.kernel.prove(old == new))
1478+
self.add_lemma(kd.kernel.prove(smt.Eq(old, new)))
14841479
self.goals[-1] = goalctx._replace(
14851480
ctx=oldctx[:at] + [new] + oldctx[at + 1 :]
14861481
)
@@ -1518,7 +1513,7 @@ def unfold(self, *decls: smt.FuncDeclRef, at=None, keep=False) -> "ProofState":
15181513
else:
15191514
self.goals.append(goalctx._replace(goal=e2))
15201515
else:
1521-
e = goalctx.ctx[at]
1516+
at, e = goalctx.ctx_find(at)
15221517
trace = []
15231518
e2 = kd.rewrite.unfold(e, decls=decls, trace=trace)
15241519
for lem in trace:
@@ -1639,12 +1634,21 @@ def contra(self):
16391634
)
16401635
return self.top_goal()
16411636

1642-
def clear(self, n: int):
1637+
def clear(self, n: int | smt.BoolRef):
16431638
"""
16441639
Remove a hypothesis from the context
1640+
1641+
>>> p,q = smt.Bools("p q")
1642+
>>> l = Lemma(smt.Implies(p, q))
1643+
>>> h = l.intros()
1644+
>>> l
1645+
[p] ?|= q
1646+
>>> l.clear(h)
1647+
[] ?|= q
16451648
"""
16461649
ctxgoal = self.pop_goal()
16471650
ctx = ctxgoal.ctx.copy()
1651+
n, _ = ctxgoal.ctx_find(n)
16481652
ctx.pop(n)
16491653
self.goals.append(ctxgoal._replace(ctx=ctx))
16501654
return self.top_goal()

src/kdrag/theories/real/fun.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import kdrag as kd
2+
import kdrag.smt as smt
3+
import kdrag.theories.real as real
4+
5+
RFun = smt.ArraySort(real.R, real.R)
6+
f, g, h = smt.Consts("f g h", RFun)
7+
x, y, c = smt.Consts("x y c", real.R)
8+
9+
smul = kd.define("smul", [c, f], smt.Lambda([x], real.mul(c, f[x])))
10+
11+
12+
# TODO: if I leave these add and mul bare, I have problems. Hmm.
13+
def Linear(f):
14+
return smt.And(
15+
smt.ForAll([x, y], f[real.add(x, y)] == real.add(f[x], f[y])),
16+
smt.ForAll([x, c], f[real.mul(c, x)] == real.mul(c, f[x])),
17+
)
18+
19+
20+
is_linear = kd.define("is_linear", [f], Linear(f))
21+
22+
23+
@kd.Theorem(smt.ForAll([f, g], is_linear(f), is_linear(g), is_linear(f + g)))
24+
def linear_add(l):
25+
f, g = l.fixes()
26+
l.unfold()
27+
_h = l.intros()
28+
l.split(at=-1)
29+
l.split(at=-1)
30+
l.split(at=0)
31+
l.split()
32+
33+
_x, _y = l.fixes()
34+
l.simp()
35+
l.rw(0)
36+
l.rw(2)
37+
l.auto(unfold=[real.add])
38+
39+
_x, _c = l.fixes()
40+
l.simp()
41+
l.rw(1)
42+
l.rw(3)
43+
l.auto(unfold=[real.mul])
44+
45+
46+
@kd.Theorem(smt.ForAll([f, c], is_linear(f), is_linear(smul(c, f))))
47+
def linear_smul(l):
48+
f, c = l.fixes()
49+
l.unfold()
50+
_h = l.intros()
51+
l.split(at=-1) # Fix this. giving h doesn't work
52+
l.split()
53+
54+
_x, _y = l.fixes()
55+
l.simp()
56+
l.rw(0)
57+
l.auto(unfold=[real.add, real.mul])
58+
59+
_x, _c2 = l.fixes()
60+
l.simp()
61+
l.rw(1)
62+
l.auto(unfold=[real.mul])

0 commit comments

Comments
 (0)