Skip to content

Commit 446622c

Browse files
committed
lambda lift for solvers
1 parent cb07678 commit 446622c

File tree

4 files changed

+105
-24
lines changed

4 files changed

+105
-24
lines changed

src/kdrag/solvers/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,21 @@ class BaseSolver:
126126
def __init__(self):
127127
self.adds = []
128128
self.assert_tracks = []
129-
self.options = {}
129+
self.options: dict = {"lambda_lift": True}
130130
self.res: Optional[subprocess.CompletedProcess] = None
131131

132132
def add(self, thm: smt.BoolRef):
133133
if isinstance(thm, list):
134-
self.adds.extend(thm)
135-
return
134+
for t in thm:
135+
self.add(t)
136136
else:
137137
assert isinstance(thm, smt.BoolRef)
138+
if self.options["lambda_lift"]:
139+
thm1 = kd.utils.curry_arrays(thm)
140+
thm1, extra = kd.utils.lambda_lift(thm1)
141+
assert isinstance(thm1, smt.BoolRef)
142+
thm = thm1
143+
self.adds.extend(extra)
138144
self.adds.append(thm)
139145

140146
def assert_and_track(self, thm: smt.BoolRef, name: str):
@@ -200,15 +206,15 @@ def write_tptp(self, filename, format="thf"):
200206
if f not in predefined and f.name() not in predefined_names:
201207
if f.arity() == 0:
202208
fp.write(
203-
f"{format}({f.name()}_type, type, {f.name() if f in no_mangle else mangle_decl(f)} : {sort_to_tptp(f.range())} ).\n"
209+
f"{format}({f.name().replace('!', '__')}_type, type, {f.name().replace('!', '__') if f in no_mangle else mangle_decl(f)} : {sort_to_tptp(f.range())} ).\n"
204210
)
205211
else:
206212
joiner = " > " if format == "thf" else " * "
207213
dom_tptp = joiner.join(
208214
[sort_to_tptp(f.domain(i)) for i in range(f.arity())]
209215
)
210216
fp.write(
211-
f"{format}({f.name()}_decl, type, {f.name() if f in no_mangle else mangle_decl(f)} : {dom_tptp} > {sort_to_tptp(f.range())}).\n"
217+
f"{format}({f.name().replace('!', '__')}_decl, type, {f.name().replace('!', '__') if f in no_mangle else mangle_decl(f)} : {dom_tptp} > {sort_to_tptp(f.range())}).\n"
212218
)
213219

214220
# Write axioms and assertions in TPTP THF format
@@ -240,11 +246,10 @@ def check_tptp_status(self, res):
240246

241247
def write_smt(self, fp):
242248
fp.write("(set-logic ALL)\n")
249+
thms = self.adds + [thm for thm, name in self.assert_tracks]
243250
# Gather up all datatypes referenced
244251
predefined = set()
245-
for sort in collect_sorts(
246-
self.adds + [thm for thm, name in self.assert_tracks]
247-
):
252+
for sort in collect_sorts(thms):
248253
if isinstance(sort, smt.DatatypeSortRef):
249254
fp.write(smtlib_datatypes([sort]))
250255
for i in range(sort.num_constructors()):
@@ -258,16 +263,17 @@ def write_smt(self, fp):
258263
fp.write(f"(declare-sort {sort.name()} 0)\n")
259264
# Declare all function symbols
260265
fp.write(";;declarations\n")
261-
for f in collect_decls(self.adds + [thm for thm, name in self.assert_tracks]):
266+
for f in collect_decls(thms):
262267
if f not in predefined and f.name() not in predefined_names:
263268
fp.write(funcdecl_smtlib(f))
264269
fp.write("\n")
265270
fp.write(";;axioms\n")
266-
for e in self.adds:
271+
# TODO: Add back assert tracking
272+
for e in thms: # self.adds:
267273
# We can't use e.sexpr() because we need to mangle overloaded names
268274
fp.write("(assert " + expr_to_smtlib(e) + ")\n")
269-
for thm, name in self.assert_tracks:
270-
fp.write("(assert (! " + expr_to_smtlib(thm) + " :named " + name + "))\n")
275+
# for thm, name in self.assert_tracks:
276+
# fp.write("(assert (! " + expr_to_smtlib(thm) + " :named " + name + "))\n")
271277

