Skip to content

Commit d66655b

Browse files
committed
Add synthesize-lowering CLI tool
1 parent ea48a64 commit d66655b

File tree

3 files changed

+150
-1
lines changed

3 files changed

+150
-1
lines changed

mlir-fuzz

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ xdsl-translate = "xdsl_smt.cli.xdsl_translate:main"
2828
verify-pdl = "xdsl_smt.cli.verify_pdl:main"
2929
verifier = "xdsl_smt.cli.transfer_function_verifier:main"
3030
cpp-translate = "xdsl_smt.cli.cpp_translate:main"
31+
synthesize-lowering = "xdsl_smt.cli.synthesize_lowering:main"
3132
synthesize-rewrites = "xdsl_smt.cli.synthesize_rewrites:main"
3233
synthesize-symbolic-rewrites = "xdsl_smt.cli.synthesize_symbolic_rewrites:main"
3334
sanity-checker = "xdsl_smt.cli.sanity_checker:main"
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import argparse
2+
import subprocess as sp
3+
4+
from xdsl.context import Context
5+
from xdsl.ir import Operation
6+
from xdsl.parser import Parser
7+
from xdsl.dialects.builtin import ModuleOp
8+
9+
from xdsl_smt.superoptimization.program_enumeration import enumerate_programs
10+
from xdsl_smt.dialects import get_all_dialects
11+
12+
13+
def register_all_arguments(arg_parser: argparse.ArgumentParser):
14+
arg_parser.add_argument(
15+
"--input-dialect", type=str, help="path to the input dialect", required=True
16+
)
17+
arg_parser.add_argument(
18+
"--input-configuration",
19+
type=str,
20+
help="path to the input configuration",
21+
required=True,
22+
)
23+
arg_parser.add_argument(
24+
"--output-dialect", type=str, help="path to the output dialect", required=True
25+
)
26+
arg_parser.add_argument(
27+
"--output-configuration",
28+
type=str,
29+
help="path to the output configuration",
30+
required=True,
31+
)
32+
arg_parser.add_argument(
33+
"--max-num-ops",
34+
type=int,
35+
help="maximum number of operations in the MLIR programs that are generated",
36+
required=True,
37+
)
38+
arg_parser.add_argument(
39+
"--timeout",
40+
type=int,
41+
help="The timeout passed to the SMT solver in milliseconds",
42+
default=8000,
43+
)
44+
45+
46+
def get_input_operations(
47+
arg_parser: argparse.Namespace, ctx: Context
48+
) -> list[ModuleOp]:
49+
"""
50+
Get a module for each op in the input dialect.
51+
Some ops may appear multiple times with different types and attributes.
52+
"""
53+
op_list = list[ModuleOp]()
54+
for program in enumerate_programs(
55+
max_num_args=99,
56+
num_ops=1,
57+
bv_widths="8",
58+
building_blocks=None,
59+
illegals=[],
60+
dialect_path=arg_parser.input_dialect,
61+
configuration=arg_parser.input_configuration,
62+
additional_options=["--exact-size", "--constant-kind=none"],
63+
):
64+
module = Parser(ctx, program).parse_module()
65+
66+
# Only consider programs that do not reuse values.
67+
# This is because we do not care about `add(%x, %x)`, just `add(%x, %y)`.
68+
def should_add(op: Operation) -> bool:
69+
for op in module.walk():
70+
for operand in op.operands:
71+
if operand.has_more_than_one_use():
72+
return False
73+
return True
74+
75+
if should_add(module):
76+
op_list.append(module)
77+
return op_list
78+
79+
80+
def try_synthesize_lowering_for_module(
81+
module: ModuleOp, args: argparse.Namespace, size: int, ctx: Context
82+
) -> ModuleOp | None:
83+
"""
84+
Try to synthesize a lowering for the given module (containing a single function), given the maximum
85+
number of operations to output.
86+
"""
87+
with open("/tmp/input-synthesize.mlir", "w") as f:
88+
f.write(str(module))
89+
res = sp.run(
90+
[
91+
"superoptimize",
92+
"/tmp/input-synthesize.mlir",
93+
f"--max-num-ops={size}",
94+
f"--dialect={args.output_dialect}",
95+
f"--configuration={args.output_configuration}",
96+
],
97+
capture_output=True,
98+
text=True,
99+
)
100+
if res.returncode != 0:
101+
return None
102+
return Parser(ctx, res.stdout).parse_module()
103+
104+
105+
def main():
106+
ctx = Context()
107+
ctx.allow_unregistered = True
108+
# Register all dialects
109+
for dialect_name, dialect_factory in get_all_dialects().items():
110+
ctx.register_dialect(dialect_name, dialect_factory)
111+
112+
arg_parser = argparse.ArgumentParser()
113+
register_all_arguments(arg_parser)
114+
args = arg_parser.parse_args()
115+
116+
op_list = []
117+
failed_synthesis: list[ModuleOp] = get_input_operations(args, ctx)
118+
119+
for i in range(1, args.max_num_ops + 1):
120+
op_list = failed_synthesis
121+
failed_synthesis = []
122+
print(
123+
f"Trying to synthesize lowerings with up to {i} operations. {len(op_list)} remaining."
124+
)
125+
for op in op_list:
126+
synthesized = try_synthesize_lowering_for_module(op, args, i, ctx)
127+
if synthesized is None:
128+
failed_synthesis.append(op)
129+
print(f"Failed to synthesize lowering with {i} operations for:")
130+
print(op)
131+
print("\n\n")
132+
continue
133+
134+
print("Successfully synthesized lowering for:")
135+
print(op)
136+
print("The synthesized lowering is:")
137+
print(synthesized)
138+
print("\n\n")
139+
print(f"{len(failed_synthesis)} remaining after size {i}.")
140+
141+
print(len(failed_synthesis), "operations could not be lowered:")
142+
for op in failed_synthesis:
143+
print(op)
144+
print("\n\n")
145+
146+
147+
if __name__ == "__main__":
148+
main()

0 commit comments

Comments
 (0)