Skip to content

Commit 8a348f8

Browse files
committed
cauchy schwartz for vec3. andE tactic. speiclaize doesn't keep by default. More tactics return formulas
1 parent 649a72f commit 8a348f8

File tree

11 files changed

+433
-93
lines changed

11 files changed

+433
-93
lines changed

src/kdrag/solvers/synth/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def cegis_simple(spec: smt.ExprRef) -> Optional[list[smt.ExprRef]]:
4848
raise Exception("Unknown result from solver", res)
4949

5050

51+
"""
52+
cegis with solve definitions exposed to verifier vs
53+
54+
"""
55+
5156
"""
5257
Top down
5358
Bottom up / Contextual compression

src/kdrag/tactics.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,33 @@ def boolsimp(self) -> "ProofState":
844844
self.goals[-1] = goalctx._replace(ctx=newctx, goal=newgoal)
845845
return self
846846

847+
def forward(self, n: int | smt.BoolRef) -> smt.BoolRef:
848+
"""
849+
Remove the hypothesis of an implication in the context
850+
851+
>>> p,q,r = smt.Bools("p q r")
852+
>>> l = Lemma(smt.Implies(smt.And(p, smt.Implies(p, q)), r))
853+
>>> _ = l.intros()
854+
>>> l.split(at=0)
855+
[p, Implies(p, q)] ?|= r
856+
>>> l.forward(1)
857+
q
858+
>>> l
859+
[p, q] ?|= r
860+
"""
861+
# TODO: extra by paramters? Give an exact index to use modus ponens?
862+
goalctx = self.top_goal()
863+
(at, formula) = goalctx.ctx_find(n)
864+
if smt.is_implies(formula):
865+
hyp, conc = formula.children()
866+
self.have(hyp, by=[])
867+
self.goals[-1] = goalctx._replace(
868+
ctx=goalctx.ctx[:at] + [conc] + goalctx.ctx[at + 1 :]
869+
)
870+
return conc
871+
else:
872+
raise ValueError("forward failed. Not an implication", formula)
873+
847874
def emt(self):
848875
"""
849876
Use egraph based equality modulo theories to simplify the goal.
@@ -1059,17 +1086,7 @@ def obtain(self, n: int | smt.QuantifierRef) -> smt.ExprRef | list[smt.ExprRef]:
10591086
"""
10601087
goalctx = self.top_goal()
10611088
ctx, goal = goalctx.ctx, goalctx.goal
1062-
if isinstance(n, smt.QuantifierRef):
1063-
for i, f in enumerate(ctx):
1064-
if f.eq(n):
1065-
n = i
1066-
break
1067-
else:
1068-
raise ValueError("obtain failed. Formula not in context", n)
1069-
assert isinstance(n, int)
1070-
if n < 0:
1071-
n = len(ctx) + n
1072-
formula = ctx[n]
1089+
n, formula = goalctx.ctx_find(n)
10731090
if isinstance(formula, smt.QuantifierRef) and formula.is_exists():
10741091
self.pop_goal()
10751092
fs, obtain_lemma = kd.kernel.obtain(formula)
@@ -1123,33 +1140,36 @@ def contract(self) -> "ProofState":
11231140
self.goals[-1] = self.goalctx._replace(ctx=newctx)
11241141
return self
11251142

1126-
def specialize(self, n: int | smt.QuantifierRef, *ts):
1143+
def specialize(self, n: int | smt.QuantifierRef, *ts, keep=False) -> smt.BoolRef:
11271144
"""
11281145
Instantiate a universal quantifier in the context.
11291146
11301147
>>> x,y = smt.Ints("x y")
11311148
>>> l = Lemma(smt.Implies(smt.ForAll([x],x == y), True))
1132-
>>> l.intros()
1149+
>>> hyp = l.intros()
1150+
>>> hyp
11331151
ForAll(x, x == y)
11341152
>>> l
11351153
[ForAll(x, x == y)] ?|= True
1136-
>>> l.specialize(0, smt.IntVal(42))
1137-
[ForAll(x, x == y), 42 == y] ?|= True
1154+
>>> l.specialize(hyp, smt.IntVal(42))
1155+
42 == y
1156+
>>> l
1157+
[42 == y] ?|= True
11381158
"""
11391159
goalctx = self.top_goal()
1140-
if isinstance(n, smt.QuantifierRef):
1141-
for i, f in enumerate(goalctx.ctx):
1142-
if f.eq(n):
1143-
n = i
1144-
break
1145-
else:
1146-
raise ValueError("Specialize failed. Formula not in context", n)
1147-
thm = goalctx.ctx[n]
1160+
(n, thm) = goalctx.ctx_find(n)
11481161
if isinstance(thm, smt.QuantifierRef) and thm.is_forall():
11491162
l = kd.kernel.specialize(ts, thm)
11501163
self.add_lemma(l)
1151-
self.goals[-1] = goalctx._replace(ctx=goalctx.ctx + [l.thm.arg(1)])
1152-
return self
1164+
# kernel.specialize returns Implies(forall x, P, P[t/x])
1165+
newformula = l.thm.arg(1)
1166+
if keep:
1167+
self.goals[-1] = goalctx._replace(ctx=goalctx.ctx + [newformula])
1168+
else:
1169+
self.goals[-1] = goalctx._replace(
1170+
ctx=goalctx.ctx[:n] + [newformula] + goalctx.ctx[n + 1 :]
1171+
)
1172+
return newformula
11531173
else:
11541174
foralls = {
11551175
n: formula
@@ -1184,6 +1204,28 @@ def ext(self, at=None):
11841204
else:
11851205
raise ValueError("Ext failed. Target is not an equality", target)
11861206

1207+
def andE(self, n: int | smt.BoolRef) -> list[smt.BoolRef]:
1208+
"""
1209+
Eliminate an `And` in the context.
1210+
1211+
>>> p,q = smt.Bools("p q")
1212+
>>> l = Lemma(smt.Implies(smt.And(p, q), p))
1213+
>>> _ = l.intros()
1214+
>>> p,q = l.andE(0)
1215+
>>> l
1216+
[p, q] ?|= p
1217+
"""
1218+
goalctx = self.top_goal()
1219+
(at, formula) = goalctx.ctx_find(n)
1220+
if smt.is_and(formula):
1221+
children = formula.children()
1222+
self.goals[-1] = goalctx._replace(
1223+
ctx=goalctx.ctx[:at] + children + goalctx.ctx[at + 1 :]
1224+
)
1225+
return children
1226+
else:
1227+
raise ValueError("andE failed. Not an and", formula)
1228+
11871229
def split(self, at=None) -> "ProofState":
11881230
"""
11891231
`split` breaks apart an `And` or bi-implication `==` goal.
@@ -1243,21 +1285,22 @@ def cb():
12431285
else:
12441286
(at, hyp) = goalctx.ctx_find(at)
12451287
if smt.is_or(hyp):
1288+
# Make N new goals for each disjunct
12461289
self.pop_goal()
12471290
for c in hyp.children():
12481291
self.goals.append(
12491292
goalctx._replace(ctx=ctx[:at] + [c] + ctx[at + 1 :], goal=goal)
12501293
)
1251-
elif smt.is_and(hyp):
1294+
elif smt.is_and(hyp): # TODO: phase this out in favor of andE.
12521295
self.pop_goal()
12531296
self.goals.append(
12541297
goalctx._replace(
12551298
ctx=ctx[:at] + ctx[at].children() + ctx[at + 1 :], goal=goal
12561299
)
12571300
)
1301+
return self
12581302
else:
12591303
raise ValueError("Split failed on", ctx[at], "in context", ctx)
1260-
return self
12611304

