Skip to content

Implementing NATTEN for 4K image generation with Flux as mentioned in GNA paper #264

@xuanmingShang

Description

@xuanmingShang

Hello,

I'm working on implementing the 4K image generation approach using NATTEN as described in the GNA paper. The paper demonstrates accelerating Flux for high-resolution image generation.

Environment
Hardware: A800 GPU
NATTEN version: 0.20.0
CUDA: 12.6
PyTorch: 2.7.0

Implementation Approach

I'm modifying the attention processor of URAE (which enables Flux to generate 4K images) to leverage na1d for efficient attention computation. For image generation tasks, each image token needs to interact with text tokens - a pattern that might be described as "self-cross attention" as mentioned in issue #82.

My implementation leverages the additional_key/value parameters of na1d to enable this cross-domain interaction. I'm basing my approach on the URAE attention processor implementation found here: URAE/attention_processor.py#L82

hidden_states = na1d(query.transpose(1, 2)[:,512:,:,:], key.transpose(1, 2)[:,512:,:,:], value.transpose(1, 2)[:,512:,:,:], \
                        kernel_size=80, stride=16, \
                        additional_keys=key.transpose(1, 2)[:,:512,:,:], additional_values=value.transpose(1, 2)[:,:512,:,:], \
                        backend="cutlass-fna", \
                        attention_kwargs={"backend": "cutlass-fmha"} )
text_hidden_states = na1d(query.transpose(1, 2)[:,:512,:,:], key.transpose(1, 2)[:,:512,:,:], value.transpose(1, 2)[:,:512,:,:], \
                        kernel_size=512, \
                        additional_keys=key.transpose(1, 2)[:,512:,:,:], additional_values=value.transpose(1, 2)[:,512:,:,:], \
                        backend="cutlass-fna", \
                        attention_kwargs={"backend": "cutlass-fmha"} )
hidden_states = torch.cat([text_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)

When running inference with this implementation, I'm getting noisy/corrupted images rather than proper high-resolution outputs. I've experimented with different parameter configurations but haven't been able to achieve the results described in the paper.

Questions

  • Is there any official implementation or example of using NATTEN for 4K image generation with Flux, as mentioned in the GNA paper?
  • Are there specific considerations or configurations needed when using na1d with additional keys/values for text-image cross-attention?
  • Could you identify any issues in my implementation that might be causing the noisy outputs?

I appreciate any guidance or pointers to relevant resources. Thank you for your time!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions