Skip to content

Commit 83e36a6

Browse files
yushangdipytorchmergebot
authored andcommitted
AOTI Minifier (pytorch#139351)
See documentation at https://docs-preview.pytorch.org/pytorch/pytorch/139351/torch.compiler_aot_inductor_minifier.html. Add a minifier for AOTI. Test Plan: python test/inductor/test_minifier.py Pull Request resolved: pytorch#139351 Approved by: https://github.com/desertfire
1 parent 8d070d2 commit 83e36a6

File tree

7 files changed

+855
-6
lines changed

7 files changed

+855
-6
lines changed

docs/source/torch.compiler_aot_inductor.rst

+12
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,15 @@ display results akin to the following:
185185
0.4883
186186
0.4703
187187
[ CUDAFloatType{2,1} ]
188+
189+
190+
Troubleshooting
191+
---------------------------
192+
Below are some useful tools for debugging AOT Inductor.
193+
194+
.. toctree::
195+
:caption: Debugging Tools
196+
:maxdepth: 1
197+
198+
logging
199+
torch.compiler_aot_inductor_minifier
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
AOTInductor Minifier
2+
===========================
3+
4+
If you encounter an error while using AOT Inductor APIs such as
5+
``torch._inductor.aoti_compile_and_package``, ``torch._indcutor.aoti_load_package``,
6+
or running the loaded model of ``aoti_load_package`` on some inputs, you can use the AOTInductor Minifier
7+
to create a minimal nn.Module that reproduce the error by setting ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True``.
8+
9+
10+
One a high-level, there are two steps in using the minifier:
11+
12+
- Set ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True`` or set the environment variable ``DUMP_AOTI_MINIFIER=1``. Then running the script that errors would produce a ``minifier_launcher.py`` script. The output directory is configurable by setting ``torch._dynamo.config.base_dir`` to a valid directory name.
13+
14+
- Run the ``minifier_launcher.py`` script. If the minifier runs successfully, it generates runnable python code in ``repro.py`` which reproduces the exact error.
15+
16+
Here is sample code which will generate an error because we injected an error on relu with
17+
``torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"``.
18+
19+
20+
.. code-block:: py
21+
22+
import torch
23+
from torch._inductor import config as inductor_config
24+
25+
class Model(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.fc1 = torch.nn.Linear(10, 16)
29+
self.relu = torch.nn.ReLU()
30+
self.sigmoid = torch.nn.Sigmoid()
31+
32+
def forward(self, x):
33+
x = self.fc1(x)
34+
x = self.relu(x)
35+
x = self.sigmoid(x)
36+
return x
37+
38+
39+
inductor_config.aot_inductor.dump_aoti_minifier = True
40+
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"
41+
42+
with torch.no_grad():
43+
model = Model().to("cuda")
44+
example_inputs = (torch.randn(8, 10).to("cuda"),)
45+
ep = torch.export.export(model, example_inputs)
46+
package_path = torch._inductor.aoti_compile_and_package(ep, example_inputs)
47+
compiled_model = torch._inductor.aoti_load_package(package_path)
48+
result = compiled_model(*example_inputs)
49+
50+
51+
The code above generates the following error:
52+
53+
::
54+
55+
RuntimeError: Failed to import /tmp/torchinductor_shangdiy/fr/cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py
56+
SyntaxError: invalid syntax (cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py, line 29)
57+
58+
This is because we injected an error on relu, and so the generated triton kernel looks like below. Note that we have ``compile error!``
59+
instead if ``relu``, so we get a ``SyntaxError``.
60+
61+
.. code-block::
62+
63+
@triton.jit
64+
def triton_poi_fused_addmm_relu_sigmoid_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
65+
xnumel = 128
66+
xoffset = tl.program_id(0) * XBLOCK
67+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
68+
xmask = xindex < xnumel
69+
x2 = xindex
70+
x0 = xindex % 16
71+
tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
72+
tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
73+
tmp2 = tmp0 + tmp1
74+
tmp3 = compile error!
75+
tmp4 = tl.sigmoid(tmp3)
76+
tl.store(in_out_ptr0 + (x2), tmp4, xmask)
77+
78+
79+
Since we have ``torch._inductor.config.aot_inductor.dump_aoti_minifier=True``, we also see an additional line indicating where ``minifier_launcher.py`` has
80+
been written to. The output directory is configurable by setting
81+
``torch._dynamo.config.base_dir`` to a valid directory name.
82+
83+
::
84+
85+
W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] Writing minified repro to:
86+
W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_21_08_602433-pid_2861654/minifier/minifier_launcher.py
87+
88+
89+
The ``minifier_launcher.py`` file has the following code. The ``exported_program`` contains the inputs to ``torch._inductor.aoti_compile_and_package``.
90+
The ``command='minify'`` parameter means the script will run the minifier to create a minimal graph module that reproduce the error. Alternatively, you set
91+
use ``command='run'`` to just compile, load, and run the loaded model (without running the minifier).
92+
93+
.. code-block:: py
94+
95+
import torch
96+
import torch._inductor.inductor_prims
97+
98+
import torch._dynamo.config
99+
import torch._inductor.config
100+
import torch._functorch.config
101+
import torch.fx.experimental._config
102+
103+
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error'
104+
torch._inductor.config.aot_inductor.dump_aoti_minifier = True
105+
106+
107+
108+
109+
isolate_fails_code_str = None
110+
111+
112+
113+
# torch version: 2.6.0a0+gitcd9c6e9
114+
# torch cuda version: 12.0
115+
# torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84
116+
117+
118+
# CUDA Info:
119+
# nvcc: NVIDIA (R) Cuda compiler driver
120+
# Copyright (c) 2005-2023 NVIDIA Corporation
121+
# Built on Fri_Jan__6_16:45:21_PST_2023
122+
# Cuda compilation tools, release 12.0, V12.0.140
123+
# Build cuda_12.0.r12.0/compiler.32267302_0
124+
125+
# GPU Hardware Info:
126+
# NVIDIA PG509-210 : 8
127+
128+
exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints/exported_program.pt2')
129+
# print(exported_program.graph)
130+
config_patches={}
131+
if __name__ == '__main__':
132+
from torch._dynamo.repro.aoti import run_repro
133+
with torch.no_grad():
134+
run_repro(exported_program, config_patches=config_patches, accuracy=False, command='minify', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints', check_str=None)
135+
136+
137+
Suppose we kept the ``command='minify'`` option, and run the script, we would get the following output:
138+
139+
::
140+
141+
...
142+
W1031 16:48:08.938000 3598491 torch/_dynamo/repro/aoti.py:89] Writing checkpoint with 3 nodes to /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_48_02_720863-pid_3598491/minifier/checkpoints/3.py
143+
W1031 16:48:08.975000 3598491 torch/_dynamo/repro/aoti.py:101] Copying repro file for convenience to /data/users/shangdiy/pytorch/repro.py
144+
Wrote minimal repro out to repro.py
145+
146+
147+
The ``repro.py`` looks like this. The exported program now contains only the relu node. The minifier successfully reduced the graph to the op that raises the
148+
error.
149+
150+
.. code-block:: py
151+
152+
import torch
153+
from torch import tensor, device
154+
import torch.fx as fx
155+
from torch._dynamo.testing import rand_strided
156+
from math import inf
157+
import torch._inductor.inductor_prims
158+
159+
import torch._dynamo.config
160+
import torch._inductor.config
161+
import torch._functorch.config
162+
import torch.fx.experimental._config
163+
164+
torch._inductor.config.generate_intermediate_hooks = True
165+
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error'
166+
torch._inductor.config.aot_inductor.dump_aoti_minifier = True
167+
168+
169+
170+
171+
isolate_fails_code_str = None
172+
173+
174+
175+
# torch version: 2.6.0a0+gitcd9c6e9
176+
# torch cuda version: 12.0
177+
# torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84
178+
179+
180+
# CUDA Info:
181+
# nvcc: NVIDIA (R) Cuda compiler driver
182+
# Copyright (c) 2005-2023 NVIDIA Corporation
183+
# Built on Fri_Jan__6_16:45:21_PST_2023
184+
# Cuda compilation tools, release 12.0, V12.0.140
185+
# Build cuda_12.0.r12.0/compiler.32267302_0
186+
187+
# GPU Hardware Info:
188+
# NVIDIA PG509-210 : 8
189+
190+
191+
from torch.nn import *
192+
class Repro(torch.nn.Module):
193+
def __init__(self) -> None:
194+
super().__init__()
195+
196+
197+
198+
def forward(self, linear):
199+
relu = torch.ops.aten.relu.default(linear); linear = None
200+
return (relu,)
201+
202+
def load_args(reader):
203+
buf0 = reader.storage('a4e748c3a3d0d4a78cde43e33ad0f9dd41d96e90', 512, device=device(type='cuda', index=0))
204+
reader.tensor(buf0, (8, 16), is_leaf=True) # linear
205+
load_args._version = 0
206+
mod = Repro()
207+
if __name__ == '__main__':
208+
from torch._dynamo.repro.aoti import run_repro, repro_load_args
209+
config_patches={}
210+
with torch.no_grad():
211+
args = repro_load_args(load_args, save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_14_19_09_678890-pid_561538/minifier/checkpoints')
212+
exported_program = torch.export.export(mod, args)
213+
run_repro(exported_program, config_patches=config_patches, accuracy=False, command='run', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_14_19_09_678890-pid_561538/minifier/checkpoints', check_str=None)

test/inductor/test_minifier.py

+72
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,78 @@ def inner(x):
170170
minifier_args=["--offload-to-disk"],
171171
)
172172

173+
# Test that compile errors in AOTInductor can be repro'd (both CPU and CUDA)
174+
def _test_aoti(self, device, expected_error):
175+
# NB: The program is intentionally quite simple, just enough to
176+
# trigger one minification step, no more (dedicated minifier tests
177+
# should exercise minifier only)
178+
run_code = f"""\
179+
class Model(torch.nn.Module):
180+
def __init__(self):
181+
super().__init__()
182+
self.fc1 = torch.nn.Linear(10, 16)
183+
self.relu = torch.nn.ReLU()
184+
self.sigmoid = torch.nn.Sigmoid()
185+
186+
def forward(self, x):
187+
x = self.fc1(x)
188+
x = self.relu(x)
189+
x = self.sigmoid(x)
190+
return x
191+
with torch.no_grad():
192+
model = Model().to("{device}")
193+
example_inputs = (torch.randn(8, 10).to("{device}"),)
194+
ep = torch.export.export(
195+
model, example_inputs
196+
)
197+
torch._inductor.aoti_compile_and_package(
198+
ep, example_inputs
199+
)
200+
"""
201+
return self._run_full_test(run_code, None, expected_error, isolate=True)
202+
203+
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
204+
@inductor_config.patch(
205+
{
206+
"cpp.inject_relu_bug_TESTING_ONLY": "compile_error",
207+
"aot_inductor.dump_aoti_minifier": True,
208+
}
209+
)
210+
def test_aoti_cpu_compile_error(self):
211+
res = self._test_aoti("cpu", "CppCompileError")
212+
self.assertExpectedInline(
213+
res.repro_module(),
214+
"""\
215+
class Repro(torch.nn.Module):
216+
def __init__(self) -> None:
217+
super().__init__()
218+
219+
def forward(self, linear):
220+
relu = torch.ops.aten.relu.default(linear); linear = None
221+
return (relu,)""",
222+
)
223+
224+
@requires_gpu
225+
@inductor_config.patch(
226+
{
227+
"triton.inject_relu_bug_TESTING_ONLY": "compile_error",
228+
"aot_inductor.dump_aoti_minifier": True,
229+
}
230+
)
231+
def test_aoti_gpu_compile_error(self):
232+
res = self._test_aoti(GPU_TYPE, "SyntaxError")
233+
self.assertExpectedInline(
234+
res.repro_module(),
235+
"""\
236+
class Repro(torch.nn.Module):
237+
def __init__(self) -> None:
238+
super().__init__()
239+
240+
def forward(self, linear):
241+
relu = torch.ops.aten.relu.default(linear); linear = None
242+
return (relu,)""",
243+
)
244+
173245

174246
if __name__ == "__main__":
175247
from torch._dynamo.test_case import run_tests

0 commit comments

Comments
 (0)