Skip to content

Commit 0be43f9

Browse files
committed
Improve z3 connection, and allow use-input-ops
1 parent 55cdfd4 commit 0be43f9

File tree

5 files changed

+66
-52
lines changed

5 files changed

+66
-52
lines changed

tests/filecheck/superoptimize.mlir

Lines changed: 0 additions & 12 deletions
This file was deleted.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
irdl.dialect @arith {
2+
irdl.operation @addi {
3+
%integer = irdl.base "!builtin.integer"
4+
%ovf_none = irdl.is #arith.overflow<none>
5+
%ovf_nsw = irdl.is #arith.overflow<nsw>
6+
%ovf_nuw = irdl.is #arith.overflow<nuw>
7+
%ovf_nsw_nuw = irdl.is #arith.overflow<nsw,nuw>
8+
%ovf = irdl.any_of(%ovf_none, %ovf_nsw, %ovf_nuw, %ovf_nsw_nuw)
9+
10+
irdl.operands(operand0: %integer, operand1: %integer)
11+
irdl.results(result0: %integer)
12+
irdl.attributes {"overflowFlags" = %ovf}
13+
}
14+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: superoptimize %s --max-num-ops=2 --dialect=%S/arith.irdl | filecheck %s
2+
3+
func.func @foo(%x: i32) -> i32 {
4+
%c0 = arith.constant 0 : i32
5+
%r = arith.muli %x, %c0 : i32
6+
func.return %r : i32
7+
}
8+
9+
// CHECK: func.func @foo(%arg0 : i32) -> i32 {
10+
// CHECK-NEXT: %0 = arith.constant 0 : i32
11+
// CHECK-NEXT: func.return %0 : i32
12+
// CHECK-NEXT: }

xdsl_smt/cli/superoptimize.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ def register_all_arguments(arg_parser: argparse.ArgumentParser):
4747
help="The timeout passed to the SMT solver in milliseconds",
4848
default=8000,
4949
)
50+
arg_parser.add_argument(
51+
"--use-input-ops",
52+
help="Reuse the existing operations and values",
53+
action="store_true",
54+
)
55+
arg_parser.add_argument(
56+
"--dialect",
57+
type=str,
58+
help="The IRDL file defining the dialect we want to use for synthesis",
59+
)
60+
arg_parser.add_argument(
61+
"-v",
62+
"--verbose",
63+
dest="verbose",
64+
help="Print debugging information in stderr",
65+
action="store_true",
66+
)
5067

5168

5269
def replace_synth_with_constants(
@@ -72,25 +89,25 @@ def main() -> None:
7289
for dialect_name, dialect_factory in get_all_dialects().items():
7390
ctx.register_dialect(dialect_name, dialect_factory)
7491

92+
with open(args.input_file, "r") as f:
93+
input_program = Parser(ctx, f.read()).parse_module()
94+
7595
current_dir = os.path.dirname(os.path.abspath(__file__))
7696
executable_path = os.path.join(
7797
current_dir, "..", "..", "mlir-fuzz", "build", "bin", "superoptimizer"
7898
)
7999

80-
arith_dialect_path = os.path.join(
81-
current_dir, "..", "..", "mlir-fuzz", "dialects", "arith.mlir"
82-
)
83-
84100
# Start the enumerator
85101
enumerator = sp.Popen(
86102
[
87103
executable_path,
88104
args.input_file,
89-
arith_dialect_path,
105+
args.dialect,
90106
f"--max-num-ops={args.max_num_ops}",
91107
"--pause-between-programs",
92108
"--mlir-print-op-generic",
93109
"--configuration=arith",
110+
f"--use-input-ops={args.use_input_ops}",
94111
],
95112
stdin=sp.PIPE,
96113
stdout=sp.PIPE,
@@ -113,43 +130,26 @@ def main() -> None:
113130
stderr=sp.PIPE,
114131
)
115132
if res.returncode != 0:
116-
print(
117-
f"Error while synthesizing program: {res.stderr.decode('utf-8')}",
118-
file=sys.stderr,
119-
)
133+
if args.verbose:
134+
print("Example failed:", file=sys.stderr)
135+
print(program.decode("utf-8"), file=sys.stderr)
136+
assert enumerator.stdin is not None
137+
enumerator.stdin.write(b"a")
138+
enumerator.stdin.flush()
120139
continue
121140

122-
res_z3 = sp.run(
123-
["z3", "-in", f"-T:{args.timeout}"],
124-
input=res.stdout + b"\n(get-model)",
125-
stdout=sp.PIPE,
126-
stderr=sp.PIPE,
127-
)
141+
resulting_program = Parser(ctx, res.stdout.decode("utf-8")).parse_module()
142+
if resulting_program.is_structurally_equivalent(input_program):
143+
if args.verbose:
144+
print("Synthesized the same program:", file=sys.stderr)
145+
print(resulting_program, file=sys.stderr)
146+
assert enumerator.stdin is not None
147+
enumerator.stdin.write(b"a")
148+
enumerator.stdin.flush()
149+
continue
128150

129-
if "model is not available" not in res_z3.stdout.decode():
130-
values_str: list[str] = re.findall(
131-
r"#([xb][0-9a-f]+)", res_z3.stdout.decode()
132-
)
133-
values: list[IntegerAttr[IntegerType]] = []
134-
for value in values_str:
135-
if value.startswith("x"):
136-
val = int(value[1:], 16)
137-
bitwidth = len(value[1:]) * 4
138-
else:
139-
val = int(value[1:], 2)
140-
bitwidth = len(value[1:])
141-
values.append(IntegerAttr(val, bitwidth))
142-
143-
mlir_program = Parser(ctx, program.decode()).parse_module()
144-
replace_synth_with_constants(mlir_program, values)
145-
146-
print(mlir_program)
147-
exit(0)
148-
149-
# Set a character to the enumerator to continue
150-
assert enumerator.stdin is not None
151-
enumerator.stdin.write(b"a")
152-
enumerator.stdin.flush()
151+
print(resulting_program.ops.first)
152+
exit(0)
153153
except BrokenPipeError as e:
154154
# The enumerator has terminated
155155
pass

0 commit comments

Comments
 (0)