Skip to content

Commit ed89c8b

Browse files
gcanlinmuziyuhui666hsliuustc0106
authored
[Bugfix] Fix NPU SDPA attention mask shape and semantics (#1031)
Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: muziyuhui666 <111362884+muziyuhui666@users.noreply.github.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent 70a5de9 commit ed89c8b

File tree

2 files changed

+82
-22
lines changed

2 files changed

+82
-22
lines changed

vllm_omni/diffusion/attention/backends/flash_attn.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,8 @@
1010
AttentionMetadata,
1111
)
1212

13-
# Import flash attention functions with fallback chain from utils/fa.py
14-
# FA3 (fa3_fwd_interface) -> FA3 (flash_attn_interface) -> FA2 (flash_attn)
15-
from vllm_omni.diffusion.attention.backends.utils.fa import (
16-
HAS_FLASH_ATTN,
17-
_pad_input,
18-
_unpad_input,
19-
_upad_input,
20-
flash_attn_func,
21-
flash_attn_varlen_func,
22-
)
23-
2413
logger = init_logger(__name__)
2514

26-
if not HAS_FLASH_ATTN:
27-
raise ImportError(
28-
"FlashAttentionBackend requires Flash Attention. "
29-
"Please install one of: fa3-fwd, flash-attention, or flash-attn. "
30-
"Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA"
31-
)
32-
3315

3416
class FlashAttentionBackend(AttentionBackend):
3517
accept_output_buffer: bool = True
@@ -74,6 +56,24 @@ def forward_cuda(
7456
attn_metadata: AttentionMetadata = None,
7557
) -> torch.Tensor:
7658
"""CUDA/ROCm flash attention implementation."""
59+
# Import flash attention functions with fallback chain from utils/fa.py
60+
# FA3 (fa3_fwd_interface) -> FA3 (flash_attn_interface) -> FA2 (flash_attn)
61+
from vllm_omni.diffusion.attention.backends.utils.fa import (
62+
HAS_FLASH_ATTN,
63+
_pad_input,
64+
_unpad_input,
65+
_upad_input,
66+
flash_attn_func,
67+
flash_attn_varlen_func,
68+
)
69+
70+
if not HAS_FLASH_ATTN:
71+
raise ImportError(
72+
"FlashAttentionBackend requires Flash Attention. "
73+
"Please install one of: fa3-fwd, flash-attention, or flash-attn. "
74+
"Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA"
75+
)
76+
7777
query_length = query.size(1)
7878
attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None
7979
# Contains at least one padding token in the sequence
@@ -122,7 +122,15 @@ def forward_npu(
122122
attn_metadata: AttentionMetadata = None,
123123
) -> torch.Tensor:
124124
"""NPU attention implementation using mindiesd."""
125-
from mindiesd import attention_forward
125+
try:
126+
from mindiesd import attention_forward
127+
except ImportError:
128+
raise ImportError(
129+
"FlashAttentionBackend NPU implementation requires MindIE-SD. "
130+
"Please install MindIE-SD to enable NPU attention support. "
131+
"For installation details, see https://gitcode.com/Ascend/MindIE-SD"
132+
"Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA"
133+
)
126134

127135
attention_mask = attn_metadata.attn_mask if attn_metadata else None
128136
output = attention_forward(

vllm_omni/diffusion/attention/backends/sdpa.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,29 @@
1313
logger = init_logger(__name__)
1414

1515

16+
def _maybe_reshape_attn_mask(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
17+
"""
18+
Reshape Attention Mask
19+
[batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
20+
"""
21+
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
22+
if attn_mask is not None and torch.all(attn_mask != 0):
23+
attn_mask = None
24+
25+
# Reshape Attention Mask
26+
# [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
27+
if (
28+
attn_mask is not None
29+
and attn_mask.ndim == 2
30+
and attn_mask.shape[0] == query.shape[0]
31+
and attn_mask.shape[1] == key.shape[1]
32+
):
33+
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
34+
attn_mask = attn_mask.to(torch.bool)
35+
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
36+
return attn_mask
37+
38+
1639
class SDPABackend(AttentionBackend):
1740
accept_output_buffer: bool = True
1841

@@ -47,16 +70,15 @@ def __init__(
4770
self.causal = causal
4871
self.softmax_scale = softmax_scale
4972

50-
def forward(
73+
def forward_cuda(
5174
self,
5275
query: torch.Tensor,
5376
key: torch.Tensor,
5477
value: torch.Tensor,
55-
attn_metadata: AttentionMetadata = None,
78+
attn_metadata: AttentionMetadata | None = None,
5679
) -> torch.Tensor:
5780
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
5881
attention_mask = attn_metadata.attn_mask if attn_metadata else None
59-
6082
output = torch.nn.functional.scaled_dot_product_attention(
6183
query,
6284
key,
@@ -68,3 +90,33 @@ def forward(
6890
)
6991
out = output.permute(0, 2, 1, 3)
7092
return out
93+
94+
def forward_xpu(
95+
self,
96+
query: torch.Tensor,
97+
key: torch.Tensor,
98+
value: torch.Tensor,
99+
attn_metadata: AttentionMetadata | None = None,
100+
) -> torch.Tensor:
101+
return self.forward_cuda(query, key, value, attn_metadata)
102+
103+
def forward_hip(
104+
self,
105+
query: torch.Tensor,
106+
key: torch.Tensor,
107+
value: torch.Tensor,
108+
attn_metadata: AttentionMetadata | None = None,
109+
) -> torch.Tensor:
110+
return self.forward_cuda(query, key, value, attn_metadata)
111+
112+
def forward_npu(
113+
self,
114+
query: torch.Tensor,
115+
key: torch.Tensor,
116+
value: torch.Tensor,
117+
attn_metadata: AttentionMetadata | None = None,
118+
) -> torch.Tensor:
119+
if attn_metadata:
120+
attention_mask = _maybe_reshape_attn_mask(query, key, attn_metadata.attn_mask)
121+
setattr(attn_metadata, "attn_mask", attention_mask)
122+
return self.forward_cuda(query, key, value, attn_metadata)

0 commit comments

Comments
 (0)