Description
Sys env:
OS Ubuntu 22.04
PyTorch 2.4.0+cu121
sana == 0.0.1
Diffusers == 0.34.0.dev0
Reproduce:
Try the demo test code:
import torch
from diffusers import SanaPAGPipeline
pipe = SanaPAGPipeline.from_pretrained(
# "Efficient-Large-Model/Sana_1600M_512px_diffusers",
"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
torch_dtype=torch.bfloat16,
pag_applied_layers="transformer_blocks.8",
)
pipe.to("cuda")
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.bfloat16)
prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
image = pipe(
prompt=prompt,
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
generator=torch.Generator(device="cuda").manual_seed(42),
)[0]
image[0].save('sana.png')
Inference data will go through SanaLinearAttnProcessor2_0
Issue Description:
Lines 6042 and 6043 first transposed a contiguous tensor and then did type casting. Type casting invokes a data copy from an old type tensor to a new one. But if you print the new tensor's stride(), you will see:
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = hidden_states.to(original_dtype)
print("Contiguity after type casting: ", hidden_states.is_contiguous()) # False
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
The problem is typecasting copies, only did the dtype transmission based on the input tensor's strides. And the bad-strided tensor is immediately used by the latter two functions. Inefficiency is broadcast.
How to Fix:
let hidden_states.to(original_dtype)
do contiguous and typecasting simultaneously.
One possible approach:
@torch.compile
def transpose_cast_kernel(input_tensor: torch.Tensor) -> torch.Tensor:
"""
torch-compiled kernel that transposes a 2D tensor and converts it to bfloat16
"""
converted = input_tensor.to(torch.bfloat16)
transposed = torch.transpose(converted, 1, 2).contiguous()
return transposed
Use the versatile operation to handle the creation of the new tensor.
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = transpose_cast_kernel(hidden_states)
# hidden_states.is_contiguous() True
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
Or, your expert team could do even better.
Measurement:
By adopting the previous change, the SanaLinearAttnProcessor2_0.call enjoys 1.06X speedup on RTX3090.
PAGCFGSanaLinearAttnProcessor2_0, and PAGIdentitySanaLinearAttnProcessor2_0 have similar logic and lose performance as well.