1313logger = init_logger (__name__ )
1414
1515
16+ def _maybe_reshape_attn_mask (query : torch .Tensor , key : torch .Tensor , attn_mask : torch .Tensor | None = None ):
17+ """
18+ Reshape Attention Mask
19+ [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
20+ """
21+ # Skip Attention Mask if all values are 1, `None` mask can speedup the computation
22+ if attn_mask is not None and torch .all (attn_mask != 0 ):
23+ attn_mask = None
24+
25+ # Reshape Attention Mask
26+ # [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
27+ if (
28+ attn_mask is not None
29+ and attn_mask .ndim == 2
30+ and attn_mask .shape [0 ] == query .shape [0 ]
31+ and attn_mask .shape [1 ] == key .shape [1 ]
32+ ):
33+ B , Sq , Skv = attn_mask .shape [0 ], query .shape [1 ], key .shape [1 ]
34+ attn_mask = attn_mask .to (torch .bool )
35+ attn_mask = attn_mask .unsqueeze (1 ).expand (B , Sq , Skv ).unsqueeze (1 ).contiguous ()
36+ return attn_mask
37+
38+
1639class SDPABackend (AttentionBackend ):
1740 accept_output_buffer : bool = True
1841
@@ -47,16 +70,15 @@ def __init__(
4770 self .causal = causal
4871 self .softmax_scale = softmax_scale
4972
50- def forward (
73+ def forward_cuda (
5174 self ,
5275 query : torch .Tensor ,
5376 key : torch .Tensor ,
5477 value : torch .Tensor ,
55- attn_metadata : AttentionMetadata = None ,
78+ attn_metadata : AttentionMetadata | None = None ,
5679 ) -> torch .Tensor :
5780 query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
5881 attention_mask = attn_metadata .attn_mask if attn_metadata else None
59-
6082 output = torch .nn .functional .scaled_dot_product_attention (
6183 query ,
6284 key ,
@@ -68,3 +90,33 @@ def forward(
6890 )
6991 out = output .permute (0 , 2 , 1 , 3 )
7092 return out
93+
94+ def forward_xpu (
95+ self ,
96+ query : torch .Tensor ,
97+ key : torch .Tensor ,
98+ value : torch .Tensor ,
99+ attn_metadata : AttentionMetadata | None = None ,
100+ ) -> torch .Tensor :
101+ return self .forward_cuda (query , key , value , attn_metadata )
102+
103+ def forward_hip (
104+ self ,
105+ query : torch .Tensor ,
106+ key : torch .Tensor ,
107+ value : torch .Tensor ,
108+ attn_metadata : AttentionMetadata | None = None ,
109+ ) -> torch .Tensor :
110+ return self .forward_cuda (query , key , value , attn_metadata )
111+
112+ def forward_npu (
113+ self ,
114+ query : torch .Tensor ,
115+ key : torch .Tensor ,
116+ value : torch .Tensor ,
117+ attn_metadata : AttentionMetadata | None = None ,
118+ ) -> torch .Tensor :
119+ if attn_metadata :
120+ attention_mask = _maybe_reshape_attn_mask (query , key , attn_metadata .attn_mask )
121+ setattr (attn_metadata , "attn_mask" , attention_mask )
122+ return self .forward_cuda (query , key , value , attn_metadata )
0 commit comments