Open
Description
I noticed that in the PixArtMSBlock implementation, there is no normalization layer for cross-attention, while normalization layers exist for self-attention and MLP:
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # for self-attention
self.attn = AttentionKVCompress(...)
self.cross_attn = MultiHeadCrossAttention(...) # no norm layer before/after
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # for MLP
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py#L541:
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
I'm curious about the reasoning behind not using normalization for cross-attention, while having it for self-attention and MLP layers. What's the rationale for this architectural design?
Thanks for this great work!
Metadata
Metadata
Assignees
Labels
No labels