Skip to content

🐛 [Bug] Decomposing attention leads to shape errors (due to view op) in FLUX model #3333

Open
@peri044

Description

Bug Description

After merging this PR : #3296, I see the following error

ValueError: Cannot view a tensor with shape torch.Size([s6, s2 + 4096, 24, 128]) and strides (3072*s2 + 12582912, 128, 128*s2 + 524288, 1) as a tensor with shape (s1, (s6*(s2 + 4096)//s1), 3072)!

While executing %view_52 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%transpose_10, [%sym_size_int_63, -1, 3072]), kwargs = {})
Original traceback:
File "/work/TensorRT/examples/dynamo/run_2.py", line 48, in forward
    return self.module.forward(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 438, in forward
    hidden_states = block(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 119, in forward
    attn_output = self.attn(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(

To Reproduce

Here's the full script :

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import export_llm, generate
from torch.export import Dim
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

import time
from contextlib import contextmanager

@contextmanager
def timer(logger, name:str):
    logger.info(f"{name} section Start...")
    start = time.time()
    yield
    end = time.time()
    logger.info(f"{name} section End...")
    logger.info(f"{name} section elapsed time: {end - start} seconds")

class MyModule(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self,
                hidden_states: torch.Tensor,
                encoder_hidden_states: torch.Tensor = None,
                pooled_projections: torch.Tensor = None,
                timestep: torch.LongTensor = None,
                img_ids: torch.Tensor = None,
                txt_ids: torch.Tensor = None,
                guidance: torch.Tensor = None,
                joint_attention_kwargs: Optional[Dict[str, Any]] = None,
                return_dict: bool = False, **kwargs):


        return self.module.forward(
            hidden_states,
            encoder_hidden_states,
            pooled_projections,
            timestep,
            img_ids,
            txt_ids,
        )

def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length):
    from unittest.mock import patch

# Assume `instance` is your class instance containing the `__call__` method

# Use patch.object to mock the __call__ method of self.transformer
    with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call:
        # one step is enough for intercept the inputs
        image =instance(
                prompt,
                guidance_scale=0.0,
                num_inference_steps=1,
                max_sequence_length=max_sequence_length,
                generator=torch.Generator("cpu").manual_seed(0)
            ).images[0]


        # Access the call arguments of the first (or specific) call
        if mock_transformer_call.call_args_list:
            args, kwargs = mock_transformer_call.call_args_list[0]
            # Store the inputs in a tuple
            intercepted_inputs = (args, kwargs)
            
            # print("Intercepted args:", args)
            # print("Intercepted kwargs:", kwargs)
            return (args, kwargs)
        else:
            print("No calls were made to self.transformer.__call__")
            return (None, None)


if __name__ == "__main__":

    # config
    dryrun = False

    # parameter setting
    batch_size = 2
    max_seq_len = 256
    prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)]
    cuda_device = "cuda:0"
    device="cuda:0"
    with torch.no_grad():
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", 
                                            torch_dtype=torch.float16)
        pipe.to(device)
        
        example_inputs = (torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 768), dtype=torch.float16).to(device),
                  torch.tensor([1., 1.], dtype=torch.float16).to(device),
                  torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
        )
        BATCH = Dim("batch", min=1, max=batch_size)
        SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len)
        dynamic_shapes = ({0 : BATCH}, 
                        {0 : BATCH, 1 : SEQ_LEN},
                        {0 : BATCH},
                        {0 : BATCH},
                        {0 : BATCH},
                        {0 : BATCH, 1 : SEQ_LEN},
                        )
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"1 Free mem: {free}, Total mem: {total}")
        # breakpoint()
        with timer(logger=logger, name="ep_gen"):
                model = MyModule(pipe.transformer).eval().half()#.to(device)
                logger.info("Directly use _export because torch.export.export doesn't work")
                # This API is used to express the constraint violation guards as asserts in the graph.
                from torch.export._trace import _export
                ep = _export(
                    model,
                    args=example_inputs, 
                    dynamic_shapes=dynamic_shapes,
                    strict=False,
                    allow_complex_guards_as_runtime_asserts=True,
                )
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"2 Free mem: {free}, Total mem: {total}")
        # breakpoint()
        logger.info(f"Generating TRT engine now, dryrun={dryrun}...")
        # print("Generating TRT engine now...")
        #TODO: if some non-tensor input, do we still need to provide them.
        with timer(logger, "trt_gen"):
            with torch_tensorrt.logging.debug():
                trt_start = time.time()
                trt_model = torch_tensorrt.dynamo.compile(
                                ep,
                                inputs=list(example_inputs),
                                enabled_precisions={torch.float32},
                                truncate_double=True,
                                device=torch.device(cuda_device),
                                disable_tf32=True,
                                use_explicit_typing=True,
                                dryrun=dryrun,
                                debug=True,
                                use_fp32_acc=True,
                            )
                trt_end = time.time()
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"3 Free mem: {free}, Total mem: {total}")
        breakpoint()
        del pipe
        del ep
        del model

        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"4 Free mem: {free}, Total mem: {total}")
        breakpoint()
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        example_inputs_cuda = [input.cuda() for input in example_inputs]
        with timer(logger, "trt_save"):
            try:
                breakpoint()
                trt_ep = torch.export.export(trt_model, args=example_inputs_cuda,
                                    dynamic_shapes=dynamic_shapes, strict=False)
                torch.export.save(trt_ep, "trt.ep")
            except Exception as e:
                import traceback
                # Capture the full traceback
                tb = traceback.format_exc()
                logger.warning("An error occurred. Here's the traceback:")
                # print(tb)
                logger.warning(tb)
                breakpoint()
                torch_tensorrt.save(trt_model, "trt.ep")

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions