@@ -100,15 +100,22 @@ def eager_attention_forward(
100100 dropout : float = 0.0 ,
101101 ** kwargs ,
102102):
103- attn_weights = paddle .matmul (query , key .transpose ((0 , 1 , 3 , 2 ))) * scaling
103+ origin_dtype = query .dtype
104+
105+ attn_weights = paddle .matmul (x = query .scale (scaling ), y = key , transpose_y = True )
106+ attn_weights = attn_weights .cast (paddle .float32 )
107+
104108 if attention_mask is not None :
109+ attnetion_mask = attention_mask .cast (paddle .float32 )
105110 attn_weights = attn_weights + attention_mask
106111
107- attn_weights = F .softmax (attn_weights , axis = - 1 , dtype = "float32" ).astype (query .dtype )
112+ attn_weights = F .softmax (attn_weights , axis = - 1 )
113+ attn_weights = attn_weights .cast (origin_dtype )
114+
108115 attn_weights = F .dropout (attn_weights , p = dropout , training = module .training )
109116
110117 attn_output = paddle .matmul (attn_weights , value )
111- attn_output = attn_output .transpose ((0 , 2 , 1 , 3 )). contiguous ()
118+ attn_output = attn_output .transpose ((0 , 2 , 1 , 3 ))
112119
113120 return attn_output , attn_weights
114121
@@ -138,44 +145,55 @@ def forward(
138145 cu_seqlens : Optional [List [paddle .Tensor ]] = None ,
139146 rope_emb : Optional [Tuple [paddle .Tensor , paddle .Tensor ]] = None , # (cos, sin)
140147 ):
148+ if output_attentions :
149+ raise NotImplementedError
150+
141151 B , L , D = hidden_states .shape
142152
143153 q = self .q_proj (hidden_states )
144154 k = self .k_proj (hidden_states )
145155 v = self .v_proj (hidden_states )
146156
147157 # [B, L, H, Dh]
148-
149158 q = q .reshape ([B , L , self .num_heads , self .head_dim ])
150159 k = k .reshape ([B , L , self .num_heads , self .head_dim ])
151160 v = v .reshape ([B , L , self .num_heads , self .head_dim ])
152161 if rope_emb is not None :
153162 cos , sin = rope_emb
154163 q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
155164
156- # → [B, H, L, Dh]
157- q = q .transpose ([0 , 2 , 1 , 3 ])
158- k = k .transpose ([0 , 2 , 1 , 3 ])
159- v = v .transpose ([0 , 2 , 1 , 3 ])
160-
161- attn_output , attn_weights = eager_attention_forward (
162- self ,
163- q ,
164- k ,
165- v ,
166- attention_mask ,
167- is_causal = self .is_causal ,
168- scaling = self .scale ,
169- dropout = 0.0 if not self .training else self .dropout ,
170- )
171- attn_output = attn_output .reshape ([B , L , D ]).contiguous ()
165+ if q .dtype == paddle .float32 :
166+ # → [B, H, L, Dh]
167+ q = q .transpose ([0 , 2 , 1 , 3 ])
168+ k = k .transpose ([0 , 2 , 1 , 3 ])
169+ v = v .transpose ([0 , 2 , 1 , 3 ])
170+
171+ attn_output , _ = eager_attention_forward (
172+ self ,
173+ q ,
174+ k ,
175+ v ,
176+ attention_mask ,
177+ is_causal = self .is_causal ,
178+ scaling = self .scale ,
179+ dropout = 0.0 if not self .training else self .dropout ,
180+ )
181+ attn_output = attn_output .reshape ([B , L , D ])
182+ else :
183+ attn_output = paddle .nn .functional .scaled_dot_product_attention (
184+ q ,
185+ k ,
186+ v ,
187+ attention_mask ,
188+ dropout_p = self .dropout ,
189+ is_causal = self .is_causal ,
190+ training = self .training ,
191+ )
192+ attn_output = attn_output .reshape ([B , L , D ])
172193
173194 attn_output = self .out_proj (attn_output )
174195
175- if not output_attentions :
176- attn_weights = None
177-
178- return attn_output , attn_weights
196+ return attn_output , None
179197
180198
181199class SiglipVisionEmbeddings (nn .Layer ):
0 commit comments