Skip to content

Commit bceb1a5

Browse files
committed
Add --out option to generate output directory
1 parent 0aa7dc5 commit bceb1a5

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

xdsl_smt/cli/synthesize_rewrites.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import argparse
5+
import os
56
import subprocess as sp
67
import sys
78
import time
@@ -176,21 +177,9 @@ def register_all_arguments(arg_parser: argparse.ArgumentParser):
176177
default=EnumerationOrder.SIZE,
177178
)
178179
arg_parser.add_argument(
179-
"--out-canonicals",
180+
"--out",
180181
type=str,
181-
help="the file in which to write the generated canonical programs",
182-
default="",
183-
)
184-
arg_parser.add_argument(
185-
"--out-illegals",
186-
type=str,
187-
help="the file in which to write the generated illegal programs",
188-
default="",
189-
)
190-
arg_parser.add_argument(
191-
"--out-rewrites",
192-
type=str,
193-
help="the file in which to write the generated rewrite rules",
182+
help="the directory in which to write the generated files",
194183
default="",
195184
)
196185
arg_parser.add_argument(
@@ -588,21 +577,26 @@ def main() -> None:
588577
for bucket_stat in bucket_stats:
589578
print(bucket_stat)
590579

591-
if args.out_canonicals != "":
592-
with open(args.out_canonicals, "w", encoding="UTF-8") as f:
580+
if args.out != "":
581+
os.makedirs(args.out, exist_ok=True)
582+
with open(
583+
os.path.join(args.out, "canonicals.mlir"), "w", encoding="UTF-8"
584+
) as f:
593585
for program in canonicals:
594586
f.write(str(program.func))
595587
f.write("\n// -----\n")
596588

597-
if args.out_rewrites != "":
598589
module = ModuleOp([rewrite.to_pdl() for rewrite in rewrites])
599-
with open(args.out_rewrites, "w", encoding="UTF-8") as f:
590+
with open(
591+
os.path.join(args.out, "rewrites.mlir"), "w", encoding="UTF-8"
592+
) as f:
600593
f.write(str(module))
601594
f.write("\n")
602595

603-
if args.out_illegals != "":
604596
module = ModuleOp([illegal.func.clone() for illegal in illegals])
605-
with open(args.out_illegals, "w", encoding="UTF-8") as f:
597+
with open(
598+
os.path.join(args.out, "illegals.mlir"), "w", encoding="UTF-8"
599+
) as f:
606600
f.write(str(module))
607601
f.write("\n")
608602

0 commit comments

Comments
 (0)