|
19 | 19 | import torch |
20 | 20 | import torch.nn.functional as F |
21 | 21 | from diffusers.models.attention_processor import Attention |
| 22 | +from diffusers.models.transformers.transformer_wan import WanAttention, _get_added_kv_projections, _get_qkv_projections |
22 | 23 | from diffusers.utils import deprecate, logging |
23 | 24 | from diffusers.utils.import_utils import is_xformers_available |
24 | 25 | from torch import nn |
@@ -535,4 +536,115 @@ def __call__( |
535 | 536 | return hidden_states |
536 | 537 |
|
537 | 538 |
|
| 539 | +class GaudiWanAttnProcessor: |
| 540 | + r""" |
| 541 | + Adapted from: https://github.com/huggingface/diffusers/blob/v0.35.1/src/diffusers/models/transformers/transformer_wan.py#L67 |
| 542 | +
|
| 543 | + This class copied from `WanAttnProcessor` and overrides methods to use Gaudi-specific implementations. |
| 544 | + Add a func _native_attention which uses FusedSDPA on Gaudi |
| 545 | + Use hpex.kernels.apply_rotary_pos_emb on Gaudi |
| 546 | + """ |
| 547 | + |
| 548 | + _attention_backend = None |
| 549 | + |
| 550 | + def __init__(self, is_training=False): |
| 551 | + if not hasattr(F, "scaled_dot_product_attention"): |
| 552 | + raise ImportError( |
| 553 | + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." |
| 554 | + ) |
| 555 | + self.is_training = is_training |
| 556 | + |
| 557 | + def _native_attention( |
| 558 | + self, |
| 559 | + query: torch.Tensor, |
| 560 | + key: torch.Tensor, |
| 561 | + value: torch.Tensor, |
| 562 | + attn_mask: Optional[torch.Tensor] = None, |
| 563 | + dropout_p: float = 0.0, |
| 564 | + is_causal: bool = False, |
| 565 | + scale: Optional[float] = None, |
| 566 | + enable_gqa: bool = False, |
| 567 | + ) -> torch.Tensor: |
| 568 | + # apply gaudi fused SDPA |
| 569 | + from habana_frameworks.torch.hpex.kernels import FusedSDPA |
| 570 | + |
| 571 | + # Fast FSDPA is not supported in training mode |
| 572 | + fsdpa_mode = "None" if self.is_training else "fast" |
| 573 | + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| 574 | + out = FusedSDPA.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, fsdpa_mode, None) |
| 575 | + out = out.permute(0, 2, 1, 3) |
| 576 | + return out |
| 577 | + |
| 578 | + def __call__( |
| 579 | + self, |
| 580 | + attn: "WanAttention", |
| 581 | + hidden_states: torch.Tensor, |
| 582 | + encoder_hidden_states: Optional[torch.Tensor] = None, |
| 583 | + attention_mask: Optional[torch.Tensor] = None, |
| 584 | + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| 585 | + ) -> torch.Tensor: |
| 586 | + encoder_hidden_states_img = None |
| 587 | + if attn.add_k_proj is not None: |
| 588 | + # 512 is the context length of the text encoder, hardcoded for now |
| 589 | + image_context_length = encoder_hidden_states.shape[1] - 512 |
| 590 | + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] |
| 591 | + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] |
| 592 | + |
| 593 | + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) |
| 594 | + |
| 595 | + query = attn.norm_q(query) |
| 596 | + key = attn.norm_k(key) |
| 597 | + |
| 598 | + query = query.unflatten(2, (attn.heads, -1)) |
| 599 | + key = key.unflatten(2, (attn.heads, -1)) |
| 600 | + value = value.unflatten(2, (attn.heads, -1)) |
| 601 | + |
| 602 | + if rotary_emb is not None: |
| 603 | + """ |
| 604 | + Wan's ROPE is pairwised, like this: |
| 605 | + def apply_rotary_emb( |
| 606 | + hidden_states: torch.Tensor, |
| 607 | + freqs_cos: torch.Tensor, |
| 608 | + freqs_sin: torch.Tensor, |
| 609 | + ): |
| 610 | + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) |
| 611 | + cos = freqs_cos[..., 0::2] |
| 612 | + sin = freqs_sin[..., 1::2] |
| 613 | + out = torch.empty_like(hidden_states) |
| 614 | + out[..., 0::2] = x1 * cos - x2 * sin |
| 615 | + out[..., 1::2] = x1 * sin + x2 * cos |
| 616 | + return out.type_as(hidden_states) |
| 617 | + """ |
| 618 | + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingMode, apply_rotary_pos_emb |
| 619 | + |
| 620 | + query = apply_rotary_pos_emb(query, *rotary_emb, None, 0, RotaryPosEmbeddingMode.PAIRWISE) |
| 621 | + key = apply_rotary_pos_emb(key, *rotary_emb, None, 0, RotaryPosEmbeddingMode.PAIRWISE) |
| 622 | + |
| 623 | + # I2V task |
| 624 | + hidden_states_img = None |
| 625 | + if encoder_hidden_states_img is not None: |
| 626 | + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) |
| 627 | + key_img = attn.norm_added_k(key_img) |
| 628 | + |
| 629 | + key_img = key_img.unflatten(2, (attn.heads, -1)) |
| 630 | + value_img = value_img.unflatten(2, (attn.heads, -1)) |
| 631 | + |
| 632 | + hidden_states_img = self._native_attention(query, key_img, value_img, None, 0.0, False, None) |
| 633 | + |
| 634 | + hidden_states_img = hidden_states_img.flatten(2, 3) |
| 635 | + hidden_states_img = hidden_states_img.type_as(query) |
| 636 | + |
| 637 | + hidden_states = self._native_attention(query, key, value, attention_mask, 0.0, False, None) |
| 638 | + |
| 639 | + hidden_states = hidden_states.flatten(2, 3) |
| 640 | + hidden_states = hidden_states.type_as(query) |
| 641 | + |
| 642 | + if hidden_states_img is not None: |
| 643 | + hidden_states = hidden_states + hidden_states_img |
| 644 | + |
| 645 | + hidden_states = attn.to_out[0](hidden_states) |
| 646 | + hidden_states = attn.to_out[1](hidden_states) |
| 647 | + return hidden_states |
| 648 | + |
| 649 | + |
538 | 650 | AttentionProcessor = Union[AttnProcessor2_0,] |
0 commit comments