Skip to content

Commit dc4086a

Browse files
Copilotphilzook58
andcommitted
Refactor Inductive to use subclass instead of monkey-patching
Co-authored-by: philzook58 <[email protected]>
1 parent 2bb56f2 commit dc4086a

File tree

2 files changed

+91
-76
lines changed

2 files changed

+91
-76
lines changed

src/kdrag/datatype.py

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -432,81 +432,88 @@ def Enum(name: str, args: str | list[str]) -> smt.DatatypeSortRef:
432432
# smt.DatatypeRef.rel = lambda self, *args: self.rel(*args)
433433

434434

435-
def InductiveRel(name: str, *params: smt.ExprRef) -> smt.Datatype:
436-
"""Define an inductive type of evidence and a relation the recurses on that evidence
437-
438-
>>> x = smt.Int("x")
439-
>>> Even = InductiveRel("Even", x)
440-
>>> Even.declare("Ev_Z", pred = x == 0)
441-
>>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
442-
>>> Even = Even.create()
443-
>>> ev = smt.Const("ev", Even)
444-
>>> ev.rel(4)
445-
even(4)[ev]
446-
>>> ev(4)
447-
even(4)[ev]
448-
>>> Even(4)
449-
even(4)
435+
class _InductiveRelDatatype(kd.kernel._InductiveDatatype):
436+
"""
437+
Subclass of _InductiveDatatype for inductive relations.
438+
This replaces the previous approach of monkey-patching declare and create methods.
450439
"""
451440

452-
dt = Inductive(name)
453-
454-
relname = name.lower()
455-
olddeclare = dt.declare
456-
preds = [] # tuck away extra predicate
457-
458-
def declare(
459-
name, *args, pred=None
460-
): # TODO: would it ever make sense to not have a pred?
461-
olddeclare(name, *args)
462-
preds.append(pred)
463-
464-
dt.declare = declare
465-
466-
oldcreate = dt.create
467-
468-
def create_relation(dt):
469-
"""
470-
When inductive is done being defined, call this function
471-
"""
472-
ev = smt.FreshConst(dt, prefix=name.lower())
473-
rel = smt.Function(relname, dt, *[x.sort() for x in params], smt.BoolSort())
441+
def __init__(self, name: str, *params: smt.ExprRef):
442+
"""Initialize an inductive relation datatype."""
443+
super().__init__(name)
444+
self._params = params
445+
self._relname = name.lower()
446+
self._preds = [] # Store predicates for each constructor
447+
448+
def declare(self, name, *args, pred=None):
449+
"""Override declare to also store predicates."""
450+
super().declare(name, *args)
451+
self._preds.append(pred)
452+
453+
def _create_relation(self, dt):
454+
"""Create the relation definition for this inductive type."""
455+
ev = smt.FreshConst(dt, prefix=self._relname)
456+
rel = smt.Function(
457+
self._relname, dt, *[x.sort() for x in self._params], smt.BoolSort()
458+
)
474459
cases = []
475460
for i in range(dt.num_constructors()):
476461
precond = dt.recognizer(i)(ev) # recognize case of the evidence
477-
pred = preds[i] # In this case, this predicate should be true
462+
pred = self._preds[i] # In this case, this predicate should be true
478463
if pred is None:
479464
res = smt.BoolVal(True)
480465
elif isinstance(pred, smt.ExprRef):
481466
res = pred
482467
else:
483-
args = [dt.accessor(i, j)(ev) for j in range(dt.constructor(i).arity())]
468+
args = [
469+
dt.accessor(i, j)(ev) for j in range(dt.constructor(i).arity())
470+
]
484471
res = pred(*args)
485472
cases.append((precond, res))
486-
rel = kd.define(relname, list(params), smt.Lambda([ev], kd.cond(*cases)))
473+
rel = kd.define(
474+
self._relname, list(self._params), smt.Lambda([ev], kd.cond(*cases))
475+
)
487476
return rel
488477

489-
def create():
490-
dt = oldcreate()
491-
if any(p is not None for p in preds):
478+
def create(self):
479+
"""Override create to also set up the relation."""
480+
dt = super().create()
481+
if any(p is not None for p in self._preds):
492482
dtrel = smt.Function(
493-
relname, *[x.sort() for x in params], smt.ArraySort(dt, smt.BoolSort())
483+
self._relname,
484+
*[x.sort() for x in self._params],
485+
smt.ArraySort(dt, smt.BoolSort()),
494486
)
495487
kd.notation.call.register(dt, lambda self, *args: dtrel(*args)[self])
496488
rel.register(
497489
dt, lambda self, *args: dtrel(*args)[self]
498490
) # doing this here let's us tie the knot inside of lambdas and refer to the predicate.
499-
dtrel = create_relation(dt)
491+
dtrel = self._create_relation(dt)
500492
dt.rel = dtrel
501-
# ev = smt.FreshConst(dt, prefix=name.lower())
502-
call_dict[dt] = dtrel # lambda *args: smt.Lambda([ev], ev.rel(*args))
493+
call_dict[dt] = dtrel
503494

504-
if len(params) == 0:
495+
if len(self._params) == 0:
505496
kd.notation.wf.register(dt, lambda x: x.rel())
506497
return dt
507498

508-
dt.create = create
509-
return dt
499+
500+
def InductiveRel(name: str, *params: smt.ExprRef) -> _InductiveRelDatatype:
501+
"""Define an inductive type of evidence and a relation the recurses on that evidence
502+
503+
>>> x = smt.Int("x")
504+
>>> Even = InductiveRel("Even", x)
505+
>>> Even.declare("Ev_Z", pred = x == 0)
506+
>>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
507+
>>> Even = Even.create()
508+
>>> ev = smt.Const("ev", Even)
509+
>>> ev.rel(4)
510+
even(4)[ev]
511+
>>> ev(4)
512+
even(4)[ev]
513+
>>> Even(4)
514+
even(4)
515+
"""
516+
return _InductiveRelDatatype(name, *params)
510517

511518

512519
def inj_lemmas(dt: smt.DatatypeSortRef) -> list[kd.kernel.Proof]:

src/kdrag/kernel.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -543,29 +543,26 @@ def induct_inductive(x: smt.DatatypeRef, P: smt.QuantifierRef) -> Proof:
543543
return axiom(smt.Implies(smt.And(hyps), conc), by="induction_axiom_schema")
544544

545545

546-
def Inductive(name: str) -> smt.Datatype:
547-
"""
548-
Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype
549-
550-
>>> Nat = Inductive("Nat")
551-
>>> Nat.declare("zero")
552-
>>> Nat.declare("succ", ("pred", Nat))
553-
>>> Nat = Nat.create()
554-
>>> Nat.succ(Nat.zero)
555-
succ(zero)
556-
"""
557-
counter = 0
558-
n = name
559-
while n in _datatypes:
560-
counter += 1
561-
n = name + "!" + str(counter)
562-
name = n
563-
assert name not in _datatypes
564-
dt = smt.Datatype(name)
565-
oldcreate = dt.create
566-
567-
def create():
568-
dt = oldcreate()
546+
class _InductiveDatatype(smt.Datatype):
547+
"""
548+
Subclass of z3.Datatype that auto-generates induction principles.
549+
This replaces the previous approach of monkey-patching the create method.
550+
"""
551+
552+
def __init__(self, name: str):
553+
"""Initialize an inductive datatype with a unique name."""
554+
counter = 0
555+
n = name
556+
while n in _datatypes:
557+
counter += 1
558+
n = name + "!" + str(counter)
559+
self._unique_name = n
560+
assert self._unique_name not in _datatypes
561+
super().__init__(self._unique_name)
562+
563+
def create(self):
564+
"""Create the datatype with additional validation and registration."""
565+
dt = super().create()
569566
# Sanity check no duplicate names. Causes confusion.
570567
names = set()
571568
for i in range(dt.num_constructors()):
@@ -580,12 +577,23 @@ def create():
580577
raise Exception("Duplicate field name", n)
581578
names.add(n)
582579
kd.notation.induct.register(dt, induct_inductive)
583-
_datatypes[name] = dt
580+
_datatypes[self._unique_name] = dt
584581
smt.sort_registry[dt.get_id()] = dt
585582
return dt
586583

587-
dt.create = create
588-
return dt
584+
585+
def Inductive(name: str) -> _InductiveDatatype:
586+
"""
587+
Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype
588+
589+
>>> Nat = Inductive("Nat")
590+
>>> Nat.declare("zero")
591+
>>> Nat.declare("succ", ("pred", Nat))
592+
>>> Nat = Nat.create()
593+
>>> Nat.succ(Nat.zero)
594+
succ(zero)
595+
"""
596+
return _InductiveDatatype(name)
589597

590598

591599
_overapproximate_fresh_ids = set()

0 commit comments

Comments
 (0)