1313import kdrag .smt as smt
1414from lark import Tree
1515import kdrag as kd
16- from typing import NamedTuple
16+ from typing import NamedTuple , Optional
1717
1818
1919# TODO: let syntax
2626?match: "match" expr "with" match_case+ -> match_
2727match_case: "|" pattern ("=>" | "↦") expr -> match_case
2828?ite: quantifier | "if" expr "then" expr "else" expr -> if
29- ?quantifier: implication | ("forall" | "∀") binders "," expr -> forall_ | \
29+ ?quantifier: iff | ("forall" | "∀") binders "," expr -> forall_ | \
3030 ("exists" | "∃") binders "," expr -> exists_ | ("fun" | "λ") binders ("=>" | "↦") expr -> fun_ | set
31+ ?iff: implication | implication ("<->" | "↔") iff -> iff_
3132?implication: disjunction | disjunction ("->" | "→") implication -> implies
3233?disjunction: conjunction | disjunction ("\\/" | "∨" | "||" | "∪" ) conjunction -> or_
3334?conjunction: comparison | conjunction ("/\\" | "∧" | "&&" | "∩") comparison -> and_
4243 | additive "+" multiplicative -> add | additive "-" multiplicative -> sub
4344?multiplicative: unary
4445 | multiplicative "*" unary -> mul | multiplicative "/" unary -> div
45- ?unary : application | "-" unary -> neg
46+ ?unary : application | "-" unary -> neg | "!" unary -> not_ | "~~~" unary -> bvnot
4647?application: atom atom* -> app
4748?atom: const | num | bool_ | "(" expr ")" | seq | char | string
4849
@@ -254,7 +255,7 @@ def pattern(tree, env: Env, expected_sort: smt.SortRef | None) -> smt.ExprRef:
254255 raise ValueError ("Unknown pattern tree" , tree )
255256
256257
257- def expr (tree , env : Env ) -> smt .ExprRef :
258+ def expr (tree , env : Env , expected_sort : Optional [ smt . SortRef ] = None ) -> smt .ExprRef :
258259 match tree :
259260 # TODO: obviously this is not well typed.
260261 case Tree ("num" , [n ]):
@@ -289,6 +290,10 @@ def expr(tree, env: Env) -> smt.ExprRef:
289290 return smt .And (expr (left , env ), expr (right , env ))
290291 case Tree ("or_" , [left , right ]):
291292 return smt .Or (expr (left , env ), expr (right , env ))
293+ case Tree ("iff_" , [left , right ]):
294+ l = expr (left , env )
295+ r = expr (right , env )
296+ return smt .Eq (l , r )
292297 case Tree ("add" , [left , right ]):
293298 return expr (left , env ) + expr (right , env )
294299 case Tree ("sub" , [left , right ]):
@@ -299,6 +304,8 @@ def expr(tree, env: Env) -> smt.ExprRef:
299304 return expr (left , env ) / expr (right , env )
300305 case Tree ("neg" , [e ]):
301306 return - expr (e , env )
307+ case Tree ("not_" , [e ]):
308+ return smt .Not (expr (e , env ))
302309 case Tree ("eq" , [left , right ]):
303310 return smt .Eq (expr (left , env ), expr (right , env ))
304311 case Tree ("neq" , [left , right ]):
@@ -416,10 +423,10 @@ def parse(s: str, locals=None, globals=None) -> smt.ExprRef:
416423 Lambda(x, x > 0)
417424 >>> parse("if true && false then 1 + 1 else 0")
418425 If(And(True, False), 2, 0)
419- >>> parse("'a'").eq(smt.CharVal('a'))
420- True
421- >>> parse("\\ "hello world \\ "" ).eq(smt.StringVal("hello world" ))
422- True
426+ >>> assert parse("'a'").eq(smt.CharVal('a'))
427+ >>> assert parse(" \\ "hello world \\ "").eq(smt.StringVal("hello world"))
428+ >>> assert parse("!true" ).eq(smt.Not(smt.BoolVal(True) ))
429+ >>> assert parse("true <-> false").eq(smt.Eq(smt.BoolVal( True), smt.BoolVal(False)))
423430 """
424431 env = Env (locals = locals or {}, globals = globals or {})
425432 return start (parser .parse (s ), env )
0 commit comments