12621305
def left(self, n=0):
12631306
"""

src/kdrag/theories/real/__init__.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
124124
abs = kd.define("absR", [x], smt.If(x >= 0, x, -x))
125125
sgn = kd.define("sgn", [x], smt.If(x > 0, 1, smt.If(x < 0, -1, 0)))
126126

127+
abs_if_pos = kd.prove(ForAll([x], x >= 0, abs(x) == x), by=[abs.defn])
127128
sgn_abs = kd.prove(ForAll([x], abs(x) * sgn(x) == x), by=[abs.defn, sgn.defn])
128129
abs_le = kd.prove(
129130
ForAll([x, y], (abs(x) <= y) == smt.And(-y <= x, x <= y)), by=[abs.defn]
@@ -262,9 +263,13 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
262263
)
263264

264265
sqr = kd.define("sqr", [x], x * x)
266+
sqr_pos = kd.prove(
267+
ForAll([x], sqr(x) >= 0),
268+
by=[sqr.defn],
269+
)
270+
sqr_neg = kd.prove(smt.ForAll([x], sqr(-x) == sqr(x)), by=[sqr.defn])
265271

266-
267-
sqrt = kd.define("sqrt", [x], x**0.5)
272+
sqrt = kd.define("sqrt", [x], x ** "1/2")
268273

269274
_l = kd.Lemma(kd.QForAll([x], x >= 0, sqrt(x) >= 0))
270275
_ = _l.fix()
@@ -274,6 +279,8 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
274279

275280
sqrt_define = kd.prove(smt.ForAll([x], sqrt(x) == x**0.5), by=[sqrt.defn, pow.defn])
276281

282+
sqrt_one = kd.prove(sqrt(1) == 1, by=[sqrt.defn])
283+
sqrt_zero = kd.prove(sqrt(0) == 0, by=[sqrt.defn])
277284
_l = kd.Lemma(kd.QForAll([x], x >= 0, sqrt(x) ** 2 == x))
278285
_ = _l.fix()
279286
_l.unfold()
@@ -284,12 +291,31 @@ def abstract_arith(t: smt.ExprRef) -> smt.ExprRef:
284291
kd.QForAll([x], x >= 0, sqr(sqrt(x)) == x), by=[sqrt_square, sqr.defn]
285292
)
286293

294+
295+
sqrt_mul = kd.prove(
296+
kd.QForAll([x, y], x >= 0, y >= 0, sqrt(x * y) == sqrt(x) * sqrt(y)),
297+
by=[sqrt.defn, mul.defn],
298+
)
299+
sqrt_mono = kd.prove(
300+
kd.QForAll([x, y], x >= 0, y >= 0, x <= y, sqrt(x) <= sqrt(y)),
301+
by=[sqrt.defn],
302+
)
303+
287304
_l = kd.Lemma(kd.QForAll([x], x >= 0, sqrt(sqr(x)) == x))
288305
_ = _l.fix()
289306
_l.unfold()
290307
_l.auto()
291308
sqrt_sqr = _l.qed()
292309

310+
sqrt_sqr_neg = kd.prove(
311+
kd.QForAll([x], x <= 0, sqrt(sqr(x)) == -x), by=[sqrt_sqr, sqr_neg]
312+
)
313+
314+
sqrt_sqr_abs = kd.prove(
315+
kd.QForAll([x], sqrt(sqr(x)) == abs(x)),
316+
by=[abs.defn, sqrt_sqr, sqrt_sqr_neg, sqr_pos],
317+
)
318+
293319
exp = smt.Function("exp", R, R) # smt.Const("exp", kd.R >> kd.R)
294320
exp_add = kd.axiom(smt.ForAll([x, y], exp(x + y) == exp(x) * exp(y)))
295321
exp_lower = kd.axiom(

src/kdrag/theories/real/complex.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
conj = kd.define("conj", [z], C.C(z.re, -z.im))
1111

1212
C0 = C.C(0, 0)
13+
zero = C0
1314
C1 = C.C(1, 0)
15+
one = C1
1416
Ci = C.C(0, 1)
17+
j = Ci
1518

1619
add_zero = kd.prove(smt.ForAll([z], z + C0 == z), by=[add.defn])
1720
mul_zero = kd.prove(smt.ForAll([z], z * C0 == C0), by=[mul.defn])
@@ -21,6 +24,9 @@
2124
smt.ForAll([z, w, u], (z + (w + u)) == ((z + w) + u)), by=[add.defn]
2225
)
2326
mul_comm = kd.prove(smt.ForAll([z, w], z * w == w * z), by=[mul.defn])
27+
mul_assoc = kd.prove(
28+
smt.ForAll([z, w, u], (z * (w * u)) == ((z * w) * u)), by=[mul.defn]
29+
)
2430

2531
# unstable perfoamnce.
2632
# mul_div = kd.prove(ForAll([z,w], Implies(w != C0, z == z * w / w)), by=[div.defn, mul.defn], timeout=1000)

src/kdrag/theories/real/geometry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import kdrag.smt as smt
33

44
# import kdrag.theories.real as real
5-
import kdrag.theories.real.vec as vec
5+
import kdrag.theories.real.vec2 as vec2
66
import kdrag.theories.set as set_
77

88

9-
Point2D = vec.Vec2
9+
Point2D = vec2.Vec2
1010
p, q, a, b, c = smt.Consts("p q a b c", Point2D)
1111

1212
r = smt.Real("r")
13-
circle = kd.define("circle", [c, r], smt.Lambda([p], vec.norm2(p - c) == r * r))
13+
circle = kd.define("circle", [c, r], smt.Lambda([p], vec2.norm2(p - c) == r * r))
1414

1515
Shape = set_.Set(Point2D)
1616

src/kdrag/theories/real/lim_algebra.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def has_lim_smul(l):
6262
n = l.fix()
6363
l.assumes(n > N1)
6464
forall_n = kd.QForAll([n_var], n_var > N1, real.abs(a[n_var] - x) < eps1)
65-
l.specialize(forall_n, n)
65+
l.specialize(forall_n, n, keep=True)
6666
l.have(real.abs(a[n] - x) < eps1, by=[])
6767
l.unfold(seq.smul)
6868
l.simp()
@@ -167,10 +167,10 @@ def has_lim_mul(l):
167167
l.unfold(seq.has_lim, at=0)
168168
l.unfold(seq.has_lim, at=1)
169169
l.specialize(0, eps1)
170-
l.specialize(1, eps2)
170+
l.specialize(1, eps2, keep=True)
171171

172172
l.have(smt.RealVal(1) > 0, by=[])
173-
l.specialize(1, smt.RealVal(1))
173+
l.specialize(1, smt.RealVal(1), keep=True)
174174

175175
N1 = smt.Int("N")
176176
N2 = smt.Int("N")

src/kdrag/theories/real/seq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,12 @@ def cumsum_const(l):
171171
# cumsum pownat = 1 - x^(n + 1) / (1 - x)
172172
# cumsum_diff - the fundamental theorem of calculus for sequences
173173

174+
"""
174175
175176
# TODO: cumsum_comm = cumsum(lambda x, cumsum(lammbda y, a[x,y]) ) ???
177+
178+
179+
# TODO: unstable
176180
@kd.Theorem(
177181
"forall (a : RSeq) (x : Real) (n : Int), cumsum (smul x a) n = smul x (cumsum a) n"
178182
)
@@ -200,6 +204,7 @@ def cumsum_smul(l):
200204
l.unfold(cumsum, smul)
201205
l.simp()
202206
l.auto(by=[smul.defn, cumsum.defn])
207+
"""
203208

204209

205210
@kd.Theorem(

0 commit comments

Comments
 (0)