44from xdsl .parser import Parser
55from xdsl .context import Context
66from xdsl .dialects import pdl
7+ from xdsl .ir import Attribute
78
89from xdsl_smt .dialects import get_all_dialects
9- from xdsl .dialects .builtin import StringAttr
10+ from xdsl_smt .dialects import smt_bitvector_dialect as smt_bv
11+ from xdsl .dialects .builtin import StringAttr , IntegerType
1012
1113
1214def iterate_pdl_patterns (file_path : str , ctx : Context ) -> Iterator [pdl .PatternOp ]:
@@ -24,6 +26,14 @@ def convert_name_to_dialect(name: str) -> str | None:
2426 return "arith.addi"
2527 case "smt.bv.sub" :
2628 return "arith.subi"
29+ case "smt.bv.mul" :
30+ return "arith.muli"
31+ case "smt.bv.lshr" :
32+ return "arith.shrui"
33+ case "smt.bv.ashr" :
34+ return "arith.shri"
35+ case "smt.bv.shl" :
36+ return "arith.shli"
2737 case "smt.bv.and" :
2838 return "arith.andi"
2939 case "smt.bv.or" :
@@ -34,6 +44,12 @@ def convert_name_to_dialect(name: str) -> str | None:
3444 return None
3545
3646
47+ def convert_type_to_dialect (type : Attribute ) -> Attribute | None :
48+ if isinstance (type , smt_bv .BitVectorType ):
49+ return IntegerType (type .width .data )
50+ return None
51+
52+
3753def convert_pdl_to_dialect (pattern : pdl .PatternOp ) -> pdl .PatternOp | None :
3854 new_pattern = pattern .clone ()
3955 for op in new_pattern .walk ():
@@ -44,7 +60,13 @@ def convert_pdl_to_dialect(pattern: pdl.PatternOp) -> pdl.PatternOp | None:
4460 if new_name is None :
4561 return None
4662 op .opName = StringAttr (new_name )
47-
63+ continue
64+ if isinstance (op , pdl .TypeOp ):
65+ if op .constantType is not None :
66+ new_type = convert_type_to_dialect (op .constantType )
67+ if new_type is None :
68+ return None
69+ op .constantType = new_type
4870 return new_pattern
4971
5072
@@ -64,7 +86,6 @@ def main():
6486
6587 for pattern in iterate_pdl_patterns (args .input_file , ctx ):
6688 if new_pattern := convert_pdl_to_dialect (pattern ):
67- print (pattern )
6889 print (new_pattern )
6990
7091
0 commit comments