272278

273279
def collect_sorts(exprs) -> set[smt.SortRef]:
@@ -454,13 +460,12 @@ def check(self):
454460
return self.check_tptp_status(self.res.stdout)
455461

456462

457-
"""
458463
class EProverSolver(BaseSolver):
459464
def __init__(self):
460-
new = download(
465+
download(
461466
"https://github.com/philzook58/eprover/releases/download/E.3.2.5-ho/eprover-ho",
462467
"eprover-ho",
463-
"56a1dd51be0ba3194851cfb6f4ecc563c82cd5e2f5009dd1a7268af91150c9cd",
468+
"840170d1cb80cc3796b0f209e5879d23d1a19577fd922d84029c30783687b2c6",
464469
)
465470
super().__init__()
466471
self.options["format"] = "tff"
@@ -479,7 +484,6 @@ def check(self):
479484
if len(self.res.stderr) > 0:
480485
raise Exception("Eprover error", self.res.stderr)
481486
return self.check_tptp_status(self.res.stdout)
482-
"""
483487

484488

485489
class EProverTHFSolver(BaseSolver):

src/kdrag/solvers/sat.py

Whitespace-only changes.

src/kdrag/utils.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,12 @@ def lambda_lift(expr: smt.ExprRef) -> tuple[smt.ExprRef, list[smt.BoolRef]]:
481481
>>> lambda_lift(smt.ForAll([x,y], smt.Exists([z], x + y + 1 == smt.Lambda([z], x)[z])))
482482
(ForAll([x, y], Exists(z, x + y + 1 == f!...(x, y)[z])), [ForAll([x, y, z], f!...(x, y)[z] == x)])
483483
"""
484-
lift_defs = []
484+
lift_defs: list[smt.BoolRef] = []
485485

486486
def worker(expr, env):
487487
if isinstance(expr, smt.QuantifierRef):
488488
vs, body = kd.utils.open_binder_unhygienic(expr)
489-
env = [v for v in env if v not in vs] # shadowing
489+
env = [v for v in env if not any(v.eq(v1) for v1 in vs)] # shadowing
490490
env1 = env + vs
491491
if expr.is_lambda():
492492
f = smt.FreshFunction(*[v.sort() for v in env], expr.sort())
@@ -496,18 +496,71 @@ def worker(expr, env):
496496
return smt.ForAll(vs, worker(body, env1))
497497
elif expr.is_exists():
498498
return smt.Exists(vs, worker(body, env1))
499+
else:
500+
raise NotImplementedError("Unknown quantifier type", expr)
499501
elif smt.is_const(expr):
500502
return expr
501-
if smt.is_app(expr):
503+
elif smt.is_app(expr):
502504
args = [worker(arg, env) for arg in expr.children()]
503505
return expr.decl()(*args)
504506
else:
505-
return expr
507+
raise Exception("Unexpected term in lambda_lift", expr)
506508

507509
env = []
508510
return worker(expr, env), lift_defs
509511

510512

513+
def curry_arrays(e: smt.ExprRef) -> smt.ExprRef:
514+
"""
515+
Curry all selects and lambdas into single argument versions.
516+
>>> f = smt.Array("f", smt.IntSort(), smt.RealSort(), smt.BoolSort())
517+
>>> curry_arrays(smt.Select(f, smt.IntVal(3), smt.RealVal(2.0)))
518+
f[3][2]
519+
>>> x,y,z = smt.Ints("x y z")
520+
>>> curry_arrays(smt.Lambda([x,y], x + y)[2,3])
521+
Lambda(x, Lambda(y, x + y))[2][3]
522+
"""
523+
# TODO Possibility of clashing names here.
524+
# If you have the same name for the curried version you're nuts though.
525+
if isinstance(e, smt.QuantifierRef):
526+
vs, body = open_binder_unhygienic(e)
527+
body = curry_arrays(body)
528+
if e.is_lambda():
529+
for v in reversed(vs):
530+
body = smt.Lambda([v], body)
531+
return body
532+
elif e.is_forall():
533+
return smt.ForAll(vs, body)
534+
elif e.is_exists():
535+
return smt.Exists(vs, body)
536+
else:
537+
raise NotImplementedError("Unknown quantifier type", e)
538+
elif smt.is_const(e):
539+
if isinstance(e, smt.ArrayRef):
540+
doms = smt.domains(e)
541+
if len(doms) == 1:
542+
return e
543+
else:
544+
sort = e.range()
545+
for d in reversed(doms):
546+
sort = smt.ArraySort(d, sort)
547+
return smt.Const(e.decl().name(), sort)
548+
else:
549+
return e
550+
elif smt.is_app(e):
551+
f, children = e.decl(), e.children()
552+
children = [curry_arrays(c) for c in children]
553+
if smt.is_select(e):
554+
arr = children[0]
555+
for index in children[1:]:
556+
arr = smt.Select(arr, index)
557+
return arr
558+
else:
559+
return f(*children)
560+
else:
561+
raise Exception("Unexpected term in curry_arrays", e)
562+
563+
511564
def generate(sort: smt.SortRef, pred=None) -> Generator[smt.ExprRef, None, None]:
512565
"""
513566
A generator of values for a sort. Repeatedly calls z3 to get a new value.

