-
Notifications
You must be signed in to change notification settings - Fork 55
Description
Hello,
I'm working on implementing the 4K image generation approach using NATTEN as described in the GNA paper. The paper demonstrates accelerating Flux for high-resolution image generation.
Environment
Hardware: A800 GPU
NATTEN version: 0.20.0
CUDA: 12.6
PyTorch: 2.7.0
Implementation Approach
I'm modifying the attention processor of URAE (which enables Flux to generate 4K images) to leverage na1d for efficient attention computation. For image generation tasks, each image token needs to interact with text tokens - a pattern that might be described as "self-cross attention" as mentioned in issue #82.
My implementation leverages the additional_key/value parameters of na1d to enable this cross-domain interaction. I'm basing my approach on the URAE attention processor implementation found here: URAE/attention_processor.py#L82
hidden_states = na1d(query.transpose(1, 2)[:,512:,:,:], key.transpose(1, 2)[:,512:,:,:], value.transpose(1, 2)[:,512:,:,:], \
kernel_size=80, stride=16, \
additional_keys=key.transpose(1, 2)[:,:512,:,:], additional_values=value.transpose(1, 2)[:,:512,:,:], \
backend="cutlass-fna", \
attention_kwargs={"backend": "cutlass-fmha"} )
text_hidden_states = na1d(query.transpose(1, 2)[:,:512,:,:], key.transpose(1, 2)[:,:512,:,:], value.transpose(1, 2)[:,:512,:,:], \
kernel_size=512, \
additional_keys=key.transpose(1, 2)[:,512:,:,:], additional_values=value.transpose(1, 2)[:,512:,:,:], \
backend="cutlass-fna", \
attention_kwargs={"backend": "cutlass-fmha"} )
hidden_states = torch.cat([text_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
When running inference with this implementation, I'm getting noisy/corrupted images rather than proper high-resolution outputs. I've experimented with different parameter configurations but haven't been able to achieve the results described in the paper.
Questions
- Is there any official implementation or example of using NATTEN for 4K image generation with Flux, as mentioned in the GNA paper?
- Are there specific considerations or configurations needed when using na1d with additional keys/values for text-image cross-attention?
- Could you identify any issues in my implementation that might be causing the noisy outputs?
I appreciate any guidance or pointers to relevant resources. Thank you for your time!