|
| 1 | +AOTInductor Minifier |
| 2 | +=========================== |
| 3 | + |
| 4 | +If you encounter an error while using AOT Inductor APIs such as |
| 5 | +``torch._inductor.aoti_compile_and_package``, ``torch._indcutor.aoti_load_package``, |
| 6 | +or running the loaded model of ``aoti_load_package`` on some inputs, you can use the AOTInductor Minifier |
| 7 | +to create a minimal nn.Module that reproduce the error by setting ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True``. |
| 8 | + |
| 9 | + |
| 10 | +One a high-level, there are two steps in using the minifier: |
| 11 | + |
| 12 | +- Set ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True`` or set the environment variable ``DUMP_AOTI_MINIFIER=1``. Then running the script that errors would produce a ``minifier_launcher.py`` script. The output directory is configurable by setting ``torch._dynamo.config.base_dir`` to a valid directory name. |
| 13 | + |
| 14 | +- Run the ``minifier_launcher.py`` script. If the minifier runs successfully, it generates runnable python code in ``repro.py`` which reproduces the exact error. |
| 15 | + |
| 16 | +Here is sample code which will generate an error because we injected an error on relu with |
| 17 | +``torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"``. |
| 18 | + |
| 19 | + |
| 20 | +.. code-block:: py |
| 21 | +
|
| 22 | + import torch |
| 23 | + from torch._inductor import config as inductor_config |
| 24 | +
|
| 25 | + class Model(torch.nn.Module): |
| 26 | + def __init__(self): |
| 27 | + super().__init__() |
| 28 | + self.fc1 = torch.nn.Linear(10, 16) |
| 29 | + self.relu = torch.nn.ReLU() |
| 30 | + self.sigmoid = torch.nn.Sigmoid() |
| 31 | +
|
| 32 | + def forward(self, x): |
| 33 | + x = self.fc1(x) |
| 34 | + x = self.relu(x) |
| 35 | + x = self.sigmoid(x) |
| 36 | + return x |
| 37 | +
|
| 38 | +
|
| 39 | + inductor_config.aot_inductor.dump_aoti_minifier = True |
| 40 | + torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error" |
| 41 | +
|
| 42 | + with torch.no_grad(): |
| 43 | + model = Model().to("cuda") |
| 44 | + example_inputs = (torch.randn(8, 10).to("cuda"),) |
| 45 | + ep = torch.export.export(model, example_inputs) |
| 46 | + package_path = torch._inductor.aoti_compile_and_package(ep, example_inputs) |
| 47 | + compiled_model = torch._inductor.aoti_load_package(package_path) |
| 48 | + result = compiled_model(*example_inputs) |
| 49 | +
|
| 50 | +
|
| 51 | +The code above generates the following error: |
| 52 | + |
| 53 | +:: |
| 54 | + |
| 55 | + RuntimeError: Failed to import /tmp/torchinductor_shangdiy/fr/cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py |
| 56 | + SyntaxError: invalid syntax (cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py, line 29) |
| 57 | + |
| 58 | +This is because we injected an error on relu, and so the generated triton kernel looks like below. Note that we have ``compile error!`` |
| 59 | +instead if ``relu``, so we get a ``SyntaxError``. |
| 60 | + |
| 61 | +.. code-block:: |
| 62 | +
|
| 63 | + @triton.jit |
| 64 | + def triton_poi_fused_addmm_relu_sigmoid_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): |
| 65 | + xnumel = 128 |
| 66 | + xoffset = tl.program_id(0) * XBLOCK |
| 67 | + xindex = xoffset + tl.arange(0, XBLOCK)[:] |
| 68 | + xmask = xindex < xnumel |
| 69 | + x2 = xindex |
| 70 | + x0 = xindex % 16 |
| 71 | + tmp0 = tl.load(in_out_ptr0 + (x2), xmask) |
| 72 | + tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last') |
| 73 | + tmp2 = tmp0 + tmp1 |
| 74 | + tmp3 = compile error! |
| 75 | + tmp4 = tl.sigmoid(tmp3) |
| 76 | + tl.store(in_out_ptr0 + (x2), tmp4, xmask) |
| 77 | +
|
| 78 | +
|
| 79 | +Since we have ``torch._inductor.config.aot_inductor.dump_aoti_minifier=True``, we also see an additional line indicating where ``minifier_launcher.py`` has |
| 80 | +been written to. The output directory is configurable by setting |
| 81 | +``torch._dynamo.config.base_dir`` to a valid directory name. |
| 82 | + |
| 83 | +:: |
| 84 | + |
| 85 | + W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] Writing minified repro to: |
| 86 | + W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_21_08_602433-pid_2861654/minifier/minifier_launcher.py |
| 87 | + |
| 88 | + |
| 89 | +The ``minifier_launcher.py`` file has the following code. The ``exported_program`` contains the inputs to ``torch._inductor.aoti_compile_and_package``. |
| 90 | +The ``command='minify'`` parameter means the script will run the minifier to create a minimal graph module that reproduce the error. Alternatively, you set |
| 91 | +use ``command='run'`` to just compile, load, and run the loaded model (without running the minifier). |
| 92 | + |
| 93 | +.. code-block:: py |
| 94 | +
|
| 95 | + import torch |
| 96 | + import torch._inductor.inductor_prims |
| 97 | +
|
| 98 | + import torch._dynamo.config |
| 99 | + import torch._inductor.config |
| 100 | + import torch._functorch.config |
| 101 | + import torch.fx.experimental._config |
| 102 | +
|
| 103 | + torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' |
| 104 | + torch._inductor.config.aot_inductor.dump_aoti_minifier = True |
| 105 | +
|
| 106 | +
|
| 107 | +
|
| 108 | +
|
| 109 | + isolate_fails_code_str = None |
| 110 | +
|
| 111 | +
|
| 112 | +
|
| 113 | + # torch version: 2.6.0a0+gitcd9c6e9 |
| 114 | + # torch cuda version: 12.0 |
| 115 | + # torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 |
| 116 | +
|
| 117 | +
|
| 118 | + # CUDA Info: |
| 119 | + # nvcc: NVIDIA (R) Cuda compiler driver |
| 120 | + # Copyright (c) 2005-2023 NVIDIA Corporation |
| 121 | + # Built on Fri_Jan__6_16:45:21_PST_2023 |
| 122 | + # Cuda compilation tools, release 12.0, V12.0.140 |
| 123 | + # Build cuda_12.0.r12.0/compiler.32267302_0 |
| 124 | +
|
| 125 | + # GPU Hardware Info: |
| 126 | + # NVIDIA PG509-210 : 8 |
| 127 | +
|
| 128 | + exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints/exported_program.pt2') |
| 129 | + # print(exported_program.graph) |
| 130 | + config_patches={} |
| 131 | + if __name__ == '__main__': |
| 132 | + from torch._dynamo.repro.aoti import run_repro |
| 133 | + with torch.no_grad(): |
| 134 | + run_repro(exported_program, config_patches=config_patches, accuracy=False, command='minify', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints', check_str=None) |
| 135 | +
|
| 136 | +
|
| 137 | +Suppose we kept the ``command='minify'`` option, and run the script, we would get the following output: |
| 138 | + |
| 139 | +:: |
| 140 | + |
| 141 | + ... |
| 142 | + W1031 16:48:08.938000 3598491 torch/_dynamo/repro/aoti.py:89] Writing checkpoint with 3 nodes to /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_48_02_720863-pid_3598491/minifier/checkpoints/3.py |
| 143 | + W1031 16:48:08.975000 3598491 torch/_dynamo/repro/aoti.py:101] Copying repro file for convenience to /data/users/shangdiy/pytorch/repro.py |
| 144 | + Wrote minimal repro out to repro.py |
| 145 | + |
| 146 | + |
| 147 | +The ``repro.py`` looks like this. The exported program now contains only the relu node. The minifier successfully reduced the graph to the op that raises the |
| 148 | +error. |
| 149 | + |
| 150 | +.. code-block:: py |
| 151 | +
|
| 152 | + import torch |
| 153 | + from torch import tensor, device |
| 154 | + import torch.fx as fx |
| 155 | + from torch._dynamo.testing import rand_strided |
| 156 | + from math import inf |
| 157 | + import torch._inductor.inductor_prims |
| 158 | +
|
| 159 | + import torch._dynamo.config |
| 160 | + import torch._inductor.config |
| 161 | + import torch._functorch.config |
| 162 | + import torch.fx.experimental._config |
| 163 | +
|
| 164 | + torch._inductor.config.generate_intermediate_hooks = True |
| 165 | + torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' |
| 166 | + torch._inductor.config.aot_inductor.dump_aoti_minifier = True |
| 167 | +
|
| 168 | +
|
| 169 | +
|
| 170 | +
|
| 171 | + isolate_fails_code_str = None |
| 172 | +
|
| 173 | +
|
| 174 | +
|
| 175 | + # torch version: 2.6.0a0+gitcd9c6e9 |
| 176 | + # torch cuda version: 12.0 |
| 177 | + # torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 |
| 178 | +
|
| 179 | +
|
| 180 | + # CUDA Info: |
| 181 | + # nvcc: NVIDIA (R) Cuda compiler driver |
| 182 | + # Copyright (c) 2005-2023 NVIDIA Corporation |
| 183 | + # Built on Fri_Jan__6_16:45:21_PST_2023 |
| 184 | + # Cuda compilation tools, release 12.0, V12.0.140 |
| 185 | + # Build cuda_12.0.r12.0/compiler.32267302_0 |
| 186 | +
|
| 187 | + # GPU Hardware Info: |
| 188 | + # NVIDIA PG509-210 : 8 |
| 189 | +
|
| 190 | +
|
| 191 | + from torch.nn import * |
| 192 | + class Repro(torch.nn.Module): |
| 193 | + def __init__(self) -> None: |
| 194 | + super().__init__() |
| 195 | +
|
| 196 | +
|
| 197 | +
|
| 198 | + def forward(self, linear): |
| 199 | + relu = torch.ops.aten.relu.default(linear); linear = None |
| 200 | + return (relu,) |
| 201 | +
|
| 202 | + def load_args(reader): |
| 203 | + buf0 = reader.storage('a4e748c3a3d0d4a78cde43e33ad0f9dd41d96e90', 512, device=device(type='cuda', index=0)) |
| 204 | + reader.tensor(buf0, (8, 16), is_leaf=True) # linear |
| 205 | + load_args._version = 0 |
| 206 | + mod = Repro() |
| 207 | + if __name__ == '__main__': |
| 208 | + from torch._dynamo.repro.aoti import run_repro, repro_load_args |
| 209 | + config_patches={} |
| 210 | + with torch.no_grad(): |
| 211 | + args = repro_load_args(load_args, save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_14_19_09_678890-pid_561538/minifier/checkpoints') |
| 212 | + exported_program = torch.export.export(mod, args) |
| 213 | + run_repro(exported_program, config_patches=config_patches, accuracy=False, command='run', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_14_19_09_678890-pid_561538/minifier/checkpoints', check_str=None) |
0 commit comments