Skip to content

Commit fdce496

Browse files
committed
kleene, vprove, abs_sum, more ir
1 parent 9613842 commit fdce496

File tree

7 files changed

+252
-4
lines changed

7 files changed

+252
-4
lines changed

src/kdrag/contrib/ir/__init__.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dataclasses import dataclass, field
1919
import kdrag as kd
2020
import kdrag.smt as smt
21+
from collections import defaultdict
2122

2223

2324
def pp_sort(s: smt.SortRef) -> str:
@@ -100,13 +101,44 @@ def __repr__(self) -> str:
100101
res.append(f"%{i} = {rhs}")
101102
return "\n".join(res)
102103

104+
def succ_calls(self) -> list[smt.ExprRef]:
105+
jmp = self.insns[-1]
106+
if smt.is_if(jmp):
107+
return jmp.children()
108+
else:
109+
return [jmp]
110+
111+
112+
type Label = str
113+
103114

104115
@dataclass
105116
class Function:
106117
""" """
107118

108-
entry: str # smt.FuncDeclRef?
109-
blocks: dict[str, Block] # 0th block is entry. Or "entry" is entry? Naw. 0th.
119+
entry: Label # smt.FuncDeclRef?
120+
blocks: dict[Label, Block] # 0th block is entry. Or "entry" is entry? Naw. 0th.
121+
122+
def calls_of(self) -> dict[str, list[tuple[Label, smt.ExprRef]]]:
123+
"""
124+
Returns a mapping from labels to a list of calls to that label
125+
"""
126+
p = defaultdict(list)
127+
for label, blk in self.blocks.items():
128+
for call in blk.succ_calls():
129+
p[call.decl().name()].append((label, call))
130+
return p
131+
132+
def phis(self):
133+
"""
134+
Return the analog a mapping from labels to phi nodes in that block
135+
"""
136+
137+
preds = self.calls_of()
138+
phis = {}
139+
for label, blk in self.blocks.items():
140+
phis[label] = zip(*[call.children() for _, call in preds[label]])
141+
return phis
110142

111143

112144
@dataclass

