🐛 [Bug] Decomposing attention leads to shape errors (due to view op) in FLUX model #3333
Open
Description
Bug Description
After merging this PR : #3296, I see the following error
ValueError: Cannot view a tensor with shape torch.Size([s6, s2 + 4096, 24, 128]) and strides (3072*s2 + 12582912, 128, 128*s2 + 524288, 1) as a tensor with shape (s1, (s6*(s2 + 4096)//s1), 3072)!
While executing %view_52 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%transpose_10, [%sym_size_int_63, -1, 3072]), kwargs = {})
Original traceback:
File "/work/TensorRT/examples/dynamo/run_2.py", line 48, in forward
return self.module.forward(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 438, in forward
hidden_states = block(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 119, in forward
attn_output = self.attn(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward
return self.processor(
To Reproduce
Here's the full script :
# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import export_llm, generate
from torch.export import Dim
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
import time
from contextlib import contextmanager
@contextmanager
def timer(logger, name:str):
logger.info(f"{name} section Start...")
start = time.time()
yield
end = time.time()
logger.info(f"{name} section End...")
logger.info(f"{name} section elapsed time: {end - start} seconds")
class MyModule(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = False, **kwargs):
return self.module.forward(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
img_ids,
txt_ids,
)
def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length):
from unittest.mock import patch
# Assume `instance` is your class instance containing the `__call__` method
# Use patch.object to mock the __call__ method of self.transformer
with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call:
# one step is enough for intercept the inputs
image =instance(
prompt,
guidance_scale=0.0,
num_inference_steps=1,
max_sequence_length=max_sequence_length,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
# Access the call arguments of the first (or specific) call
if mock_transformer_call.call_args_list:
args, kwargs = mock_transformer_call.call_args_list[0]
# Store the inputs in a tuple
intercepted_inputs = (args, kwargs)
# print("Intercepted args:", args)
# print("Intercepted kwargs:", kwargs)
return (args, kwargs)
else:
print("No calls were made to self.transformer.__call__")
return (None, None)
if __name__ == "__main__":
# config
dryrun = False
# parameter setting
batch_size = 2
max_seq_len = 256
prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)]
cuda_device = "cuda:0"
device="cuda:0"
with torch.no_grad():
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.float16)
pipe.to(device)
example_inputs = (torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
torch.randn((batch_size, 768), dtype=torch.float16).to(device),
torch.tensor([1., 1.], dtype=torch.float16).to(device),
torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
)
BATCH = Dim("batch", min=1, max=batch_size)
SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len)
dynamic_shapes = ({0 : BATCH},
{0 : BATCH, 1 : SEQ_LEN},
{0 : BATCH},
{0 : BATCH},
{0 : BATCH},
{0 : BATCH, 1 : SEQ_LEN},
)
free, total = torch.cuda.mem_get_info(cuda_device)
print(f"1 Free mem: {free}, Total mem: {total}")
# breakpoint()
with timer(logger=logger, name="ep_gen"):
model = MyModule(pipe.transformer).eval().half()#.to(device)
logger.info("Directly use _export because torch.export.export doesn't work")
# This API is used to express the constraint violation guards as asserts in the graph.
from torch.export._trace import _export
ep = _export(
model,
args=example_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)
free, total = torch.cuda.mem_get_info(cuda_device)
print(f"2 Free mem: {free}, Total mem: {total}")
# breakpoint()
logger.info(f"Generating TRT engine now, dryrun={dryrun}...")
# print("Generating TRT engine now...")
#TODO: if some non-tensor input, do we still need to provide them.
with timer(logger, "trt_gen"):
with torch_tensorrt.logging.debug():
trt_start = time.time()
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=list(example_inputs),
enabled_precisions={torch.float32},
truncate_double=True,
device=torch.device(cuda_device),
disable_tf32=True,
use_explicit_typing=True,
dryrun=dryrun,
debug=True,
use_fp32_acc=True,
)
trt_end = time.time()
free, total = torch.cuda.mem_get_info(cuda_device)
print(f"3 Free mem: {free}, Total mem: {total}")
breakpoint()
del pipe
del ep
del model
free, total = torch.cuda.mem_get_info(cuda_device)
print(f"4 Free mem: {free}, Total mem: {total}")
breakpoint()
import gc
gc.collect()
torch.cuda.empty_cache()
example_inputs_cuda = [input.cuda() for input in example_inputs]
with timer(logger, "trt_save"):
try:
breakpoint()
trt_ep = torch.export.export(trt_model, args=example_inputs_cuda,
dynamic_shapes=dynamic_shapes, strict=False)
torch.export.save(trt_ep, "trt.ep")
except Exception as e:
import traceback
# Capture the full traceback
tb = traceback.format_exc()
logger.warning("An error occurred. Here's the traceback:")
# print(tb)
logger.warning(tb)
breakpoint()
torch_tensorrt.save(trt_model, "trt.ep")
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: