diff --git a/cogvideox/models/attention_processor.py b/cogvideox/models/attention_processor.py new file mode 100644 index 0000000..6e6afc3 --- /dev/null +++ b/cogvideox/models/attention_processor.py @@ -0,0 +1,110 @@ +from typing import Optional +import torch +from flash_attn import flash_attn_func +from einops import rearrange + +from diffusers.models.attention import Attention +from diffusers.models.embeddings import apply_rotary_emb + + +class CogVideoXSWAAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self, window_size=1024): + self.window_size = window_size + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + num_frames: int = None, + height: int = None, + width: int = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim) # .transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + query = query.transpose(1, 2).to(value) + key = key.transpose(1, 2).to(value) + + interval = max((query.size(1) - text_seq_length) // (self.window_size - 256), 1) + cross_key = torch.cat([key[:, :text_seq_length], key[:, text_seq_length::interval]], dim=1) + cross_val = torch.cat([value[:, :text_seq_length], value[:, text_seq_length::interval]], dim=1) + cross_hidden_states = flash_attn_func(query, cross_key, cross_val, dropout_p=0.0, causal=False) + query_txt = query[:, :text_seq_length] + key_txt = key[:, :text_seq_length] + value_txt = value[:, :text_seq_length] + querys = torch.tensor_split(query[:, text_seq_length:], 6, 2) + keys = torch.tensor_split(key[:, text_seq_length:], 6, 2) + values = torch.tensor_split(value[:, text_seq_length:], 6, 2) + new_querys = [querys[0]] + new_keys = [keys[0]] + new_values = [values[0]] + for index, mode in enumerate(["bs (f h w) hn hd -> bs (f w h) hn hd", "bs (f h w) hn hd -> bs (h f w) hn hd", "bs (f h w) hn hd -> bs (h w f) hn hd", + "bs (f h w) hn hd -> bs (w f h) hn hd", "bs (f h w) hn hd -> bs (w h f) hn hd"]): + new_querys.append(rearrange(querys[index + 1], mode, f=num_frames, h=height, w=width)) + new_keys.append(rearrange(keys[index + 1], mode, f=num_frames, h=height, w=width)) + new_values.append(rearrange(values[index + 1], mode, f=num_frames, h=height, w=width)) + query = torch.cat([query_txt, torch.cat(new_querys, dim=2)], dim=1) + key = torch.cat([key_txt, torch.cat(new_keys, dim=2)], dim=1) + value = torch.cat([value_txt, torch.cat(new_values, dim=2)], dim=1) + + hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False, window_size=(self.window_size, self.window_size)) + hidden_states_txt = hidden_states[:, :text_seq_length] + hidden_states = torch.tensor_split(hidden_states[:, text_seq_length:], 6, 2) + new_hidden_states = [hidden_states[0]] + for index, mode in enumerate(["bs (f w h) hn hd -> bs (f h w) hn hd", "bs (h f w) hn hd -> bs (f h w) hn hd", "bs (h w f) hn hd -> bs (f h w) hn hd", + "bs (w f h) hn hd -> bs (f h w) hn hd", "bs (w h f) hn hd -> bs (f h w) hn hd"]): + new_hidden_states.append(rearrange(hidden_states[index + 1], mode, f=num_frames, h=height, w=width)) + hidden_states = torch.cat([hidden_states_txt, torch.cat(new_hidden_states, dim=2)], dim=1) + cross_hidden_states + + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states diff --git a/cogvideox/models/transformer3d.py b/cogvideox/models/transformer3d.py index 5aa7323..21c1405 100644 --- a/cogvideox/models/transformer3d.py +++ b/cogvideox/models/transformer3d.py @@ -31,6 +31,8 @@ from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero +from cogvideox.models.attention_processor import CogVideoXSWAAttnProcessor2_0 + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -87,6 +89,8 @@ def __init__( ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, + swa: bool = False, + window_size: int = 1024, ): super().__init__() @@ -101,7 +105,7 @@ def __init__( eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), + processor=CogVideoXSWAAttnProcessor2_0(window_size) if swa else CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -122,6 +126,9 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_frames: int = None, + height: int = None, + width: int = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -135,6 +142,9 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + num_frames = num_frames, + height = height, + width = width, ) hidden_states = hidden_states + gate_msa * attn_hidden_states @@ -238,6 +248,8 @@ def __init__( spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, + swa: bool = False, + window_size: int = 1024, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -285,6 +297,8 @@ def __init__( attention_bias=attention_bias, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, + swa=swa, + window_size=window_size, ) for _ in range(num_layers) ] @@ -417,6 +431,7 @@ def forward( return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape + p = self.config.patch_size # 1. Time embedding timesteps = timestep @@ -469,6 +484,9 @@ def custom_forward(*inputs): encoder_hidden_states, emb, image_rotary_emb, + num_frames, + height // p, + width // p, **ckpt_kwargs, ) else: @@ -477,6 +495,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=emb, image_rotary_emb=image_rotary_emb, + num_frames=num_frames, + height=height // p, + width=width // p, ) if not self.config.use_rotary_positional_embeddings: @@ -493,7 +514,6 @@ def custom_forward(*inputs): hidden_states = self.proj_out(hidden_states) # 6. Unpatchify - p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)