Skip to content

Commit c207140

Browse files
committed
Update convert_pdl_to_dialect
1 parent db1a85b commit c207140

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

xdsl_smt/cli/convert_pdl_to_dialect.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from xdsl.parser import Parser
55
from xdsl.context import Context
66
from xdsl.dialects import pdl
7+
from xdsl.ir import Attribute
78

89
from xdsl_smt.dialects import get_all_dialects
9-
from xdsl.dialects.builtin import StringAttr
10+
from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv
11+
from xdsl.dialects.builtin import StringAttr, IntegerType
1012

1113

1214
def iterate_pdl_patterns(file_path: str, ctx: Context) -> Iterator[pdl.PatternOp]:
@@ -24,6 +26,14 @@ def convert_name_to_dialect(name: str) -> str | None:
2426
return "arith.addi"
2527
case "smt.bv.sub":
2628
return "arith.subi"
29+
case "smt.bv.mul":
30+
return "arith.muli"
31+
case "smt.bv.lshr":
32+
return "arith.shrui"
33+
case "smt.bv.ashr":
34+
return "arith.shri"
35+
case "smt.bv.shl":
36+
return "arith.shli"
2737
case "smt.bv.and":
2838
return "arith.andi"
2939
case "smt.bv.or":
@@ -34,6 +44,12 @@ def convert_name_to_dialect(name: str) -> str | None:
3444
return None
3545

3646

47+
def convert_type_to_dialect(type: Attribute) -> Attribute | None:
48+
if isinstance(type, smt_bv.BitVectorType):
49+
return IntegerType(type.width.data)
50+
return None
51+
52+
3753
def convert_pdl_to_dialect(pattern: pdl.PatternOp) -> pdl.PatternOp | None:
3854
new_pattern = pattern.clone()
3955
for op in new_pattern.walk():
@@ -44,7 +60,13 @@ def convert_pdl_to_dialect(pattern: pdl.PatternOp) -> pdl.PatternOp | None:
4460
if new_name is None:
4561
return None
4662
op.opName = StringAttr(new_name)
47-
63+
continue
64+
if isinstance(op, pdl.TypeOp):
65+
if op.constantType is not None:
66+
new_type = convert_type_to_dialect(op.constantType)
67+
if new_type is None:
68+
return None
69+
op.constantType = new_type
4870
return new_pattern
4971

5072

@@ -64,7 +86,6 @@ def main():
6486

6587
for pattern in iterate_pdl_patterns(args.input_file, ctx):
6688
if new_pattern := convert_pdl_to_dialect(pattern):
67-
print(pattern)
6889
print(new_pattern)
6990

7091

0 commit comments

Comments
 (0)