@@ -127,7 +127,7 @@ def _dump_reproducer(
127127
128128 print ("---- SNIP ----" )
129129 print ("import torch" )
130- print ("from torch import device" ) # Used inside fx_g.code
130+ print ("from torch import tensor, device" ) # Used inside fx_g.code
131131 print ("import torch_mlir" )
132132 print ("" )
133133
@@ -138,7 +138,13 @@ def _dump_reproducer(
138138 print ("model = Model()" )
139139 args = ""
140140 for inp in inps :
141- args += f"torch.ones({ inp .shape } , dtype={ inp .dtype } ), "
141+ if torch .all (inp == 0 ):
142+ args += f"torch.zeros({ inp .shape } , dtype={ inp .dtype } ), "
143+ elif torch .all (inp == 1 ):
144+ args += f"torch.ones({ inp .shape } , dtype={ inp .dtype } ), "
145+ else :
146+ torch .set_printoptions (threshold = 100000 )
147+ args += f"torch.tensor({ str (inp )} , dtype={ inp .dtype } ), "
142148 if dtype is not None :
143149 print (f"model.to({ dtype } )" )
144150 print (f"inps = ({ args } )" )
@@ -148,6 +154,13 @@ def _dump_reproducer(
148154 print ("" )
149155 print ("---- SNIP ----" )
150156
157+ def _reduce_inputs (inps , are_inputs_good ):
158+ for i in range (len (inps )):
159+ new_inps = inps .copy ()
160+ new_inps [i ] = torch .zeros (inps [i ].shape , dtype = inps [i ].dtype )
161+ if are_inputs_good (new_inps ):
162+ inps = new_inps
163+ return inps
151164
152165@torch .no_grad ()
153166def reproduce (
@@ -200,6 +213,7 @@ def module_fails(fx_g, inputs):
200213
201214
202215 def show_reproducer (fx_g : fx .GraphModule , inps : List [torch .Tensor ]):
216+ inps = _reduce_inputs (inps , lambda inputs : module_fails (fx_g , inputs ))
203217 _dump_reproducer (fx_g , inps , output_type , dtype )
204218
205219 minifier (fx_g , inputs , module_fails , dump_state = show_reproducer )
0 commit comments