Skip to content

Commit 47b3580

Browse files
authored
python/torch_mlir/repro.py: Reduce inputs (#103)
1 parent 4b6d89a commit 47b3580

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

python/torch_mlir/repro.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _dump_reproducer(
127127

128128
print("---- SNIP ----")
129129
print("import torch")
130-
print("from torch import device") # Used inside fx_g.code
130+
print("from torch import tensor, device") # Used inside fx_g.code
131131
print("import torch_mlir")
132132
print("")
133133

@@ -138,7 +138,13 @@ def _dump_reproducer(
138138
print("model = Model()")
139139
args = ""
140140
for inp in inps:
141-
args += f"torch.ones({inp.shape}, dtype={inp.dtype}), "
141+
if torch.all(inp == 0):
142+
args += f"torch.zeros({inp.shape}, dtype={inp.dtype}), "
143+
elif torch.all(inp == 1):
144+
args += f"torch.ones({inp.shape}, dtype={inp.dtype}), "
145+
else:
146+
torch.set_printoptions(threshold=100000)
147+
args += f"torch.tensor({str(inp)}, dtype={inp.dtype}), "
142148
if dtype is not None:
143149
print(f"model.to({dtype})")
144150
print(f"inps = ({args})")
@@ -148,6 +154,13 @@ def _dump_reproducer(
148154
print("")
149155
print("---- SNIP ----")
150156

157+
def _reduce_inputs(inps, are_inputs_good):
158+
for i in range(len(inps)):
159+
new_inps = inps.copy()
160+
new_inps[i] = torch.zeros(inps[i].shape, dtype=inps[i].dtype)
161+
if are_inputs_good(new_inps):
162+
inps = new_inps
163+
return inps
151164

152165
@torch.no_grad()
153166
def reproduce(
@@ -200,6 +213,7 @@ def module_fails(fx_g, inputs):
200213

201214

202215
def show_reproducer(fx_g: fx.GraphModule, inps: List[torch.Tensor]):
216+
inps = _reduce_inputs(inps, lambda inputs: module_fails(fx_g, inputs))
203217
_dump_reproducer(fx_g, inps, output_type, dtype)
204218

205219
minifier(fx_g, inputs, module_fails, dump_state=show_reproducer)

0 commit comments

Comments
 (0)