Skip to content

[Feature] Add Triton unified attention kernel for deterministic inference#6795

Open
gongweibao wants to merge 2 commits intoPaddlePaddle:developfrom
gongweibao:pr/triton-unified-attn-kernel
Open

[Feature] Add Triton unified attention kernel for deterministic inference#6795
gongweibao wants to merge 2 commits intoPaddlePaddle:developfrom
gongweibao:pr/triton-unified-attn-kernel

Conversation

@gongweibao
Copy link
Collaborator

Motivation

为确定性推理模式提供 Triton unified attention kernel。该 kernel 通过 unified kv_indices 将 prefix(cached)和 extend(new)KV tokens 统一处理,确保无论 prefix cache 命中与否,累加顺序完全一致,从而实现 split-invariant 的确定性 attention 输出。

本 PR 是纯算子层面的实现,不涉及推理流程集成(流程集成见后续 PR)。

Modifications

新增文件:

  • fastdeploy/model_executor/layers/attention/triton_ops/__init__.py — 包导出
  • fastdeploy/model_executor/layers/attention/triton_ops/unified_extend_attention.py — 核心实现:
    • _fwd_kernel_unified: Triton JIT kernel,online softmax + paged KV cache + causal mask
    • Index building utilities: triton_cumsum_with_zero_prefix, build_kv_indices_from_block_tables, build_unified_kv_indices, _scatter_extend_kv_indices_kernel(均 CUDA Graph 兼容)
    • pre_cache_len_concat_triton: GPU-only Triton 替代 C++ op
    • Reference implementations(_ref 后缀)用于正确性验证
  • tests/deterministic/test_unified_extend_attention.py — kernel 正确性单测
  • tests/deterministic/test_build_triton_indices.py — indices 构建逻辑单测

Usage or Command

# 运行 kernel 正确性测试
source /root/paddlejob/workspace/env_run/gongweibao/archfd/fdarchenv/bin/activate
CUDA_VISIBLE_DEVICES=0 python -m pytest tests/deterministic/test_unified_extend_attention.py -v
CUDA_VISIBLE_DEVICES=0 python -m pytest tests/deterministic/test_build_triton_indices.py -v

Accuracy Tests

测试覆盖:

  • Kernel 正确性:MHA/GQA/MQA, head_dim=13/64/80/96/128/256, float16/bfloat16
  • Split invariance:cache miss vs hit 输出 bit-identical
  • Determinism:多次运行 bitwise identical
  • Production-scale:bs=19, seq=4096, mixed lengths
  • Cross-validation:naive vs sdpa reference vs triton 三方交叉验证

Checklist

  • Add at least a tag in the PR title.
  • Format your code, run pre-commit before commit.
  • Add unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

🤖 Generated with Claude Code

…ence

Add a Triton-based unified extend attention kernel that processes both
prefix (cached) and extend (new) KV tokens through a single kernel with
unified kv_indices, ensuring identical accumulation order regardless of
cache hit/miss patterns.

Key components:
- _fwd_kernel_unified: Triton JIT kernel with online softmax, paged KV
  cache support, and causal masking for prefix+extend
- Index building utilities: triton_cumsum_with_zero_prefix,
  build_kv_indices_from_block_tables, build_unified_kv_indices,
  _scatter_extend_kv_indices_kernel (all CUDA Graph compatible)
- pre_cache_len_concat_triton: GPU-only replacement for C++ op
- Reference implementations (_ref variants) for correctness validation
- Comprehensive tests: kernel correctness, split invariance,
  determinism, production-scale, cross-validation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@paddle-bot
Copy link

paddle-bot bot commented Mar 11, 2026

Thanks for your contribution!

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


gongweibao seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Replace triple Python for-loop with paddle.where vectorized mask in
naive_attention and _build_causal_mask. seq4096 test: 2m39s -> 6s.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 43.19066% with 146 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@f0ab8ee). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...s/attention/triton_ops/unified_extend_attention.py 43.19% 141 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #6795   +/-   ##
==========================================
  Coverage           ?   72.36%           
==========================================
  Files              ?      396           
  Lines              ?    54956           
  Branches           ?     8635           
==========================================
  Hits               ?    39770           
  Misses             ?    12362           
  Partials           ?     2824           
Flag Coverage Δ
GPU 72.36% <43.19%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants