2727smt .ArrayRef .__call__ = lambda self , arg : self [arg ]
2828
2929
30+ def quantifier_call (self , * args ):
31+ """
32+ Instantiate a quantifier. This does substitution
33+ >>> x,y = smt.Ints("x y")
34+ >>> smt.Lambda([x,y], x + 1)(2,3)
35+ 2 + 1
36+
37+ To apply a Lambda without substituting, use square brackets
38+ >>> smt.Lambda([x,y], x + 1)[2,3]
39+ Select(Lambda([x, y], x + 1), 2, 3)
40+ """
41+ if self .num_vars () != len (args ):
42+ raise TypeError ("Wrong number of arguments" , self , args )
43+ return smt .substitute_vars (
44+ self .body (), * (smt ._py2expr (arg ) for arg in reversed (args ))
45+ )
46+
47+
48+ smt .QuantifierRef .__call__ = quantifier_call
49+
50+
3051class SortDispatch :
3152 """
3253 Sort dispatch is modeled after functools.singledispatch
@@ -114,6 +135,12 @@ def QForAll(vs: list[smt.ExprRef], *hyp_conc) -> smt.BoolRef:
114135
115136 If variables have a property `wf` attached, this is added as a hypothesis.
116137
138+ There is no downside to always using this compared to `smt.ForAll` and it can avoid some errors.
139+
140+ >>> x,y = smt.Ints("x y")
141+ >>> QForAll([x,y], x > 0, y > 0, x + y > 0)
142+ ForAll([x, y], Implies(And(x > 0, y > 0), x + y > 0))
143+
117144 """
118145 conc = hyp_conc [- 1 ]
119146 hyps = hyp_conc [:- 1 ]
@@ -187,13 +214,13 @@ def datatype_call(self, *args):
187214records = {}
188215
189216
190- def Record (name : str , * fields , pred = None ) -> smt .DatatypeSortRef :
217+ def Record (name : str , * fields , pred = None , admit = False ) -> smt .DatatypeSortRef :
191218 """
192219 Define a record datatype.
193220 The optional argument `pred` will add a well-formedness condition to the record
194221 giving something akin to a refinement type.
195222 """
196- if name in records :
223+ if not admit and name in records :
197224 raise Exception ("Record already defined" , name )
198225 rec = smt .Datatype (name )
199226 rec .declare (name , * fields )
@@ -219,9 +246,24 @@ def Record(name: str, *fields, pred=None) -> smt.DatatypeSortRef:
219246 return rec
220247
221248
222- def NewType (name : str , sort : smt .SortRef , pred = None ) -> smt .DatatypeSortRef :
249+ def NewType (
250+ name : str , sort : smt .SortRef , pred = None , admit = False
251+ ) -> smt .DatatypeSortRef :
223252 """Minimal wrapper around a sort for sort based overloading"""
224- return Record (name , ("val" , sort ), pred = pred )
253+ return Record (name , ("val" , sort ), pred = pred , admit = admit )
254+
255+
256+ def Enum (name , args , admit = False ):
257+ """Shorthand for simple enumeration datatypes. Similar to python's Enum.
258+ >>> Color = Enum("Color", "Red Green Blue")
259+ >>> smt.And(Color.Red != Color.Green, Color.Red != Color.Blue)
260+ And(Red != Green, Red != Blue)
261+ """
262+ T = kd .Inductive (name , admit = admit )
263+ for c in args .split ():
264+ T .declare (c )
265+ T = T .create ()
266+ return T
225267
226268
227269def induct_inductive (DT : smt .DatatypeSortRef , x = None , P = None ) -> kd .kernel .Proof :
@@ -257,9 +299,9 @@ def induct_inductive(DT: smt.DatatypeSortRef, x=None, P=None) -> kd.kernel.Proof
257299 )
258300
259301
260- def Inductive (name : str , strict = True ) -> smt .DatatypeSortRef :
302+ def Inductive (name : str , admit = False ) -> smt .DatatypeSortRef :
261303 """Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype"""
262- if strict and name in records :
304+ if not admit and name in records :
263305 raise Exception (
264306 "Datatype with that name already defined. Use keyword strict=False to override" ,
265307 name ,
@@ -271,7 +313,7 @@ def Inductive(name: str, strict=True) -> smt.DatatypeSortRef:
271313 def create ():
272314 dt = oldcreate ()
273315 # Sanity check no duplicate names. Causes confusion.
274- if strict :
316+ if not admit :
275317 names = set ()
276318 for i in range (dt .num_constructors ()):
277319 cons = dt .constructor (i )
@@ -284,7 +326,6 @@ def create():
284326 if n in names :
285327 raise Exception ("Duplicate field name" , n )
286328 names .add (n )
287- x = smt .FreshConst (dt , prefix = "x" )
288329 kd .notation .induct .register (dt , lambda x : induct_inductive (dt , x = x ))
289330 records [name ] = dt
290331 return dt
@@ -318,6 +359,10 @@ def cond(*cases, default=None) -> smt.ExprRef:
318359 return acc
319360
320361
362+ def conde (* cases ):
363+ return smt .Or ([smt .And (c ) for c in cases ])
364+
365+
321366class Cond :
322367 def __init__ (self ):
323368 self .cases = []
0 commit comments