Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 57 additions & 50 deletions src/kdrag/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,81 +432,88 @@ def Enum(name: str, args: str | list[str]) -> smt.DatatypeSortRef:
# smt.DatatypeRef.rel = lambda self, *args: self.rel(*args)


def InductiveRel(name: str, *params: smt.ExprRef) -> smt.Datatype:
"""Define an inductive type of evidence and a relation the recurses on that evidence

>>> x = smt.Int("x")
>>> Even = InductiveRel("Even", x)
>>> Even.declare("Ev_Z", pred = x == 0)
>>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
>>> Even = Even.create()
>>> ev = smt.Const("ev", Even)
>>> ev.rel(4)
even(4)[ev]
>>> ev(4)
even(4)[ev]
>>> Even(4)
even(4)
class _InductiveRelDatatype(kd.kernel._InductiveDatatype):
"""
Subclass of _InductiveDatatype for inductive relations.
This replaces the previous approach of monkey-patching declare and create methods.
"""

dt = Inductive(name)

relname = name.lower()
olddeclare = dt.declare
preds = [] # tuck away extra predicate

def declare(
name, *args, pred=None
): # TODO: would it ever make sense to not have a pred?
olddeclare(name, *args)
preds.append(pred)

dt.declare = declare

oldcreate = dt.create

def create_relation(dt):
"""
When inductive is done being defined, call this function
"""
ev = smt.FreshConst(dt, prefix=name.lower())
rel = smt.Function(relname, dt, *[x.sort() for x in params], smt.BoolSort())
def __init__(self, name: str, *params: smt.ExprRef):
"""Initialize an inductive relation datatype."""
super().__init__(name)
self._params = params
self._relname = name.lower()
self._preds = [] # Store predicates for each constructor

def declare(self, name, *args, pred=None):
"""Override declare to also store predicates."""
super().declare(name, *args)
self._preds.append(pred)

def _create_relation(self, dt):
"""Create the relation definition for this inductive type."""
ev = smt.FreshConst(dt, prefix=self._relname)
rel = smt.Function(
self._relname, dt, *[x.sort() for x in self._params], smt.BoolSort()
)
cases = []
for i in range(dt.num_constructors()):
precond = dt.recognizer(i)(ev) # recognize case of the evidence
pred = preds[i] # In this case, this predicate should be true
pred = self._preds[i] # In this case, this predicate should be true
if pred is None:
res = smt.BoolVal(True)
elif isinstance(pred, smt.ExprRef):
res = pred
else:
args = [dt.accessor(i, j)(ev) for j in range(dt.constructor(i).arity())]
args = [
dt.accessor(i, j)(ev) for j in range(dt.constructor(i).arity())
]
res = pred(*args)
cases.append((precond, res))
rel = kd.define(relname, list(params), smt.Lambda([ev], kd.cond(*cases)))
rel = kd.define(
self._relname, list(self._params), smt.Lambda([ev], kd.cond(*cases))
)
return rel

def create():
dt = oldcreate()
if any(p is not None for p in preds):
def create(self):
"""Override create to also set up the relation."""
dt = super().create()
if any(p is not None for p in self._preds):
dtrel = smt.Function(
relname, *[x.sort() for x in params], smt.ArraySort(dt, smt.BoolSort())
self._relname,
*[x.sort() for x in self._params],
smt.ArraySort(dt, smt.BoolSort()),
)
kd.notation.call.register(dt, lambda self, *args: dtrel(*args)[self])
rel.register(
dt, lambda self, *args: dtrel(*args)[self]
) # doing this here let's us tie the knot inside of lambdas and refer to the predicate.
dtrel = create_relation(dt)
dtrel = self._create_relation(dt)
dt.rel = dtrel
# ev = smt.FreshConst(dt, prefix=name.lower())
call_dict[dt] = dtrel # lambda *args: smt.Lambda([ev], ev.rel(*args))
call_dict[dt] = dtrel

if len(params) == 0:
if len(self._params) == 0:
kd.notation.wf.register(dt, lambda x: x.rel())
return dt

dt.create = create
return dt

def InductiveRel(name: str, *params: smt.ExprRef) -> _InductiveRelDatatype:
"""Define an inductive type of evidence and a relation the recurses on that evidence

>>> x = smt.Int("x")
>>> Even = InductiveRel("Even", x)
>>> Even.declare("Ev_Z", pred = x == 0)
>>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
>>> Even = Even.create()
>>> ev = smt.Const("ev", Even)
>>> ev.rel(4)
even(4)[ev]
>>> ev(4)
even(4)[ev]
>>> Even(4)
even(4)
"""
return _InductiveRelDatatype(name, *params)


def inj_lemmas(dt: smt.DatatypeSortRef) -> list[kd.kernel.Proof]:
Expand Down
60 changes: 34 additions & 26 deletions src/kdrag/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,29 +543,26 @@ def induct_inductive(x: smt.DatatypeRef, P: smt.QuantifierRef) -> Proof:
return axiom(smt.Implies(smt.And(hyps), conc), by="induction_axiom_schema")


def Inductive(name: str) -> smt.Datatype:
"""
Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype

>>> Nat = Inductive("Nat")
>>> Nat.declare("zero")
>>> Nat.declare("succ", ("pred", Nat))
>>> Nat = Nat.create()
>>> Nat.succ(Nat.zero)
succ(zero)
"""
counter = 0
n = name
while n in _datatypes:
counter += 1
n = name + "!" + str(counter)
name = n
assert name not in _datatypes
dt = smt.Datatype(name)
oldcreate = dt.create

def create():
dt = oldcreate()
class _InductiveDatatype(smt.Datatype):
"""
Subclass of z3.Datatype that auto-generates induction principles.
This replaces the previous approach of monkey-patching the create method.
"""

def __init__(self, name: str):
"""Initialize an inductive datatype with a unique name."""
counter = 0
n = name
while n in _datatypes:
counter += 1
n = name + "!" + str(counter)
self._unique_name = n
assert self._unique_name not in _datatypes
super().__init__(self._unique_name)

def create(self):
"""Create the datatype with additional validation and registration."""
dt = super().create()
# Sanity check no duplicate names. Causes confusion.
names = set()
for i in range(dt.num_constructors()):
Expand All @@ -580,12 +577,23 @@ def create():
raise Exception("Duplicate field name", n)
names.add(n)
kd.notation.induct.register(dt, induct_inductive)
_datatypes[name] = dt
_datatypes[self._unique_name] = dt
smt.sort_registry[dt.get_id()] = dt
return dt

dt.create = create
return dt

def Inductive(name: str) -> _InductiveDatatype:
"""
Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype

>>> Nat = Inductive("Nat")
>>> Nat.declare("zero")
>>> Nat.declare("succ", ("pred", Nat))
>>> Nat = Nat.create()
>>> Nat.succ(Nat.zero)
succ(zero)
"""
return _InductiveDatatype(name)


_overapproximate_fresh_ids = set()
Expand Down