diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index b67091a..e98a5ce 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -37,6 +37,8 @@ except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False +from einops import rearrange + def flash_attention( q, @@ -349,7 +351,8 @@ def __init__(self, num_heads, window_size=(-1, -1), qk_norm=True, - eps=1e-6): + eps=1e-6, + bidx=0): assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -358,6 +361,7 @@ def __init__(self, self.window_size = window_size self.qk_norm = qk_norm self.eps = eps + self.bidx = bidx # layers self.q = nn.Linear(dim, dim) @@ -385,13 +389,59 @@ def qkv_fn(x): return q, k, v q, k, v = qkv_fn(x) + f, h, w = grid_sizes.tolist()[0] + q = rope_apply(q, grid_sizes, freqs).to(dtype) + k = rope_apply(k, grid_sizes, freqs).to(dtype) + v = v.to(dtype) + + qs = torch.tensor_split(q.to(torch.bfloat16), 6, 2) + ks = torch.tensor_split(k.to(torch.bfloat16), 6, 2) + vs = torch.tensor_split(v.to(torch.bfloat16), 6, 2) + + new_querys = [] + new_keys = [] + new_values = [] + for index, mode in enumerate( + [ + "bs (f h w) hn hd -> bs (h w f) hn hd", + "bs (f h w) hn hd -> bs (w h f) hn hd", + "bs (f h w) hn hd -> bs (h f w) hn hd", + "bs (f h w) hn hd -> bs (w f h) hn hd", + "bs (f h w) hn hd -> bs (f h w) hn hd", + "bs (f h w) hn hd -> bs (f w h) hn hd", + ] + ): + + new_querys.append(rearrange(qs[index], mode, f=f, h=h, w=w)) + new_keys.append(rearrange(ks[index], mode, f=f, h=h, w=w)) + new_values.append(rearrange(vs[index], mode, f=f, h=h, w=w)) + q = torch.cat(new_querys, dim=2) + k = torch.cat(new_keys, dim=2) + v = torch.cat(new_values, dim=2) x = attention( - q=rope_apply(q, grid_sizes, freqs).to(dtype), - k=rope_apply(k, grid_sizes, freqs).to(dtype), - v=v.to(dtype), + q=q, + k=k, + v=v, k_lens=seq_lens, - window_size=self.window_size) + window_size=self.window_size + ) + + hidden_states = torch.tensor_split(x, 6, 2) + new_hidden_states = [] + for index, mode in enumerate( + [ + "bs (h w f) hn hd -> bs (f h w) hn hd", + "bs (w h f) hn hd -> bs (f h w) hn hd", + "bs (h f w) hn hd -> bs (f h w) hn hd", + "bs (w f h) hn hd -> bs (f h w) hn hd", + "bs (f h w) hn hd -> bs (f h w) hn hd", + "bs (f w h) hn hd -> bs (f h w) hn hd", + ] + ): + new_hidden_states.append(rearrange(hidden_states[index], mode, f=f, h=h, w=w)) + x = torch.cat(new_hidden_states, dim=2) + x = x.to(dtype) # output @@ -504,7 +554,9 @@ def __init__(self, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, - eps=1e-6): + eps=1e-6, + bidx=0, + swa=False): super().__init__() self.dim = dim self.ffn_dim = ffn_dim @@ -513,11 +565,13 @@ def __init__(self, self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps + if (bidx + 1)%5!=0 and swa: + window_size = (4096, 4096) # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, - eps) + eps, bidx=bidx) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -654,6 +708,7 @@ def __init__( eps=1e-6, in_channels=16, hidden_size=2048, + swa=False, ): r""" Initialize the diffusion model backbone. @@ -726,7 +781,7 @@ def __init__( cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps) + window_size, qk_norm, cross_attn_norm, eps, bidx=_, swa=swa) for _ in range(num_layers) ])