Add CK-free fallback for fused QKNorm+RoPE+Cache#34
Conversation
86890b8 to
c1a866a
Compare
Wrap fused_qk_norm_rope_cache_quant_shuffle in try-except so that CK-free builds gracefully fall through to the non-fused Triton path (rotary_emb + q/k_norm + reshape_and_cache) instead of crashing. Key safety measures: - qkv.clone() backup before fused kernel call, restored on failure (protects against partial in-place writes before exception) - log-once warning via class attribute to avoid log spam - q_norm is None guard on middle path preserves original elif invariant
c1a866a to
0ca5159
Compare
|
@gyohuangxin @ZhiweiYan-96 @valarLip — requesting review on this CK-free fallback for What this does: When Safety measures: Known limitation: Full E2E test is blocked by a separate AITER-side issue — the ASM attention kernels ( Unit tests (import + mock fallback with 5 assertions) pass cleanly. |
|
Moved to upstream: ROCm#279 |
Summary
fused_qk_norm_rope_cache_quant_shuffleinattention_mha.py:rope_cache()with try-exceptqkv.clone()backup before fused kernel call, restored on failure (protects against partial in-place writes)q_norm is Noneguard on middle path preserves originalelifinvariantRelated
Test Results
from atom.model_ops.attention_mha import PagedAttentionImplKnown Limitation
E2E tests are blocked by a separate AITER-side issue:
module_fmha_v3_varlen_fwdJIT compilation fails in CK-free builds because the ASM attention kernels still depend on CK-Tile headers (fmha_fwd.hpp). Ourrope_cachefallback works correctly through model load + warmup (logs confirm individualmodule_rope_pos_fwdandmodule_cacheloaded successfully).Shengnan's team is working on removing the CK header dependency from ASM attention kernels in AITER. Once that lands, the full CK-free E2E path (this PR + FMHA fix) will be unblocked.