44
55from xdsl .dialects import pdl , smt
66from xdsl .ir import SSAValue
7- from xdsl_smt .dialects .smt_bitvector_dialect import AddOp
7+ from xdsl .context import Context
8+ from xdsl_smt .dialects import get_all_dialects
89
910
1011def get_replace_terminator (pattern : pdl .PatternOp ) -> pdl .ReplaceOp :
@@ -19,33 +20,61 @@ def get_replace_terminator(pattern: pdl.PatternOp) -> pdl.ReplaceOp:
1920 return replace
2021
2122
23+ unary_lean_name : dict [str , str ] = {
24+ "smt.bv.neg" : "-" ,
25+ "smt.bv.not" : "BitVec.not" ,
26+ }
27+
2228op_name_to_lean_infix : dict [str , str ] = {
2329 "smt.eq" : "=" ,
2430 "smt.distinct" : "≠" ,
2531 "smt.bv.sub" : "-" ,
2632 "smt.bv.mul" : "*" ,
2733 "smt.bv.add" : "+" ,
34+ "smt.bv.and" : "&&&" ,
35+ "smt.bv.or" : "|||" ,
36+ "smt.bv.xor" : "^^^" ,
37+ "smt.bv.lshr" : ">>>" ,
38+ "smt.bv.shl" : "<<<" ,
39+ }
40+
41+ op_name_to_lean_prefix : dict [str , str ] = {
42+ "smt.bv.ashr" : "BitVec.sshiftRight'" ,
2843}
2944
3045
3146def get_lean_expression (value : SSAValue , operand_to_str : dict [SSAValue , str ]) -> str :
3247 if value in operand_to_str :
3348 return operand_to_str [value ]
49+ if isinstance ((pdl_op := value .owner ), pdl .ResultOp ):
50+ return get_lean_expression (pdl_op .parent_ , operand_to_str )
3451 if isinstance ((pdl_op := value .owner ), pdl .OperationOp ):
52+ operands = [
53+ get_lean_expression (operand , operand_to_str )
54+ for operand in pdl_op .operand_values
55+ ]
3556 assert pdl_op .opName is not None
3657 op_name = pdl_op .opName .data
37- lean_infix = op_name_to_lean_infix [op_name ]
38- lhs = get_lean_expression (pdl_op .operand_values [0 ], operand_to_str )
39- rhs = get_lean_expression (pdl_op .operand_values [1 ], operand_to_str )
40- return f"({ lhs } { lean_infix } { rhs } )"
58+ if len (operands ) == 1 :
59+ return f"({ unary_lean_name [op_name ]} { operands [0 ]} )"
60+ elif len (operands ) == 2 :
61+ lhs , rhs = operands
62+ if op_name in op_name_to_lean_infix .keys ():
63+ lean_infix = op_name_to_lean_infix [op_name ]
64+ return f"({ lhs } { lean_infix } { rhs } )"
65+ if op_name in op_name_to_lean_prefix .keys ():
66+ lean_prefix = op_name_to_lean_prefix [op_name ]
67+ return f"({ lean_prefix } { lhs } { rhs } )"
68+ assert False , f"Unsupported operation: { op_name } "
4169 assert False , f"Unsupported owner: { value .owner } "
4270
4371
4472def get_lean_type (value : SSAValue ) -> str :
45- assert isinstance (value .owner , pdl .TypeOp )
73+ if not isinstance (value .owner , pdl .TypeOp ):
74+ raise ValueError (f"Only pdl.type values are supported, got { value .owner } " )
4675 assert (MlirType := value .owner .constantType ) is not None
4776 assert isinstance (MlirType , smt .BitVectorType )
48- return f"BitVec { MlirType .width } "
77+ return f"BitVec { MlirType .width . data } "
4978
5079
5180def pdl_to_lean (pattern : pdl .PatternOp ) -> str :
@@ -57,7 +86,8 @@ def pdl_to_lean(pattern: pdl.PatternOp) -> str:
5786 for op in pattern .walk ():
5887 if isinstance (op , pdl .OperandOp ):
5988 operand_to_str [op .value ] = names [num_names ]
60- operand_to_type [op .value ] = get_lean_type (op .value )
89+ assert op .value_type is not None
90+ operand_to_type [op .value ] = get_lean_type (op .value_type )
6191 num_names += 1
6292
6393 replace_op = get_replace_terminator (pattern )
@@ -70,9 +100,35 @@ def pdl_to_lean(pattern: pdl.PatternOp) -> str:
70100 rhs_expr = get_lean_expression (replace_op .repl_values [0 ], operand_to_str )
71101
72102 # Convert the PDL pattern to a Lean theorem statement.
73- assert pattern .sym_name is not None
103+ # assert pattern.sym_name is not None
74104 arguments_str = " " .join (
75105 f"({ operand_to_str [value ]} : { operand_to_type [value ]} )"
76106 for value in operand_to_str .keys ()
77107 )
78- return f"theorem { pattern .sym_name } { arguments_str } :\n { lhs_expr } = { rhs_expr } := by\n sorry"
108+ return f"example { arguments_str } :\n { lhs_expr } = { rhs_expr } := by\n bv_decide +acNf"
109+
110+
111+ if __name__ == "__main__" :
112+ import sys
113+
114+ from xdsl .parser import Parser
115+
116+ if len (sys .argv ) != 2 :
117+ print ("Usage: pdl_to_lean.py <pdl_file>" )
118+ sys .exit (1 )
119+
120+ pdl_file = sys .argv [1 ]
121+ with open (pdl_file , "r" ) as f :
122+ pdl_text = f .read ()
123+
124+ ctx = Context ()
125+ for name , factory in get_all_dialects ().items ():
126+ ctx .register_dialect (name , factory )
127+
128+ module = Parser (ctx , pdl_text ).parse_module ()
129+
130+ for pattern in module .walk ():
131+ if isinstance (pattern , pdl .PatternOp ):
132+ lean_theorem = pdl_to_lean (pattern )
133+ # aprint(pattern)
134+ print (lean_theorem )
0 commit comments