Open
Description
reproduce script,you can download the serialized_inputs.pt
from diffusers import DiffusionPipeline
import torch
from torchao.quantization import autoquant
from tqdm import tqdm
transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev/transformer/", torch_dtype=torch.bfloat16).to("cuda")
transformer.to(memory_format=torch.channels_last)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
transformer = autoquant(transformer, error_on_unseen=False)
example_inputs = torch.load("serialized_inputs.pt", weights_only=True)
example_inputs = {k: v.to("cuda") for k, v in example_inputs.items()}
example_inputs.update({"joint_attention_kwargs": None, "return_dict": False})
for i in tqdm(range(50)):
with torch.no_grad():
transformer(**example_inputs)
env:
torch:2.6.0+cu124
torchao:0.8.0+cu124
diffusers:0.32.0
gpu:rtx4090
error traceback:
SingleProcess AUTOTUNE benchmarking takes 2.6167 seconds and 4.9302 seconds precompiling for 11 choices
>>time: 0.199ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.505ms
>>time: 0.246ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 2.240ms
>>time: 0.239ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 6.40
best_cls=<class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>
activation_shapes: torch.Size([1, 3072]), times_seen: 1
weight_shape: torch.Size([6144, 3072]), dtype: torch.bfloat16, bias_shape: torch.Size([6144])
>>time: 0.068ms for <class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>, to_beat: infms
>>time: 0.038ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.068ms
>>time: 0.038ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.038ms
>>time: 0.039ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.038ms
best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>
activation_shapes: torch.Size([4096, 3072]), times_seen: 1
weight_shape: torch.Size([64, 3072]), dtype: torch.bfloat16, bias_shape: torch.Size([64])
AUTOTUNE addmm(4096x64, 4096x3072, 3072x64)
triton_mm_843 0.0440 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_851 0.0440 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
triton_mm_840 0.0461 ms 95.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_856 0.0471 ms 93.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
bias_addmm 0.0481 ms 91.5%
addmm 0.0502 ms 87.8%
triton_mm_846 0.0512 ms 86.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_849 0.0513 ms 85.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_850 0.0543 ms 81.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_842 0.0553 ms 79.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 4.0817 seconds and 4.6254 seconds precompiling for 20 choices
>>time: 0.048ms for <class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>, to_beat: infms
AUTOTUNE mixed_mm(4096x3072, 3072x64)
triton_mm_874 0.0439 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=256, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_861 0.0440 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_864 0.0451 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_868 0.0451 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
triton_mm_875 0.0451 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=256, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_858 0.0492 ms 89.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
fallback_mixed_mm 0.0512 ms 85.8%
triton_mm_873 0.0522 ms 84.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_867 0.0563 ms 78.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_863 0.0584 ms 75.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 4.4230 seconds and 6.0413 seconds precompiling for 20 choices
>>time: 0.048ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.048ms
AUTOTUNE int_mm(4096x3072, 3072x64, 4096x64)
triton_mm_885 0.0287 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_886 0.0338 ms 84.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_884 0.0440 ms 65.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_881 0.0492 ms 58.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_880 0.0532 ms 53.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_882 0.0543 ms 52.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_879 0.0604 ms 47.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_878 0.0666 ms 43.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_877 0.0788 ms 36.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_876 0.0963 ms 29.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1406 seconds and 3.6622 seconds precompiling for 11 choices
>>time: 0.033ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.048ms
>>time: 0.074ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.135ms
>>time: 0.068ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 0.38
best_cls=<class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>
0%| | 0/50 [09:59<?, ?it/s]
Traceback (most recent call last):
File "/media/242hdd/research/arc_repose/quant.py", line 62, in <module>
transformer(**example_inputs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1036, in _compile
raise InternalTorchDynamoError(
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
self._return(inst)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
self.output.compile_subgraph(
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1136, in compile_subgraph
self.compile_and_call_fx_graph(
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1329, in compile_and_call_fx_graph
fx.GraphModule(root, self.graph),
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 511, in __init__
self.graph = graph
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2029, in __setattr__
super().__setattr__(name, value)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 558, in graph
self.recompile()
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 808, in recompile
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 190, in fx_forward_from_src_skip_result
result = original_forward_from_src(src, globals, co_fields)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 92, in _forward_from_src
return _method_from_src(
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 102, in _method_from_src
_exec_with_source(src, globals_copy, co_fields)
File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 88, in _exec_with_source
exec(compile(src, key, "exec"), globals)
torch._dynamo.exc.InternalTorchDynamoError: SystemError: excessive stack use: stack is 7008 deep
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
i encounter the above error on rtx4090, but weirdly the above code can run on L20 successfully
Activity