Skip to content

Commit acab8aa

Browse files
committed
Replace naive eager attention with SDPA (#4725)
1 parent 2526aad commit acab8aa

File tree

4 files changed

+68
-27
lines changed

4 files changed

+68
-27
lines changed

paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_ernie.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def forward(self, hidden_states):
296296
3. Scale by learned weight parameter
297297
- Maintains original dtype for numerical stability during computation
298298
"""
299-
if self.config.fuse_rms_norm:
299+
if hidden_states.dtype != paddle.float16 and self.config.fuse_rms_norm:
300300
return fused_rms_norm_ext(
301301
hidden_states, self.weight, self.variance_epsilon
302302
)[0].astype(self.weight.dtype)
@@ -854,8 +854,15 @@ def core_attn(
854854
v = tensor.transpose(x=v, perm=perm)
855855

856856
replicate = self.config.num_attention_heads // self.config.num_key_value_heads
857+
is_float16 = k.dtype == paddle.float16
858+
if is_float16:
859+
k = k.cast(paddle.float32)
860+
v = v.cast(paddle.float32)
857861
k = paddle.repeat_interleave(k, replicate, axis=1)
858862
v = paddle.repeat_interleave(v, replicate, axis=1)
863+
if is_float16:
864+
k = k.cast(paddle.float16)
865+
v = v.cast(paddle.float16)
859866

860867
scale_qk_coeff = self.config.scale_qk_coeff * self.head_dim**0.5
861868
product = paddle.matmul(x=q.scale(1.0 / scale_qk_coeff), y=k, transpose_y=True)

paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_siglip.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,22 @@ def eager_attention_forward(
100100
dropout: float = 0.0,
101101
**kwargs,
102102
):
103-
attn_weights = paddle.matmul(query, key.transpose((0, 1, 3, 2))) * scaling
103+
origin_dtype = query.dtype
104+
105+
attn_weights = paddle.matmul(x=query.scale(scaling), y=key, transpose_y=True)
106+
attn_weights = attn_weights.cast(paddle.float32)
107+
104108
if attention_mask is not None:
109+
attnetion_mask = attention_mask.cast(paddle.float32)
105110
attn_weights = attn_weights + attention_mask
106111

107-
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query.dtype)
112+
attn_weights = F.softmax(attn_weights, axis=-1)
113+
attn_weights = attn_weights.cast(origin_dtype)
114+
108115
attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
109116

110117
attn_output = paddle.matmul(attn_weights, value)
111-
attn_output = attn_output.transpose((0, 2, 1, 3)).contiguous()
118+
attn_output = attn_output.transpose((0, 2, 1, 3))
112119

113120
return attn_output, attn_weights
114121

@@ -138,44 +145,55 @@ def forward(
138145
cu_seqlens: Optional[List[paddle.Tensor]] = None,
139146
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
140147
):
148+
if output_attentions:
149+
raise NotImplementedError
150+
141151
B, L, D = hidden_states.shape
142152

143153
q = self.q_proj(hidden_states)
144154
k = self.k_proj(hidden_states)
145155
v = self.v_proj(hidden_states)
146156

147157
# [B, L, H, Dh]
148-
149158
q = q.reshape([B, L, self.num_heads, self.head_dim])
150159
k = k.reshape([B, L, self.num_heads, self.head_dim])
151160
v = v.reshape([B, L, self.num_heads, self.head_dim])
152161
if rope_emb is not None:
153162
cos, sin = rope_emb
154163
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
155164

156-
# → [B, H, L, Dh]
157-
q = q.transpose([0, 2, 1, 3])
158-
k = k.transpose([0, 2, 1, 3])
159-
v = v.transpose([0, 2, 1, 3])
160-
161-
attn_output, attn_weights = eager_attention_forward(
162-
self,
163-
q,
164-
k,
165-
v,
166-
attention_mask,
167-
is_causal=self.is_causal,
168-
scaling=self.scale,
169-
dropout=0.0 if not self.training else self.dropout,
170-
)
171-
attn_output = attn_output.reshape([B, L, D]).contiguous()
165+
if q.dtype == paddle.float32:
166+
# → [B, H, L, Dh]
167+
q = q.transpose([0, 2, 1, 3])
168+
k = k.transpose([0, 2, 1, 3])
169+
v = v.transpose([0, 2, 1, 3])
170+
171+
attn_output, _ = eager_attention_forward(
172+
self,
173+
q,
174+
k,
175+
v,
176+
attention_mask,
177+
is_causal=self.is_causal,
178+
scaling=self.scale,
179+
dropout=0.0 if not self.training else self.dropout,
180+
)
181+
attn_output = attn_output.reshape([B, L, D])
182+
else:
183+
attn_output = paddle.nn.functional.scaled_dot_product_attention(
184+
q,
185+
k,
186+
v,
187+
attention_mask,
188+
dropout_p=self.dropout,
189+
is_causal=self.is_causal,
190+
training=self.training,
191+
)
192+
attn_output = attn_output.reshape([B, L, D])
172193

173194
attn_output = self.out_proj(attn_output)
174195

175-
if not output_attentions:
176-
attn_weights = None
177-
178-
return attn_output, attn_weights
196+
return attn_output, None
179197

180198

181199
class SiglipVisionEmbeddings(nn.Layer):

paddlex/inference/models/doc_vlm/predictor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ....utils.deps import require_genai_client_plugin
3030
from ....utils.device import TemporaryDeviceChanger
3131
from ...common.batch_sampler import DocVLMBatchSampler
32-
from ...utils.misc import is_bfloat16_available
32+
from ...utils.misc import is_bfloat16_available, is_float16_available
3333
from ..base import BasePredictor
3434
from .result import DocVLMResult
3535

@@ -54,7 +54,12 @@ def __init__(self, *args, **kwargs):
5454

5555
if self._use_local_model:
5656
self.device = kwargs.get("device", None)
57-
self.dtype = "bfloat16" if is_bfloat16_available(self.device) else "float32"
57+
if is_bfloat16_available(self.device):
58+
self.dtype = "bfloat16"
59+
elif is_float16_available(self.device):
60+
self.dtype = "float16"
61+
else:
62+
self.dtype = "float32"
5863

5964
self.infer, self.processor = self._build(**kwargs)
6065

paddlex/inference/utils/misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,14 @@ def is_bfloat16_available(device):
3232
return (
3333
"npu" in get_device_type() or paddle.amp.is_bfloat16_supported()
3434
) and device_type in ("gpu", "npu", "xpu", "mlu")
35+
36+
37+
def is_float16_available(device):
38+
import paddle.amp
39+
40+
if device is None:
41+
device = get_default_device()
42+
device_type, _ = parse_device(device)
43+
return (
44+
"npu" in get_device_type() or paddle.amp.is_float16_supported()
45+
) and device_type in ("gpu", "npu", "xpu", "mlu", "dcu")

0 commit comments

Comments
 (0)