Skip to content

Commit 3a6a962

Browse files
committed
moved complex up. Universal type as open datatype
1 parent 5e34b48 commit 3a6a962

File tree

5 files changed

+81
-143
lines changed

5 files changed

+81
-143
lines changed

src/kdrag/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,38 @@ def Some(x: smt.ExprRef) -> smt.DatatypeRef:
224224
return OptionSort(x.sort()).Some(x)
225225

226226

227+
Complex = datatype.Struct("C", ("re", smt.RealSort()), ("im", smt.RealSort()))
228+
229+
z, w, u, z1, z2 = smt.Consts("z w u z1 z2", Complex)
230+
complex_add = notation.add.define([z1, z2], Complex.C(z1.re + z2.re, z1.im + z2.im))
231+
complex_mul = notation.mul.define(
232+
[z1, z2], Complex.C(z1.re * z2.re - z1.im * z2.im, z1.re * z2.im + z1.im * z2.re)
233+
)
234+
complex_div = notation.div.define(
235+
[z1, z2],
236+
Complex.C(
237+
(z1.re * z2.re + z1.im * z2.im) / (z2.re**2 + z2.im**2),
238+
(z1.im * z2.re - z1.re * z2.im) / (z2.re**2 + z2.im**2),
239+
),
240+
)
241+
J = Complex.C(0, 1)
242+
complex_one = Complex.C(1, 0)
243+
244+
245+
def ComplexSort() -> smt.DatatypeSortRef:
246+
"""
247+
>>> C = ComplexSort()
248+
>>> z, w = smt.Consts("z w", C)
249+
>>> full_simp(J + J)
250+
C(0, 2)
251+
>>> full_simp(J * J)
252+
C(-1, 0)
253+
>>> full_simp(J / J)
254+
C(1, 0)
255+
"""
256+
return Complex
257+
258+
227259
def Assoc(f, T=None) -> smt.BoolRef:
228260
"""
229261
>>> Assoc(smt.Function("f", smt.IntSort(), smt.IntSort(), smt.IntSort()))

src/kdrag/contrib/pcode/asmspec.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,18 @@ def __init__(self, spec: AsmSpec):
661661
def auto(self, n, **kwargs):
662662
vc = self.vcs[n]
663663
self.pfs.append(
664-
vc.verify(self.ctx, **kwargs)
664+
kd.kernel.prove(
665+
vc.vc(self.ctx),
666+
by=(
667+
[vc.verify(self.ctx, **kwargs)]
668+
+ [f.defn for f in self.ctx.definitions]
669+
),
670+
)
665671
) # TODO: This makes an ctx.definitions unfolded proof, which isn't what is expected?
672+
return self
673+
674+
def __repr__(self):
675+
return f"AsmProofState(num_vcs={len(self.vcs)}, num_pfs={len(self.pfs)})"
666676

667677
def lemma(self, n) -> VCProofState:
668678
return VCProofState(self.vcs[n], self.ctx, _parent=self)

src/kdrag/parsers/microlean.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def expr(tree, env: Env) -> smt.ExprRef:
192192
return expr(left, env) / expr(right, env)
193193
case Tree("eq", [left, right]):
194194
return smt.Eq(expr(left, env), expr(right, env))
195+
case Tree("neq", [left, right]):
196+
return expr(left, env) != expr(right, env) # type: ignore
195197
case Tree("le", [left, right]):
196198
return expr(left, env) <= expr(right, env)
197199
case Tree("lt", [left, right]):

src/kdrag/theories/real/complex.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,13 @@
22
import kdrag.smt as smt
33
import kdrag.theories.real as real
44

5-
C = kd.Struct("C", ("re", smt.RealSort()), ("im", smt.RealSort()))
6-
5+
C = kd.ComplexSort()
76
z, w, u, z1, z2 = smt.Consts("z w u z1 z2", C)
8-
add = kd.define("add", [z1, z2], C.C(z1.re + z2.re, z1.im + z2.im))
9-
kd.notation.add.register(C, add)
10-
mul = kd.define(
11-
"mul", [z1, z2], C.C(z1.re * z2.re - z1.im * z2.im, z1.re * z2.im + z1.im * z2.re)
12-
)
13-
kd.notation.mul.register(C, mul)
7+
add = kd.complex_add
8+
mul = kd.complex_mul
9+
div = kd.complex_div
1410
conj = kd.define("conj", [z], C.C(z.re, -z.im))
1511

16-
17-
div = kd.notation.div.define(
18-
[z1, z2],
19-
C.C(
20-
(z1.re * z2.re + z1.im * z2.im) / (z2.re**2 + z2.im**2),
21-
(z1.im * z2.re - z1.re * z2.im) / (z2.re**2 + z2.im**2),
22-
),
23-
)
24-
2512
C0 = C.C(0, 0)
2613
C1 = C.C(1, 0)
2714
Ci = C.C(0, 1)

src/kdrag/theories/univ.py

Lines changed: 32 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -2,131 +2,6 @@
22
import kdrag.smt as smt
33
import functools
44

5-
_Type0 = kd.Inductive("Type0")
6-
_Type0.declare("Int", ("int", smt.IntSort()))
7-
_Type0.declare("Bool", ("bool", smt.BoolSort()))
8-
_Type0.declare("Real", ("real", smt.RealSort()))
9-
Type0 = _Type0.create()
10-
_x = smt.Const("x", Type0)
11-
12-
Type0Set = smt.FullSet(Type0)
13-
Int0 = smt.Lambda([_x], _x.is_Int)
14-
Nat0 = smt.Lambda([_x], smt.And(Int0(_x), _x.int >= 0))
15-
Bool0 = smt.Lambda([_x], _x.is_Bool)
16-
Real0 = smt.Lambda([_x], _x.is_Real)
17-
Unit0 = smt.Lambda([_x], smt.And(_x.is_Bool, _x.bool == smt.BoolVal(True)))
18-
19-
# Polymorphic definitions
20-
neg = kd.notation.neg.define(
21-
[_x],
22-
kd.cond(
23-
(_x.is_Bool, Type0.Bool(_x.bool)),
24-
(_x.is_Int, Type0.Int(-_x.int)),
25-
(_x.is_Real, Type0.Real(-_x.real)),
26-
),
27-
)
28-
29-
"""
30-
31-
def from_dt(dt : smt.DatatypeSortRef) -> smt.ExprRef:
32-
# encode x to tagged sequences
33-
kd.cond(
34-
*[dt.recognizer(i), smt.Concat(smt.Unit(Type0.IntVal(i)), smt.Unit(dt.accessor(i,0)(x))) for i in range(dt.num_constructors())])
35-
)
36-
def to_set(dt : smt.DatatypeSortRef) -> smt.ExprRef:
37-
# set corresponding to encoding
38-
# smt.Exists([z], from_dt(dt)(z) == x)
39-
# but there is a more computational version.
40-
kd.prove(forall([x], to_set(dt)(from_dt(dt))) # encoding is in encoding.
41-
def to_dt(dt : smt.DatatypeSortRef) -> smt.ExprRef:
42-
# decode tagged sequences to x
43-
44-
class OpenFuncDecl:
45-
name : str
46-
domains : tuple[smt.SortRef]
47-
range : smt.SortRef
48-
head : smt.FuncDeclRef
49-
last : smt.FuncDeclRef
50-
defns : list[smt.FuncDeclRef] = []
51-
undef : smt.FuncDeclRef
52-
def __init__(self, name : str, *sorts : smt.SortRef):
53-
self.name = name
54-
self.defns = []
55-
self.undef = kd.FuncDecl(name, sorts, sorts[-1])
56-
def define(self, vs, cond, body) -> smt.FuncDeclRef:
57-
assert [v.sort() == s for v,s in zip(vs, self.sorts]
58-
assert body.sort() == self.range
59-
newundef = smt.FreshFunction(*[v.sort() for v in vs], self.range)
60-
self.defns.append(kd.define(self.last.name(), vs, smt.If(cond, body, newundef(*vs)))
61-
self.undef = newundef
62-
63-
64-
def Seq(T : smt.FuncRef) -> smt.QuantifierRef:
65-
x = smt.FreshConst("x", T.domain())
66-
return smt.Lambda([x], smt.And(x.is_Seq, smt.SeqFoldLeftsmt.True, T, (x.seq)
67-
def Vec(n, T : smt.FuncRef) -> smt.QuantifierRef:
68-
x = smt.FreshConst("x", T.domain())
69-
return smt.Lambda([x], smt.And(Seq(T)(x), smt.Length(x.seq) == n))
70-
def Id(T, x, y) -> smt.QuantifierRef:
71-
p = smt.FreshConst("p", T.domain())
72-
return smt.Lambda([p], smt.And(Unit(p) , x == y))
73-
74-
75-
76-
Type1 = kd.Inductive("Type1")
77-
Type1.declare("Type0", ("type0", Type0))
78-
Type1.declare("Seq", ("seq", Type1))
79-
Type1Sort = smt.DatatypeSort("Type1")
80-
# We could have deeper recursion at the same universe level
81-
# Type1.declare(
82-
# "Array", ("array", smt.ArraySort(smt.ArraySort(Type1Sort, Type0), Type1Sort))
83-
# )
84-
Type1.declare("Array", ("array", smt.ArraySort(Type0, Type1Sort)))
85-
Type1 = Type1.create()
86-
87-
def Int(x : smt.ExprRef) -> smt.BoolRef:
88-
89-
90-
def level(x : smt.ExprRef) -> smt.IntRef:
91-
x.sort().name()
92-
"""
93-
94-
95-
def Int(l: int) -> smt.QuantifierRef:
96-
"""
97-
>>> Int(0)
98-
Lambda(x, is(Int, x))
99-
>>> Int(1)
100-
Lambda(x, And(is(Type0, x), is(Int, type0(x))))
101-
"""
102-
assert l >= 0
103-
if l == 0:
104-
return Int0
105-
else:
106-
Typel = Type(l)
107-
x = smt.Const("x", Typel)
108-
return smt.Lambda(
109-
[x], smt.And(Typel.recognizer(0)(x), Int(l - 1)(Typel.accessor(0, 0)(x)))
110-
)
111-
112-
113-
@functools.cache
114-
def Type(l: int) -> smt.DatatypeSortRef:
115-
"""
116-
A generic value type at universe level l.
117-
>>> Type(1)
118-
Type1
119-
"""
120-
if l == 0:
121-
return Type0
122-
else:
123-
TypeN = kd.Inductive(f"Type{l}")
124-
TypeN.declare(f"Type{l - 1}", (f"type{l - 1}", Type(l - 1)))
125-
TypeNSort = smt.DatatypeSort(f"Type{l}")
126-
TypeN.declare("Seq", ("seq", smt.SeqSort(TypeNSort)))
127-
TypeN.declare("Array", ("array", smt.ArraySort(Type(l - 1), TypeNSort)))
128-
return TypeN.create()
129-
1305

1316
"""
1327
A different style using an open datatype instead
@@ -164,3 +39,35 @@ def poly_ax(s: smt.SortRef) -> kd.Proof:
16439
# assert positive_poly(s)
16540
x = smt.Const("x", s)
16641
return kd.axiom(smt.ForAll([x], cast(s, box(x)) == x, patterns=[cast(s, box(x))]))
42+
43+
44+
"""
45+
Specialize some to an inductive datatype so we get built in support for injectors.
46+
"""
47+
48+
Type = kd.Inductive("Type")
49+
_Type = smt.DatatypeSort("Type")
50+
Type.declare("Bool", ("bool", smt.BoolSort()))
51+
Type.declare("Int", ("int", smt.IntSort()))
52+
Type.declare("Real", ("real", smt.RealSort()))
53+
Type.declare("Seq", ("seq", smt.SeqSort(_Type)))
54+
Type.declare("Array", ("array", smt.ArraySort(Poly, _Type)))
55+
# Maybe RFun Real -> Type
56+
# IntFun Int -> Type
57+
58+
Type = Type.create()
59+
60+
# Probably unsound
61+
# type_poly_ax = poly_ax(Type)
62+
_x = smt.Const("x", Type)
63+
Int = smt.Lambda([_x], _x.is_Int)
64+
Real = smt.Lambda([_x], _x.is_Real)
65+
Bool = smt.Lambda([_x], _x.is_Bool)
66+
67+
"""
68+
S = smt.Const("S", smt.SetSort(Type))
69+
mul = smt.Const("mul", smt.ArraySort(Type, Type, Type))
70+
semigroup = kd.define(
71+
"semigroup", [S, mul], smt.And(kd.Closed(S, mul), kd.Assoc(mul, T=S))
72+
)
73+
"""

0 commit comments

Comments
 (0)