Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions custom_ops/xpu_ops/src/ops/block_attn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
std::vector<int> lody_vec(dec_batch + 1);
std::vector<int> offset_vec(dec_batch, 0);
std::vector<int> lod_ref_vec(dec_batch + 1, 0);
using TGEMM = std::conditional_t<std::is_same_v<XPU_XType, XPU_CType>,
tfloat32,
int8_wo_t>;
using TGEMM = std::
conditional_t<std::is_same_v<XPU_XType, XPU_CType>, float, int8_wo_t>;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 block_attn_spliced.cc 中存在相同的 TGEMM 定义未同步更新

block_attn_spliced.cc:1545-1547 中存在完全相同的代码,仍然使用 tfloat32

using TGEMM = std::conditional_t<std::is_same_v<XPU_XType, XPU_CType>,
                                 tfloat32,
                                 int8_wo_t>;

该处同样用于 speculative_attention_decoder 的模板参数。如果本 PR 的目的是全面提升 speculative attention 的精度,建议同步将 block_attn_spliced.cc 中的 tfloat32 也修改为 float,以保持两个代码路径的一致性。

请确认是否是遗漏,还是有意保持该文件不变?

constexpr int quant_mode = std::is_same_v<XPU_CType, int8_t> ? 3 : 0;
ret = baidu::xpu::xfa::speculative_attention_decoder<XPU_XType,
XPU_CType,
Expand Down
Loading