src/kdrag/solvers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ def check(self):
343343
cmd = [
344344
self.options["solver_path"],
345345
fp.name,
346-
"--mode",
347-
"casc",
346+
# "--mode", # This adds portfolio mode, but it was slower. Maybe more useful for hard questions?
347+
# "casc",
348348
"--input_syntax",
349349
"smtlib2",
350350
# "--ignore_unrecognized_logic", "on",

src/kdrag/tactics.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,17 @@ def prove(
324324
raise e
325325

326326

327+
def vprove(thm: smt.BoolRef, **kwargs) -> kd.kernel.Proof:
328+
"""
329+
Prove a theorem using an egraph-based solver.
330+
331+
>>> x = smt.Int("x")
332+
>>> vprove(x + 1 > x)
333+
|= x + 1 > x
334+
"""
335+
return prove(thm, solver=solvers.VampireSolver, **kwargs)
336+
337+
327338
def simp(t: smt.ExprRef, by: list[kd.kernel.Proof] = [], **kwargs) -> kd.kernel.Proof:
328339
rules = [kd.rewrite.rewrite_of_expr(lem.thm) for lem in by]
329340
t1 = kd.rewrite.rewrite_once(t, rules)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Kleene Algebra
3+
4+
5+
- https://www.philipzucker.com/bryzzowski_kat/
6+
- https://www.cs.uoregon.edu/research/summerschool/summer24/topics.php#Silva
7+
8+
"""
9+
10+
import kdrag as kd
11+
import kdrag.smt as smt
12+
13+
K = smt.DeclareSort("Kleene")
14+
zero = smt.Const("_0", K)
15+
one = smt.Const("_1", K)
16+
x, y, z, w, e, f = smt.Consts("x y z w e f", K)
17+
par = smt.Function("par", K, K, K)
18+
seq = smt.Function("seq", K, K, K)
19+
star = smt.Function("star", K, K)
20+
kd.notation.or_.register(K, par)
21+
kd.notation.add.register(K, par)
22+
kd.notation.matmul.register(K, seq)
23+
kd.notation.mul.register(K, seq)
24+
kd.notation.le.register(K, lambda x, y: smt.Eq(par(x, y), y))
25+
26+
27+
par_assoc = kd.axiom(smt.ForAll([x, y, z], x + (y + z) == (x + y) + z))
28+
par_comm = kd.axiom(smt.ForAll([x, y], x + y == y + x))
29+
par_idem = kd.axiom(smt.ForAll([x], x + x == x))
30+
par_zero = kd.axiom(smt.ForAll([x], x + zero == x))
31+
32+
zero_par = kd.prove(smt.ForAll([x], zero + x == x), by=[par_comm, par_zero])
33+
34+
seq_assoc = kd.axiom(smt.ForAll([x, y, z], x * (y * z) == (x * y) * z))
35+
seq_zero = kd.axiom(smt.ForAll([x], x * zero == zero))
36+
seq_one = kd.axiom(smt.ForAll([x], x * one == x))
37+
38+
rdistrib = kd.axiom(smt.ForAll([x, y, z], x * (y + z) == x * y + x * z))
39+
ldistrib = kd.axiom(smt.ForAll([x, y, z], (y + z) * x == y * x + z * x))
40+
41+
unfold = kd.axiom(smt.ForAll([e], star(e) == one + e * star(e)))
42+
43+
# If a set shrinks, star(e) is less than it
44+
45+
least_fix = kd.axiom(smt.ForAll([x, e, f], f + e * x <= x, star(e) * f <= x))
46+
47+
KLEENE = [
48+
par_assoc,
49+
par_comm,
50+
par_idem,
51+
par_zero,
52+
seq_assoc,
53+
seq_zero,
54+
seq_one,
55+
rdistrib,
56+
ldistrib,
57+
unfold,
58+
least_fix,
59+
]
60+
par_monotone = kd.prove(
61+
smt.ForAll([x, y, z, w], x <= y, z <= w, x + z <= y + w),
62+
by=[par_assoc, par_comm],
63+
)
64+
65+
seq_monotone = kd.prove(
66+
smt.ForAll([x, y, z, w], x <= y, z <= w, x * z <= y * w),
67+
by=[rdistrib, ldistrib, par_assoc],
68+
)
69+
70+
star_seq_star = kd.tactics.vprove(
71+
smt.ForAll([x], star(x) * star(x) == star(x)), by=KLEENE
72+
)
73+
# z3 takes 0.5 seconds. Vampire is actually faster despite all the overhead
74+
star_star = kd.tactics.vprove(smt.ForAll([x], star(star(x)) == star(x)), by=KLEENE)
75+
76+
Test = smt.DeclareSort("Test")
77+
guard = smt.Function("guard", Test, K, K)
78+
and_ = smt.Function("and", Test, Test, Test)
79+
or_ = smt.Function("or", Test, Test, Test)
80+
not_ = smt.Function("not", Test, Test)
81+
true = smt.Const("true", Test)
82+
false = smt.Const("false", Test)
83+
kd.notation.invert.register(Test, not_)
84+
85+
a, b, c, d = smt.Consts("a b c d", Test)
86+
87+
88+
def while_(t, c):
89+
return star(guard(t, c)) * guard(~t, one)

src/kdrag/theories/real/__init__.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
137137
ForAll([x, y], smt.Implies(y != 0, abs(x / y) == abs(x) / abs(y))),
138138
by=[abs.defn],
139139
)
140+
abs_Lipschitz = kd.prove(
141+
ForAll([x, y], abs(abs(x) - abs(y)) <= abs(x - y)), by=[abs.defn]
142+
)
140143
abs_triangle = kd.prove(
141144
ForAll([x, y, z], abs(x - y) <= abs(x - z) + abs(z - y)), by=[abs.defn]
142145
)
@@ -220,6 +223,10 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
220223
floor_le = kd.prove(ForAll([x], floor(x) <= x), by=[floor.defn])
221224
floor_gt = kd.prove(ForAll([x], x < floor(x) + 1), by=[floor.defn])
222225

226+
227+
ceil = kd.define("ceil", [x], -floor(-x))
228+
ceil_le = kd.prove(ForAll([x], x <= ceil(x)), by=[ceil.defn, floor_le])
229+
ceil_gt = kd.prove(ForAll([x], ceil(x) - 1 < x), by=[ceil.defn, floor_gt])
223230
# c = kd.Calc([n, x], smt.ToReal(n) <= x)
224231
# c.eq(n <= smt.ToInt(x))
225232
# c.eq(smt.ToReal(n) <= floor(x), by=[floor.defn])
@@ -237,6 +244,23 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
237244
kd.kernel.prove(smt.Implies(x > 0, x**0 == 1))
238245
# pow_pos = kd.prove(kd.QForAll([x, y], x > 0, pow(x, y) > 0), by=[pow.defn])
239246

247+
pownat = smt.Function("pownat", smt.RealSort(), smt.IntSort(), smt.RealSort())
248+
pownat = kd.define(
249+
"pownat",
250+
[x, n],
251+
kd.cond((n == 0, 1), (n > 0, x * pownat(x, n - 1)), default=pownat(x, n + 1) / x),
252+
)
253+
254+
# Basic lemmas for pownat.
255+
pownat_zero = kd.prove(
256+
smt.ForAll([x], pownat(x, 0) == 1),
257+
by=[pownat.defn],
258+
)
259+
pownat_succ = kd.prove(
260+
kd.QForAll([x, n], n >= 0, pownat(x, n + 1) == x * pownat(x, n)),
261+
by=[pownat.defn],
262+
)
263+
240264
sqr = kd.define("sqr", [x], x * x)
241265

242266

@@ -385,6 +409,29 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
385409
)
386410

387411

412+
# x, y = smt.Reals("x y")
413+
N = smt.Int("N")
414+
415+
416+
@kd.Theorem(
417+
smt.ForAll(
418+
[x, y],
419+
smt.Implies(
420+
smt.And(x > 0, y > 0),
421+
smt.Exists([N], y < x * N),
422+
),
423+
)
424+
)
425+
def archimedes(l):
426+
# https://en.wikipedia.org/wiki/Archimedean_property
427+
x, y = l.fixes()
428+
l.intros()
429+
430+
# Choose N = floor(y/x) + 1.
431+
l.exists(smt.ToInt(y / x) + 1)
432+
l.auto()
433+
434+
388435
# smt.Function("cont_at", RFun, R, smt.BoolSort())
389436

390437
is_diff = kd.define("is_diff", [f], smt.ForAll([x], diff_at(f, x)))

src/kdrag/theories/real/seq.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import kdrag as kd
33
import kdrag.smt as smt
44

5+
import kdrag.theories.int as int_
56
import kdrag.theories.real as real
67

78

@@ -209,6 +210,72 @@ def max_suffix_bound(l):
209210
cos = kd.define("cos", [a], smt.Map(real.cos, a))
210211

211212
abs = kd.define("abs", [a], smt.Map(real.abs, a))
213+
abs_ext = kd.prove(smt.ForAll([a, n], abs(a)[n] == real.abs(a[n])), by=[abs.defn])
214+
abs_ge_0 = kd.prove(
215+
kd.QForAll([a, n], abs(a)[n] >= 0),
216+
by=[abs_ext, real.abs_pos],
217+
)
218+
219+
220+
@kd.Theorem(
221+
kd.QForAll(
222+
[a, N],
223+
N >= 0,
224+
kd.QForAll([n], n >= 0, n <= N, abs(a)[n] <= finsum(abs(a), N)),
225+
)
226+
)
227+
def abs_le_finsum_abs(l):
228+
a, N = l.fixes()
229+
230+
l.induct(N, using=int_.induct_nat)
231+
# base case: N = 0
232+
l.intros()
233+
n = l.fix()
234+
l.intros()
235+
l.split(at=-1)
236+
l.have(n == 0, by=[])
237+
l.rw(-1)
238+
l.unfold(finsum)
239+
l.unfold(cumsum)
240+
l.simp()
241+
l.auto()
242+
243+
# step: N >= 0, IH -> N + 1
244+
N = l.fix()
245+
l.intros()
246+
l.intros()
247+
248+
n = l.fix()
249+
l.intros()
250+
l.split(at=0)
251+
252+
# Case split: n <= N or n = N + 1
253+
l.cases(n <= N)
254+
255+
# n <= N: use IH and monotonicity of finsum
256+
l.have(abs(a)[n] <= finsum(abs(a), N), by=[])
257+
l.have(N + 1 > 0, by=[])
258+
l.have(
259+
finsum(abs(a), N + 1) == finsum(abs(a), N) + abs(a)[N + 1],
260+
by=[finsum.defn, cumsum.defn],
261+
)
262+
l.have(abs(a)[N + 1] >= 0, by=[abs_ge_0])
263+
l.have(finsum(abs(a), N) <= finsum(abs(a), N + 1), by=[])
264+
l.auto()
265+
266+
# n > N: then n = N + 1
267+
l.have(n == N + 1, by=[])
268+
l.rw(-1)
269+
l.have(N + 1 > 0, by=[])
270+
l.have(
271+
finsum(abs(a), N + 1) == finsum(abs(a), N) + abs(a)[N + 1],
272+
by=[finsum.defn, cumsum.defn],
273+
)
274+
l.have(abs(a)[N + 1] >= 0, by=[abs_ge_0])
275+
l.have(abs(a)[0] <= finsum(abs(a), N), by=[])
276+
l.have(abs(a)[0] >= 0, by=[abs_ge_0])
277+
l.have(finsum(abs(a), N + 1) >= abs(a)[N + 1], by=[])
278+
l.auto()
212279

213280

214281
has_bound = kd.define("has_bound", [a, M], kd.QForAll([n], real.abs(a[n]) <= M))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import kdrag as kd
2+
import kdrag.smt as smt

0 commit comments

Comments
 (0)