|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import argparse |
| 5 | +import os |
5 | 6 | import subprocess as sp |
6 | 7 | import sys |
7 | 8 | import time |
@@ -176,21 +177,9 @@ def register_all_arguments(arg_parser: argparse.ArgumentParser): |
176 | 177 | default=EnumerationOrder.SIZE, |
177 | 178 | ) |
178 | 179 | arg_parser.add_argument( |
179 | | - "--out-canonicals", |
| 180 | + "--out", |
180 | 181 | type=str, |
181 | | - help="the file in which to write the generated canonical programs", |
182 | | - default="", |
183 | | - ) |
184 | | - arg_parser.add_argument( |
185 | | - "--out-illegals", |
186 | | - type=str, |
187 | | - help="the file in which to write the generated illegal programs", |
188 | | - default="", |
189 | | - ) |
190 | | - arg_parser.add_argument( |
191 | | - "--out-rewrites", |
192 | | - type=str, |
193 | | - help="the file in which to write the generated rewrite rules", |
| 182 | + help="the directory in which to write the generated files", |
194 | 183 | default="", |
195 | 184 | ) |
196 | 185 | arg_parser.add_argument( |
@@ -588,21 +577,26 @@ def main() -> None: |
588 | 577 | for bucket_stat in bucket_stats: |
589 | 578 | print(bucket_stat) |
590 | 579 |
|
591 | | - if args.out_canonicals != "": |
592 | | - with open(args.out_canonicals, "w", encoding="UTF-8") as f: |
| 580 | + if args.out != "": |
| 581 | + os.makedirs(args.out, exist_ok=True) |
| 582 | + with open( |
| 583 | + os.path.join(args.out, "canonicals.mlir"), "w", encoding="UTF-8" |
| 584 | + ) as f: |
593 | 585 | for program in canonicals: |
594 | 586 | f.write(str(program.func)) |
595 | 587 | f.write("\n// -----\n") |
596 | 588 |
|
597 | | - if args.out_rewrites != "": |
598 | 589 | module = ModuleOp([rewrite.to_pdl() for rewrite in rewrites]) |
599 | | - with open(args.out_rewrites, "w", encoding="UTF-8") as f: |
| 590 | + with open( |
| 591 | + os.path.join(args.out, "rewrites.mlir"), "w", encoding="UTF-8" |
| 592 | + ) as f: |
600 | 593 | f.write(str(module)) |
601 | 594 | f.write("\n") |
602 | 595 |
|
603 | | - if args.out_illegals != "": |
604 | 596 | module = ModuleOp([illegal.func.clone() for illegal in illegals]) |
605 | | - with open(args.out_illegals, "w", encoding="UTF-8") as f: |
| 597 | + with open( |
| 598 | + os.path.join(args.out, "illegals.mlir"), "w", encoding="UTF-8" |
| 599 | + ) as f: |
606 | 600 | f.write(str(module)) |
607 | 601 | f.write("\n") |
608 | 602 |
|
|
0 commit comments