@@ -47,6 +47,23 @@ def register_all_arguments(arg_parser: argparse.ArgumentParser):
4747 help = "The timeout passed to the SMT solver in milliseconds" ,
4848 default = 8000 ,
4949 )
50+ arg_parser .add_argument (
51+ "--use-input-ops" ,
52+ help = "Reuse the existing operations and values" ,
53+ action = "store_true" ,
54+ )
55+ arg_parser .add_argument (
56+ "--dialect" ,
57+ type = str ,
58+ help = "The IRDL file defining the dialect we want to use for synthesis" ,
59+ )
60+ arg_parser .add_argument (
61+ "-v" ,
62+ "--verbose" ,
63+ dest = "verbose" ,
64+ help = "Print debugging information in stderr" ,
65+ action = "store_true" ,
66+ )
5067
5168
5269def replace_synth_with_constants (
@@ -72,25 +89,25 @@ def main() -> None:
7289 for dialect_name , dialect_factory in get_all_dialects ().items ():
7390 ctx .register_dialect (dialect_name , dialect_factory )
7491
92+ with open (args .input_file , "r" ) as f :
93+ input_program = Parser (ctx , f .read ()).parse_module ()
94+
7595 current_dir = os .path .dirname (os .path .abspath (__file__ ))
7696 executable_path = os .path .join (
7797 current_dir , ".." , ".." , "mlir-fuzz" , "build" , "bin" , "superoptimizer"
7898 )
7999
80- arith_dialect_path = os .path .join (
81- current_dir , ".." , ".." , "mlir-fuzz" , "dialects" , "arith.mlir"
82- )
83-
84100 # Start the enumerator
85101 enumerator = sp .Popen (
86102 [
87103 executable_path ,
88104 args .input_file ,
89- arith_dialect_path ,
105+ args . dialect ,
90106 f"--max-num-ops={ args .max_num_ops } " ,
91107 "--pause-between-programs" ,
92108 "--mlir-print-op-generic" ,
93109 "--configuration=arith" ,
110+ f"--use-input-ops={ args .use_input_ops } " ,
94111 ],
95112 stdin = sp .PIPE ,
96113 stdout = sp .PIPE ,
@@ -113,43 +130,26 @@ def main() -> None:
113130 stderr = sp .PIPE ,
114131 )
115132 if res .returncode != 0 :
116- print (
117- f"Error while synthesizing program: { res .stderr .decode ('utf-8' )} " ,
118- file = sys .stderr ,
119- )
133+ if args .verbose :
134+ print ("Example failed:" , file = sys .stderr )
135+ print (program .decode ("utf-8" ), file = sys .stderr )
136+ assert enumerator .stdin is not None
137+ enumerator .stdin .write (b"a" )
138+ enumerator .stdin .flush ()
120139 continue
121140
122- res_z3 = sp .run (
123- ["z3" , "-in" , f"-T:{ args .timeout } " ],
124- input = res .stdout + b"\n (get-model)" ,
125- stdout = sp .PIPE ,
126- stderr = sp .PIPE ,
127- )
141+ resulting_program = Parser (ctx , res .stdout .decode ("utf-8" )).parse_module ()
142+ if resulting_program .is_structurally_equivalent (input_program ):
143+ if args .verbose :
144+ print ("Synthesized the same program:" , file = sys .stderr )
145+ print (resulting_program , file = sys .stderr )
146+ assert enumerator .stdin is not None
147+ enumerator .stdin .write (b"a" )
148+ enumerator .stdin .flush ()
149+ continue
128150
129- if "model is not available" not in res_z3 .stdout .decode ():
130- values_str : list [str ] = re .findall (
131- r"#([xb][0-9a-f]+)" , res_z3 .stdout .decode ()
132- )
133- values : list [IntegerAttr [IntegerType ]] = []
134- for value in values_str :
135- if value .startswith ("x" ):
136- val = int (value [1 :], 16 )
137- bitwidth = len (value [1 :]) * 4
138- else :
139- val = int (value [1 :], 2 )
140- bitwidth = len (value [1 :])
141- values .append (IntegerAttr (val , bitwidth ))
142-
143- mlir_program = Parser (ctx , program .decode ()).parse_module ()
144- replace_synth_with_constants (mlir_program , values )
145-
146- print (mlir_program )
147- exit (0 )
148-
149- # Set a character to the enumerator to continue
150- assert enumerator .stdin is not None
151- enumerator .stdin .write (b"a" )
152- enumerator .stdin .flush ()
151+ print (resulting_program .ops .first )
152+ exit (0 )
153153 except BrokenPipeError as e :
154154 # The enumerator has terminated
155155 pass
0 commit comments