Skip to content

error in using torchao and torch compile on rtx 4090 #1775

Open
@zhangvia

Description

@zhangvia

reproduce script,you can download the serialized_inputs.pt

from diffusers import DiffusionPipeline
import torch 
from torchao.quantization import autoquant
from tqdm import tqdm

transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev/transformer/", torch_dtype=torch.bfloat16).to("cuda")
transformer.to(memory_format=torch.channels_last)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
transformer = autoquant(transformer, error_on_unseen=False)


example_inputs = torch.load("serialized_inputs.pt", weights_only=True)
example_inputs = {k: v.to("cuda") for k, v in example_inputs.items()}
example_inputs.update({"joint_attention_kwargs": None, "return_dict": False})

for i in tqdm(range(50)):
    with torch.no_grad():
        transformer(**example_inputs)

env:
torch:2.6.0+cu124
torchao:0.8.0+cu124
diffusers:0.32.0
gpu:rtx4090

error traceback:

SingleProcess AUTOTUNE benchmarking takes 2.6167 seconds and 4.9302 seconds precompiling for 11 choices
>>time: 0.199ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.505ms
>>time: 0.246ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 2.240ms
>>time: 0.239ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 6.40
best_cls=<class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>

activation_shapes: torch.Size([1, 3072]), times_seen: 1
weight_shape: torch.Size([6144, 3072]), dtype: torch.bfloat16, bias_shape: torch.Size([6144])
>>time: 0.068ms for <class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>, to_beat: infms
>>time: 0.038ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.068ms
>>time: 0.038ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.038ms
>>time: 0.039ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.038ms
best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>

activation_shapes: torch.Size([4096, 3072]), times_seen: 1
weight_shape: torch.Size([64, 3072]), dtype: torch.bfloat16, bias_shape: torch.Size([64])
AUTOTUNE addmm(4096x64, 4096x3072, 3072x64)
  triton_mm_843 0.0440 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_851 0.0440 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_840 0.0461 ms 95.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_856 0.0471 ms 93.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  bias_addmm 0.0481 ms 91.5%
  addmm 0.0502 ms 87.8%
  triton_mm_846 0.0512 ms 86.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_849 0.0513 ms 85.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_850 0.0543 ms 81.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_842 0.0553 ms 79.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 4.0817 seconds and 4.6254 seconds precompiling for 20 choices
>>time: 0.048ms for <class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>, to_beat: infms
AUTOTUNE mixed_mm(4096x3072, 3072x64)
  triton_mm_874 0.0439 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=256, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_861 0.0440 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_864 0.0451 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_868 0.0451 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_875 0.0451 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=256, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_858 0.0492 ms 89.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  fallback_mixed_mm 0.0512 ms 85.8%
  triton_mm_873 0.0522 ms 84.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_867 0.0563 ms 78.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_863 0.0584 ms 75.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE='tl.bfloat16', EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 4.4230 seconds and 6.0413 seconds precompiling for 20 choices
>>time: 0.048ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.048ms
AUTOTUNE int_mm(4096x3072, 3072x64, 4096x64)
  triton_mm_885 0.0287 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_886 0.0338 ms 84.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_884 0.0440 ms 65.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_881 0.0492 ms 58.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_880 0.0532 ms 53.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_882 0.0543 ms 52.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_879 0.0604 ms 47.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_878 0.0666 ms 43.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_877 0.0788 ms 36.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_876 0.0963 ms 29.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1406 seconds and 3.6622 seconds precompiling for 11 choices
>>time: 0.033ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.048ms
>>time: 0.074ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.135ms
>>time: 0.068ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 0.38
best_cls=<class 'torchao.quantization.autoquant.AQDefaultLinearWeight'>

  0%|                                                                                                                                                            | 0/50 [09:59<?, ?it/s]
Traceback (most recent call last):
  File "/media/242hdd/research/arc_repose/quant.py", line 62, in <module>
    transformer(**example_inputs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1036, in _compile
    raise InternalTorchDynamoError(
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
    self._return(inst)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
    self.output.compile_subgraph(
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1136, in compile_subgraph
    self.compile_and_call_fx_graph(
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1329, in compile_and_call_fx_graph
    fx.GraphModule(root, self.graph),
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 511, in __init__
    self.graph = graph
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2029, in __setattr__
    super().__setattr__(name, value)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 558, in graph
    self.recompile()
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 808, in recompile
    cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 190, in fx_forward_from_src_skip_result
    result = original_forward_from_src(src, globals, co_fields)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 92, in _forward_from_src
    return _method_from_src(
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 102, in _method_from_src
    _exec_with_source(src, globals_copy, co_fields)
  File "/media/74nvme/software/miniconda3/envs/repose/lib/python3.10/site-packages/torch/fx/graph_module.py", line 88, in _exec_with_source
    exec(compile(src, key, "exec"), globals)
torch._dynamo.exc.InternalTorchDynamoError: SystemError: excessive stack use: stack is 7008 deep

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

i encounter the above error on rtx4090, but weirdly the above code can run on L20 successfully

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions