Skip to content

Commit f760d7f

Browse files
committed
Add convert-to-dialect CLI command
1 parent 10c129c commit f760d7f

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ synthesize-symbolic-rewrites = "xdsl_smt.cli.synthesize_symbolic_rewrites:main"
3434
sanity-checker = "xdsl_smt.cli.sanity_checker:main"
3535
superoptimize = "xdsl_smt.cli.superoptimize:main"
3636
xdsl-smt-run = "xdsl_smt.cli.xdsl_smt_run:main"
37+
convert-pdl-to-dialect = "xdsl_smt.cli.convert_pdl_to_dialect:main"
3738

3839
[build-system]
3940
requires = ["setuptools>=61", "wheel"]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import argparse
2+
from typing import Iterator
3+
4+
from xdsl.parser import Parser
5+
from xdsl.context import Context
6+
from xdsl.dialects import pdl
7+
8+
from xdsl_smt.dialects import get_all_dialects
9+
from xdsl.dialects.builtin import StringAttr
10+
11+
12+
def iterate_pdl_patterns(file_path: str, ctx: Context) -> Iterator[pdl.PatternOp]:
13+
with open(file_path, "r") as f:
14+
input_program = Parser(ctx, f.read()).parse_module()
15+
16+
for pattern in input_program.walk():
17+
if isinstance(pattern, pdl.PatternOp):
18+
yield pattern
19+
20+
21+
def convert_name_to_dialect(name: str) -> str | None:
22+
match name:
23+
case "smt.bv.add":
24+
return "arith.addi"
25+
case "smt.bv.sub":
26+
return "arith.subi"
27+
case "smt.bv.and":
28+
return "arith.andi"
29+
case "smt.bv.or":
30+
return "arith.ori"
31+
case "smt.bv.xor":
32+
return "arith.xori"
33+
case _:
34+
return None
35+
36+
37+
def convert_pdl_to_dialect(pattern: pdl.PatternOp) -> pdl.PatternOp | None:
38+
new_pattern = pattern.clone()
39+
for op in new_pattern.walk():
40+
if isinstance(op, pdl.OperationOp):
41+
if op.opName is None:
42+
return None
43+
new_name = convert_name_to_dialect(op.opName.data)
44+
if new_name is None:
45+
return None
46+
op.opName = StringAttr(new_name)
47+
48+
return new_pattern
49+
50+
51+
def register_all_arguments(arg_parser: argparse.ArgumentParser):
52+
arg_parser.add_argument("input_file", type=str, help="path to the input file")
53+
54+
55+
def main():
56+
ctx = Context()
57+
ctx.allow_unregistered = True
58+
for dialect, factory in get_all_dialects().items():
59+
ctx.register_dialect(dialect, factory)
60+
61+
parser = argparse.ArgumentParser(description="Convert PDL to dialect")
62+
register_all_arguments(parser)
63+
args = parser.parse_args()
64+
65+
for pattern in iterate_pdl_patterns(args.input_file, ctx):
66+
if new_pattern := convert_pdl_to_dialect(pattern):
67+
print(pattern)
68+
print(new_pattern)
69+
70+
71+
if __name__ == "__main__":
72+
main()

0 commit comments

Comments
 (0)