Skip to content

Commit 9b2e58f

Browse files
committed
[CK Tile] Fix FMHA LSE calculation and potential division by zero
This commit addresses numerical stability issues in the BlockFmhaPipelineQRKSVS pipeline when bias has -inf masking values: 1. Explicitly handle the case where the accumulated exponential sum (l) is zero. In this case, the LSE is now correctly set to negative infinity, preventing log(0) errors. 2. Extend the zero-check protection in the normalization step to cover the ELEMENTWISE_BIAS case, preventing potential division by zero.
1 parent 004784e commit 9b2e58f

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -655,26 +655,35 @@ struct BlockFmhaPipelineQRKSVS
655655
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
656656
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
657657
constexpr auto i_idx = make_tuple(idx0);
658-
#if CK_TILE_FMHA_FWD_FAST_EXP2
659-
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
660-
BiasEnum == BlockAttentionBiasEnum::ALIBI)
658+
// In the masked biased case, the entire row can be suppressed and the accumulated
659+
// softmax denominator becomes zero; treat it as log(0) = -inf to avoid NaNs.
660+
if(l_[i_idx] == 0.0f)
661661
{
662-
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
662+
lse(i_idx) = -numeric<LSEDataType>::infinity();
663663
}
664664
else
665665
{
666-
if constexpr(kHasLogitsSoftCap)
666+
#if CK_TILE_FMHA_FWD_FAST_EXP2
667+
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
668+
BiasEnum == BlockAttentionBiasEnum::ALIBI)
667669
{
668670
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
669671
}
670672
else
671673
{
672-
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
674+
if constexpr(kHasLogitsSoftCap)
675+
{
676+
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
677+
}
678+
else
679+
{
680+
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
681+
}
673682
}
674-
}
675683
#else
676-
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
684+
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
677685
#endif
686+
}
678687
});
679688

680689
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
@@ -686,7 +695,10 @@ struct BlockFmhaPipelineQRKSVS
686695
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
687696
constexpr auto i_idx = make_tuple(idx0);
688697
const auto tmp = [&]() {
689-
if constexpr(FmhaMask::IsMasking)
698+
// When bias carries -inf masks the denominator can be zero; guard the normalization
699+
// so we do not divide by zero after a fully masked row.
700+
if constexpr(FmhaMask::IsMasking ||
701+
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
690702
{
691703
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
692704
}

0 commit comments

Comments
 (0)