tests/test_solver.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
TweeSolver,
99
SATSolver,
1010
LeanSolver,
11+
EProverSolver,
12+
VampireSolver
1113
)
1214
import kdrag.solvers as solvers
1315
import kdrag.smt as smt
@@ -19,6 +21,15 @@
1921
import kdrag.solvers.kb as kb
2022
import kdrag.rewrite as rw
2123

24+
@pytest.mark.slow
25+
def test_eprover():
26+
s = EProverSolver()
27+
a,b,c = smt.Bools("a b c")
28+
s.add(smt.And(a,b))
29+
assert s.check() == smt.sat
30+
s.add(smt.Not(a))
31+
assert s.check() == smt.unsat
32+
2233
@pytest.mark.slow
2334
def test_vampirethf():
2435
s = VampireTHFSolver()
@@ -46,9 +57,11 @@ def test_vampirethf():
4657
s.add(f == g)
4758
s.add(smt.Not(smt.ForAll([x], (f(x) == g(x)))))
4859
assert s.check() == smt.unsat
49-
60+
"""
61+
62+
"""
5063
x, y, z = smt.Reals("x y z")
51-
s = VampireTHFSolver()
64+
s = VampireSolver()
5265
s.add(x + y == z)
5366
s.add(x == y)
5467
assert s.check() == smt.sat
@@ -59,12 +72,13 @@ def test_vampirethf():
5972
x, y, z = smt.Consts("x y z", S)
6073
f, g = smt.Consts("f g", S >> T)
6174

62-
"""
75+
6376
s = VampireTHFSolver()
6477
s.add(f == g)
6578
s.add(smt.Not(smt.ForAll([x], (f(x) == g(x)))))
6679
assert s.check() == smt.unsat
67-
"""
80+
81+
6882
s = EProverTHFSolver()
6983
s.add(f == g)
7084
s.add(smt.Not(smt.ForAll([x], (f(x) == g(x)))))
@@ -112,13 +126,23 @@ def test_vampirethf():
112126
f = smt.Const("f", smt.BoolSort() >> smt.BoolSort())
113127
p, q = smt.Bools("p q")
114128

129+
s = EProverTHFSolver()
130+
s.set("lambda_lift", False)
131+
s.add(smt.Not(smt.ForAll(
132+
[p],
133+
p == (smt.Lambda([f], f(p)) == smt.Lambda([f], f(smt.BoolVal(True)))),
134+
)))
135+
assert s.check() == smt.unsat
136+
"""
137+
# lambda lifting makes this come back unknown
115138
kd.kernel.prove(
116139
smt.ForAll(
117140
[p],
118141
p == (smt.Lambda([f], f(p)) == smt.Lambda([f], f(smt.BoolVal(True)))),
119142
),
120143
solver=EProverTHFSolver,
121144
)
145+
"""
122146

123147
s = TweeSolver()
124148
x, y, z = smt.Consts("x y z", S)

0 commit comments

Comments
 (0)