|
| 1 | +import argparse |
| 2 | +from typing import Iterator |
| 3 | + |
| 4 | +from xdsl.parser import Parser |
| 5 | +from xdsl.context import Context |
| 6 | +from xdsl.dialects import pdl |
| 7 | + |
| 8 | +from xdsl_smt.dialects import get_all_dialects |
| 9 | +from xdsl.dialects.builtin import StringAttr |
| 10 | + |
| 11 | + |
| 12 | +def iterate_pdl_patterns(file_path: str, ctx: Context) -> Iterator[pdl.PatternOp]: |
| 13 | + with open(file_path, "r") as f: |
| 14 | + input_program = Parser(ctx, f.read()).parse_module() |
| 15 | + |
| 16 | + for pattern in input_program.walk(): |
| 17 | + if isinstance(pattern, pdl.PatternOp): |
| 18 | + yield pattern |
| 19 | + |
| 20 | + |
| 21 | +def convert_name_to_dialect(name: str) -> str | None: |
| 22 | + match name: |
| 23 | + case "smt.bv.add": |
| 24 | + return "arith.addi" |
| 25 | + case "smt.bv.sub": |
| 26 | + return "arith.subi" |
| 27 | + case "smt.bv.and": |
| 28 | + return "arith.andi" |
| 29 | + case "smt.bv.or": |
| 30 | + return "arith.ori" |
| 31 | + case "smt.bv.xor": |
| 32 | + return "arith.xori" |
| 33 | + case _: |
| 34 | + return None |
| 35 | + |
| 36 | + |
| 37 | +def convert_pdl_to_dialect(pattern: pdl.PatternOp) -> pdl.PatternOp | None: |
| 38 | + new_pattern = pattern.clone() |
| 39 | + for op in new_pattern.walk(): |
| 40 | + if isinstance(op, pdl.OperationOp): |
| 41 | + if op.opName is None: |
| 42 | + return None |
| 43 | + new_name = convert_name_to_dialect(op.opName.data) |
| 44 | + if new_name is None: |
| 45 | + return None |
| 46 | + op.opName = StringAttr(new_name) |
| 47 | + |
| 48 | + return new_pattern |
| 49 | + |
| 50 | + |
| 51 | +def register_all_arguments(arg_parser: argparse.ArgumentParser): |
| 52 | + arg_parser.add_argument("input_file", type=str, help="path to the input file") |
| 53 | + |
| 54 | + |
| 55 | +def main(): |
| 56 | + ctx = Context() |
| 57 | + ctx.allow_unregistered = True |
| 58 | + for dialect, factory in get_all_dialects().items(): |
| 59 | + ctx.register_dialect(dialect, factory) |
| 60 | + |
| 61 | + parser = argparse.ArgumentParser(description="Convert PDL to dialect") |
| 62 | + register_all_arguments(parser) |
| 63 | + args = parser.parse_args() |
| 64 | + |
| 65 | + for pattern in iterate_pdl_patterns(args.input_file, ctx): |
| 66 | + if new_pattern := convert_pdl_to_dialect(pattern): |
| 67 | + print(pattern) |
| 68 | + print(new_pattern) |
| 69 | + |
| 70 | + |
| 71 | +if __name__ == "__main__": |
| 72 | + main() |
0 commit comments