Skip to content

Commit 53f016c

Browse files
committed
made inductive relations callable. Zipper pmatching
1 parent a001799 commit 53f016c

File tree

3 files changed

+125
-40
lines changed

3 files changed

+125
-40
lines changed

src/kdrag/datatype.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def Enum(name: str, args: str | list[str]) -> smt.DatatypeSortRef:
429429
rel = kd.notation.SortDispatch(name="rel")
430430
"""SortDispatch for the relation associated with a Datatype of evidence"""
431431
smt.DatatypeRef.rel = lambda *args: rel(*args)
432+
# smt.DatatypeRef.rel = lambda self, *args: self.rel(*args)
432433

433434

434435
def InductiveRel(name: str, *params: smt.ExprRef) -> smt.Datatype:
@@ -439,8 +440,13 @@ def InductiveRel(name: str, *params: smt.ExprRef) -> smt.Datatype:
439440
>>> Even.declare("Ev_Z", pred = x == 0)
440441
>>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
441442
>>> Even = Even.create()
442-
>>> smt.Const("ev", Even).rel(4)
443-
even(ev, 4)
443+
>>> ev = smt.Const("ev", Even)
444+
>>> ev.rel(4)
445+
even(4)[ev]
446+
>>> ev(4)
447+
even(4)[ev]
448+
>>> Even(4)
449+
even(4)
444450
"""
445451

446452
dt = Inductive(name)
@@ -477,19 +483,26 @@ def create_relation(dt):
477483
args = [dt.accessor(i, j)(ev) for j in range(dt.constructor(i).arity())]
478484
res = pred(*args)
479485
cases.append((precond, res))
480-
args = [ev]
481-
args.extend(params)
482-
rel = kd.define(relname, args, kd.cond(*cases))
486+
rel = kd.define(relname, list(params), smt.Lambda([ev], kd.cond(*cases)))
483487
return rel
484488

485489
def create():
486490
dt = oldcreate()
487-
dtrel = smt.Function(relname, dt, *[x.sort() for x in params], smt.BoolSort())
488-
rel.register(
489-
dt, lambda *args: dtrel(*args)
490-
) # doing this here let's us tie the knot inside of lambdas and refer to the predicate.
491-
dtrel = create_relation(dt)
492-
dt.rel = dtrel
491+
if any(p is not None for p in preds):
492+
dtrel = smt.Function(
493+
relname, *[x.sort() for x in params], smt.ArraySort(dt, smt.BoolSort())
494+
)
495+
kd.notation.call.register(dt, lambda self, *args: dtrel(*args)[self])
496+
rel.register(
497+
dt, lambda self, *args: dtrel(*args)[self]
498+
) # doing this here let's us tie the knot inside of lambdas and refer to the predicate.
499+
dtrel = create_relation(dt)
500+
dt.rel = dtrel
501+
# ev = smt.FreshConst(dt, prefix=name.lower())
502+
call_dict[dt] = dtrel # lambda *args: smt.Lambda([ev], ev.rel(*args))
503+
504+
if len(params) == 0:
505+
kd.notation.wf.register(dt, lambda x: x.rel())
493506
return dt
494507

495508
dt.create = create

src/kdrag/notation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ def define(self, args, body):
9898
return defn
9999

100100

101+
call = SortDispatch(name="call")
102+
"""Sort based dispatch for `()` call syntax"""
103+
smt.ExprRef.__call__ = lambda x, *y, **kwargs: call(x, *y, **kwargs)
104+
105+
getitem = SortDispatch(name="getitem")
106+
"""Sort based dispatch for `[]` getitem syntax"""
107+
smt.ExprRef.__getitem__ = lambda x, y: getitem(x, y) # type: ignore
108+
109+
101110
add = SortDispatch(name="add")
102111
"""Sort based dispatch for `+` syntax"""
103112
smt.ExprRef.__add__ = lambda x, y: add(x, y) # type: ignore
@@ -184,10 +193,6 @@ def define(self, args, body):
184193
"""Sort based dispatch for induction principles. Should instantiate an induction scheme for variable x and predicate P"""
185194
smt.ExprRef.induct = lambda x, P: induct(x, P) # type: ignore
186195

187-
getitem = SortDispatch(name="getitem")
188-
"""Sort based dispatch for `[]` getitem syntax"""
189-
smt.ExprRef.__getitem__ = lambda x, y: getitem(x, y) # type: ignore
190-
191196

192197
to_int = SortDispatch(name="to_int")
193198
"""Sort based dispatch for `to_int`"""

src/kdrag/utils.py

Lines changed: 92 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def prune(
579579

580580

581581
def bysect(
582-
thm, by0: list[kd.kernel.Proof] | dict[object, kd.kernel.Proof], **kwargs
582+
thm, by: list[kd.kernel.Proof] | dict[object, kd.kernel.Proof], **kwargs
583583
) -> Sequence[tuple[object, kd.kernel.Proof]]:
584584
"""
585585
Bisect the `by` list to find a minimal set of premises that prove `thm`. Presents the same interface as `prove`
@@ -589,29 +589,29 @@ def bysect(
589589
>>> bysect(x == z, by=by)
590590
[(1, |= x == y), (3, |= y == z)]
591591
"""
592-
if isinstance(by0, list):
593-
by = list(enumerate(by0))
594-
elif isinstance(by0, dict):
595-
by = list(by0.items())
592+
if isinstance(by, list):
593+
by1 = list(enumerate(by))
594+
elif isinstance(by, dict):
595+
by1 = list(by.items())
596596
else:
597597
raise ValueError("by must be a list or dict")
598598
n = 2
599-
while len(by) >= 2:
600-
subset_size = len(by) // n
601-
for i in range(0, len(by), subset_size):
602-
rest = by[:i] + by[i + subset_size :]
599+
while len(by1) >= 2:
600+
subset_size = len(by1) // n
601+
for i in range(0, len(by1), subset_size):
602+
rest = by1[:i] + by1[i + subset_size :]
603603
try:
604604
kd.prove(thm, by=[b for _, b in rest], **kwargs)
605-
by = rest
605+
by1 = rest
606606
n = max(n - 1, 2)
607607
break
608608
except Exception as _:
609609
pass
610610
else:
611-
if n == len(by):
611+
if n == len(by1):
612612
break
613-
n = min(len(by), n * 2)
614-
return by
613+
n = min(len(by1), n * 2)
614+
return by1
615615

616616

617617
def subterms(t: smt.ExprRef, into_binder=False):
@@ -780,22 +780,28 @@ def ast_size_sexpr(t: smt.AstRef) -> int:
780780
@dataclass(frozen=True)
781781
class QuantifierHole:
782782
vs: list[smt.ExprRef]
783-
# orig_vs : list[smt.ExprRef] to be able to exactly reconstruct original term?
783+
orig_vs: list[smt.ExprRef] # to be able to exactly reconstruct original term?
784+
785+
def has_right(self) -> bool:
786+
return False
784787

785788

786789
class LambdaHole(QuantifierHole):
787790
def wrap(self, body: smt.ExprRef) -> smt.ExprRef:
788-
return smt.Lambda(self.vs, body)
791+
body = smt.substitute(body, *zip(self.vs, self.orig_vs))
792+
return smt.Lambda(self.orig_vs, body)
789793

790794

791795
class ForAllHole(QuantifierHole):
792796
def wrap(self, body: smt.ExprRef) -> smt.ExprRef:
793-
return smt.ForAll(self.vs, body)
797+
body = smt.substitute(body, *zip(self.vs, self.orig_vs))
798+
return smt.ForAll(self.orig_vs, body)
794799

795800

796801
class ExistsHole(QuantifierHole):
797802
def wrap(self, body: smt.ExprRef) -> smt.ExprRef:
798-
return smt.Exists(self.vs, body)
803+
body = smt.substitute(body, *zip(self.vs, self.orig_vs))
804+
return smt.Exists(self.orig_vs, body)
799805

800806

801807
@dataclass(frozen=True)
@@ -835,23 +841,28 @@ class Zipper:
835841
>>> t = smt.Lambda([x,y], (x + y) * (y + z))
836842
>>> z1 = Zipper.from_term(t)
837843
>>> z1.open_binder().arg(1).left().arg(0)
838-
Zipper(ctx=[LambdaHole(vs=[X!..., Y!...]), DeclHole(f=*, _left=(), _right=(Y!... + z,)), DeclHole(f=+, _left=(), _right=(Y!...,))], t=X!...)
839-
>>> z1.pop().pop().pop()
840-
Zipper(ctx=[], t=Lambda([X!..., Y!...], (X!... + Y!...)*(Y!... + z)))
844+
Zipper(ctx=[LambdaHole(vs=[X!..., Y!...], orig_vs=[x, y]), DeclHole(f=*, _left=(), _right=(Y!... + z,)), DeclHole(f=+, _left=(), _right=(Y!...,))], t=X!...)
845+
>>> z1.up().up().up()
846+
Zipper(ctx=[], t=Lambda([x, y], (x + y)*(y + z)))
841847
"""
842848

843-
ctx: list[Hole] # trail / stack
849+
ctx: list[Hole] # trail / stack,. Consider saving old term
844850
t: smt.ExprRef
845851

846852
@classmethod
847853
def from_term(cls, t: smt.ExprRef) -> "Zipper":
848854
return cls([], t)
849855

850-
def pop(self) -> "Zipper": # up?
856+
def up(self) -> "Zipper": # up?
851857
hole = self.ctx.pop()
852858
self.t = hole.wrap(self.t)
853859
return self
854860

861+
def rebuild(self) -> smt.ExprRef:
862+
while self.ctx:
863+
self.up()
864+
return self.t
865+
855866
def copy(self) -> "Zipper":
856867
return Zipper(self.ctx.copy(), self.t)
857868

@@ -879,19 +890,75 @@ def arg(self, n: int) -> "Zipper":
879890

880891
def open_binder(self) -> "Zipper":
881892
assert isinstance(self.t, smt.QuantifierRef)
893+
t = self.t
894+
orig_vs, _body = kd.utils.open_binder_unhygienic(
895+
t
896+
) # TODO: don't need to build body
882897
vs, body = kd.utils.open_binder(self.t)
883898
if self.t.is_forall():
884-
hole = ForAllHole(vs)
899+
hole = ForAllHole(vs, orig_vs)
885900
elif self.t.is_exists():
886-
hole = ExistsHole(vs)
901+
hole = ExistsHole(vs, orig_vs)
887902
elif self.t.is_lambda():
888-
hole = LambdaHole(vs)
903+
hole = LambdaHole(vs, orig_vs)
889904
else:
890905
raise NotImplementedError("Unknown quantifier type", self.t)
891906
self.ctx.append(hole)
892907
self.t = body
893908
return self
894909

910+
def __iter__(self):
911+
return self
912+
913+
def __next__(self) -> smt.ExprRef:
914+
"""
915+
All subterms of the term in a pre-order traversal.
916+
917+
>>> x,y,z = smt.Ints("x y z")
918+
>>> list(Zipper([], x + y*z))
919+
[x, y*z, y, z]
920+
"""
921+
if isinstance(self.t, smt.QuantifierRef):
922+
self.open_binder()
923+
return self.t
924+
elif smt.is_const(self.t):
925+
while len(self.ctx) != 0 and not self.ctx[-1].has_right():
926+
self.up()
927+
if len(self.ctx) == 0:
928+
raise StopIteration
929+
else:
930+
self.right()
931+
return self.t
932+
elif smt.is_app(self.t):
933+
self.arg(0)
934+
return self.t
935+
else:
936+
raise ValueError("Unexpected term in Zipper iteration", self.t)
937+
938+
def pmatch(
939+
self, vs: list[smt.ExprRef], pat: smt.ExprRef
940+
) -> Optional[dict[smt.ExprRef, smt.ExprRef]]:
941+
"""
942+
Pattern match the current term against a pattern with variables vs.
943+
Leaves the zipper in a context state.
944+
This can be used to replace but rebuild using original context
945+
946+
>>> x,y,z,a,b,c = smt.Ints("x y z a b c")
947+
>>> zip = Zipper([], x + smt.Lambda([y], y*z)[x])
948+
>>> (subst := zip.pmatch([a,b], a*b))
949+
{b: z, a: Y!...}
950+
>>> zip.t = smt.IntVal(1) * subst[a]
951+
>>> zip.rebuild()
952+
x + Lambda(y, 1*y)[x]
953+
"""
954+
subst = pmatch(vs, pat, self.t)
955+
if subst is not None:
956+
return subst
957+
for t in self:
958+
subst = pmatch(vs, pat, t)
959+
if subst is not None:
960+
return subst
961+
895962
def __hash__(self):
896963
"""
897964
Warning: If you are hashing Zippers, make sure you are copying them.

0 commit comments

Comments
 (0)