Skip to content

[Performance] Issue on *SanaLinearAttnProcessor2_0 family. 1.06X speedup can be reached with a simple change. #11499

Open
@David-Dingle

Description

@David-Dingle

Sys env:

OS Ubuntu 22.04
PyTorch 2.4.0+cu121
sana == 0.0.1
Diffusers == 0.34.0.dev0

Reproduce:

Try the demo test code:

import torch
from diffusers import SanaPAGPipeline

pipe = SanaPAGPipeline.from_pretrained(
  # "Efficient-Large-Model/Sana_1600M_512px_diffusers",  
  "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
  torch_dtype=torch.bfloat16,
  pag_applied_layers="transformer_blocks.8",
)
pipe.to("cuda")

pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.bfloat16)

prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
image = pipe(
    prompt=prompt,
    guidance_scale=5.0,
    pag_scale=2.0,
    num_inference_steps=20,
    generator=torch.Generator(device="cuda").manual_seed(42),
)[0]
image[0].save('sana.png')

Inference data will go through SanaLinearAttnProcessor2_0

Issue Description:

Lines 6042 and 6043 first transposed a contiguous tensor and then did type casting. Type casting invokes a data copy from an old type tensor to a new one. But if you print the new tensor's stride(), you will see:

        hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
        hidden_states = hidden_states.to(original_dtype)
        print("Contiguity after type casting: ", hidden_states.is_contiguous()) # False

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)

The problem is typecasting copies, only did the dtype transmission based on the input tensor's strides. And the bad-strided tensor is immediately used by the latter two functions. Inefficiency is broadcast.

How to Fix:

let hidden_states.to(original_dtype) do contiguous and typecasting simultaneously.
One possible approach:

@torch.compile
def transpose_cast_kernel(input_tensor: torch.Tensor) -> torch.Tensor:
    """
    torch-compiled kernel that transposes a 2D tensor and converts it to bfloat16
    """
    converted = input_tensor.to(torch.bfloat16)
    transposed = torch.transpose(converted, 1, 2).contiguous()
    return transposed

Use the versatile operation to handle the creation of the new tensor.

        hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
        hidden_states = transpose_cast_kernel(hidden_states)
        # hidden_states.is_contiguous() True

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)

Or, your expert team could do even better.

Measurement:

By adopting the previous change, the SanaLinearAttnProcessor2_0.call enjoys 1.06X speedup on RTX3090.
PAGCFGSanaLinearAttnProcessor2_0, and PAGIdentitySanaLinearAttnProcessor2_0 have similar logic and lose performance as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions