feat: Add DiT-oriented kernels where Qk (Bmm1) type can be reinterpreted into Int8 or BFloat16#2711
feat: Add DiT-oriented kernels where Qk (Bmm1) type can be reinterpreted into Int8 or BFloat16#2711xrq-phys wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthrough📝 WalkthroughThis change extends FMHA kernel selection and launch paths to support Q/K reinterpretation and SageAttention state: adds new key fields to runner/kernel caches, threads Sage scaling-factor pointers and per-block counts through launchers/runner/params, updates kernel metadata/params (Vx path), and adds tests/docs. Changes
Sequence Diagram(s)sequenceDiagram
participant API as Python API
participant Launcher as FMHA Launcher
participant Cache as RunnerCache
participant Runner as TllmGenFmhaRunner
participant Factory as KernelFactory
participant GPU as GPU Kernel
API->>Launcher: call (Q,K,V, options, sage_sfs?, num_elts_sage?)
Launcher->>Cache: get(q_dtype, kv_dtype, o_dtype, qk_reinterpret_type, num_elts_sage...)
Cache-->>Launcher: runner (cached or newly created)
Launcher->>Runner: populate RunnerParams (ptrSageAttnSfs*, counts, shapes, strides)
Runner->>Factory: select kernel (includes dtypeQkReinterpret & Sage counts)
Factory-->>Runner: kernel metadata / function pointer
Runner->>GPU: cuLaunchKernel / KernelParamsVx or KernelParams
GPU-->>Runner: (optional) reduction / post-process
Runner-->>Launcher: completion (output buffer)
Launcher-->>API: return output
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands FlashInfer's capabilities by integrating specialized attention kernels tailored for Diffusion Transformer (DiT) models within the TensorRT-LLM framework. The changes enable more flexible mixed-precision computations, including configurations with BFloat16, Int8, and E4m3 data types for query-key products and value tensors, alongside support for SageAttention scaling. This enhancement allows FlashInfer to leverage highly optimized kernels for a broader range of advanced generative AI models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for DiT-oriented TRTLLM kernels, including variants with mixed-precision and SageAttention. The changes are extensive, involving updates to kernel selection logic, data type handling, and parameter passing to accommodate the new kernel specializations. A compatibility layer (fmhaKernelMetaAdapter.h) has been added to unify the metadata of existing and new kernels, which is a good approach for this transitional period. The addition of new tests for the DiT kernels is also a great inclusion.
My review includes a few suggestions to improve maintainability and robustness:
- Improving the hash function in the kernel cache to reduce potential collisions.
- Adding a clarifying comment for a potentially misleading variable name.
- Refactoring a duplicated helper lambda into a common utility file.
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/fmhaRunner.cuh (1)
28-53:⚠️ Potential issue | 🔴 CriticalAllow
DATA_TYPE_INT8for the new SageAttention Q/K path.The new ragged DiT coverage exercises int8 query/key inputs (
tests/attention/test_trtllm_ragged_dit.py:181-193), but this constructor still rejects anymDtypeQoutside E4M3/FP16/BF16. That blocks the new int8 variant before kernel lookup.🐛 Suggested fix
FLASHINFER_CHECK( - mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16, + mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || + mDtypeQ == DATA_TYPE_BF16 || mDtypeQ == DATA_TYPE_INT8, "Unsupported Q data type: " + std::string(toStr(mDtypeQ)));🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh` around lines 28 - 53, The constructor TllmGenFmhaRunner currently rejects DATA_TYPE_INT8 for query/key types, blocking the new SageAttention int8 path; update the validation checks (the FLASHINFER_CHECK calls that inspect mDtypeQ and mDtypeKv) to include DATA_TYPE_INT8 as an allowed type so int8 Q/K inputs pass validation before calling getTllmFmhaKernels (leave the output-type check unchanged unless tests require int8 output).
🧹 Nitpick comments (2)
tests/attention/test_trtllm_ragged_dit.py (1)
57-58: Drop the extra CUDA-availability skip.This suite already assumes CUDA-capable runners, so this branch only hides misconfigured test environments instead of surfacing them.
Based on learnings, tests in the repository assume CUDA is available and do not require
torch.cuda.is_available()guards in pytest fixtures. Ensure test files undertests/follow this convention and avoid adding CPU-only guards in fixtures unless explicitly handling a non-CUDA environment.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_ragged_dit.py` around lines 57 - 58, Remove the redundant CUDA availability guard by deleting the conditional block that checks "if not torch.cuda.is_available()" and calls "pytest.skip(...)" in the test file; tests under tests/ should assume CUDA is present, so remove the "if not torch.cuda.is_available(): pytest.skip('CUDA not available.')" branch (search for that exact conditional) to allow failures to surface in misconfigured environments.csrc/trtllm_fmha_kernel_launcher.cu (1)
73-81: Consider using a better hash combining strategy.The small bit shifts (1-7 bits) combined with XOR may lead to hash collisions when multiple fields have similar values. While the practical impact is minimal given the limited number of unique kernel configurations cached, a more robust approach would use multiplicative hash combining.
♻️ Suggested improvement using hash_combine pattern
struct KeyHash { std::size_t operator()(const Key& k) const { - return std::hash<int>()(static_cast<int>(std::get<0>(k))) ^ - (std::hash<int>()(static_cast<int>(std::get<1>(k))) << 1) ^ - (std::hash<int>()(static_cast<int>(std::get<2>(k))) << 2) ^ - (std::hash<int>()(static_cast<int>(std::get<3>(k))) << 3) ^ - (std::hash<int>()(std::get<4>(k)) << 4) ^ (std::hash<int>()(std::get<5>(k)) << 5) ^ - (std::hash<int>()(std::get<6>(k)) << 6) ^ (std::hash<int>()(std::get<7>(k)) << 7); + std::size_t seed = 0; + auto hash_combine = [&seed](auto val) { + seed ^= std::hash<int>()(static_cast<int>(val)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + }; + hash_combine(std::get<0>(k)); + hash_combine(std::get<1>(k)); + hash_combine(std::get<2>(k)); + hash_combine(std::get<3>(k)); + hash_combine(std::get<4>(k)); + hash_combine(std::get<5>(k)); + hash_combine(std::get<6>(k)); + hash_combine(std::get<7>(k)); + return seed; } };🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 73 - 81, The KeyHash::operator() uses small fixed bit shifts and XOR which can produce collisions; replace it with a robust hash_combine pattern: start with a seed (std::size_t) and for each element of Key (use std::get<0>(k) ... std::get<7>(k)) mix in std::hash<int>()(value) using a multiplicative constant (e.g. 0x9e3779b97f4a7c15ULL) and seed ^= h + constant + (seed<<6) + (seed>>2) or equivalent combine logic; update KeyHash to iterate the 8 fields and combine each hash into the seed, then return seed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 552-563: The optional SageAttention scale-factor tensors
(sage_attn_sfs_q, sage_attn_sfs_k, sage_attn_sfs_p, sage_attn_sfs_v) are being
cast to float* without dtype checks; add the same TVM_FFI_ICHECK_EQ(...dtype(),
dl_float32) validations used for attention_sinks and lse before performing the
static_cast, and only set sage_attn_sfs_*_ptr to nullptr if the optional has no
value—this ensures each tensor's dtype is dl_float32 prior to casting in
trtllm_fmha_kernel_launcher.cu.
In `@flashinfer/prefill.py`:
- Around line 3446-3452: The SageAttention tensor tuple (sage_attn_sfs) and
corresponding block sizes (num_elts_per_sage_attn_blk) must be validated before
handing raw pointers to the C++ runner: ensure each tensor in sage_attn_sfs that
has a non-zero entry in num_elts_per_sage_attn_blk is non-None, is on the
expected device (use tensor.device or tensor.get_device()), has dtype
torch.float32, and is contiguous (or call .contiguous() before taking the
pointer); for None entries require the matching block size be zero. Update the
code paths that forward these values (the code working with sage_attn_sfs and
num_elts_per_sage_attn_blk around the SageAttention prep and the later block at
lines ~3573-3608) to perform these checks and raise/handle a clear error if
validation fails, then pass the tensor.data_ptr() only after validation.
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 949-978: The hash currently maps blockSize==0 and blockSize==1 to
the same encoded value (computeLog2BlockSize returns 0), so update hashID to
mark “disabled” separately: keep computeLog2BlockSize as-is for nonzero sizes,
but when numEltsPerSageAttnBlkQ/K/P/V == 0 set reserved bits 28..31 as per-field
disabled flags (e.g., set bit 28 if numEltsPerSageAttnBlkQ==0, bit 29 for
numEltsPerSageAttnBlkK, bit 30 for P, bit 31 for V) before composing the final
return; reference the hashID function and the numEltsPerSageAttnBlkQ/K/P/V
parameters and ensure you OR the corresponding (1ULL << 28..31) flags into the
returned uint64_t so disabled is distinguishable from block size 1.
In `@include/flashinfer/trtllm/fmha/kernelParamsVx.h`:
- Around line 718-726: The O TMA descriptor and related metadata are being built
using Q-side values (kernelMeta.mDataTypeQ, numEltsInClampedHeadDimQ,
kernelMeta.mTileSizeQ) causing mismatch when O differs; update the O-path to use
O-side metadata: use the O data type (kernelMeta.mDataTypeO) when calling
buildNdTmaDescriptor, compute tileShapeO using O-specific sizes (e.g.,
numEltsInClampedHeadDimO and kernelMeta.mTileSizeO or equivalent O head-dim/
tile-size fields), and ensure mNumHiddenEltsO is computed from the O head
dimension (mHeadDimV) not Q; keep references to makeTmaShapeStrideO, tileShapeO,
params.tmaO_, buildNdTmaDescriptor, kernelMeta.mDataTypeQ/mDataTypeO,
mHeadDimV/mHeadDimQk, and mNumHiddenEltsO to locate the changes.
- Around line 784-814: The code computes params.mChunkedAttentionSizeLog2 when
isSlidingOrChunkedCausalMask(...) and options.mChunkedAttentionSize is set, but
then unconditionally resets params.mChunkedAttentionSizeLog2 to 0 at the end;
remove that reset so the computed value persists. Locate the assignment
params.mChunkedAttentionSizeLog2 = 0 (near the end of the block) and delete it
(or make it conditional only when chunked attention is disabled), ensuring the
earlier computation that uses options.mChunkedAttentionSize and
isSlidingOrChunkedCausalMask retains its result.
---
Outside diff comments:
In `@include/flashinfer/trtllm/fmha/fmhaRunner.cuh`:
- Around line 28-53: The constructor TllmGenFmhaRunner currently rejects
DATA_TYPE_INT8 for query/key types, blocking the new SageAttention int8 path;
update the validation checks (the FLASHINFER_CHECK calls that inspect mDtypeQ
and mDtypeKv) to include DATA_TYPE_INT8 as an allowed type so int8 Q/K inputs
pass validation before calling getTllmFmhaKernels (leave the output-type check
unchanged unless tests require int8 output).
---
Nitpick comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 73-81: The KeyHash::operator() uses small fixed bit shifts and XOR
which can produce collisions; replace it with a robust hash_combine pattern:
start with a seed (std::size_t) and for each element of Key (use std::get<0>(k)
... std::get<7>(k)) mix in std::hash<int>()(value) using a multiplicative
constant (e.g. 0x9e3779b97f4a7c15ULL) and seed ^= h + constant + (seed<<6) +
(seed>>2) or equivalent combine logic; update KeyHash to iterate the 8 fields
and combine each hash into the seed, then return seed.
In `@tests/attention/test_trtllm_ragged_dit.py`:
- Around line 57-58: Remove the redundant CUDA availability guard by deleting
the conditional block that checks "if not torch.cuda.is_available()" and calls
"pytest.skip(...)" in the test file; tests under tests/ should assume CUDA is
present, so remove the "if not torch.cuda.is_available(): pytest.skip('CUDA not
available.')" branch (search for that exact conditional) to allow failures to
surface in misconfigured environments.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c9e29865-5ce7-4261-9225-2c6cbb57a8c5
📒 Files selected for processing (10)
csrc/trtllm_fmha_kernel_launcher.cudocs/api/attention.rstflashinfer/artifacts.pyflashinfer/prefill.pyinclude/flashinfer/trtllm/fmha/fmhaKernelMetaAdapter.hinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunner.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParamsVx.htests/attention/test_trtllm_ragged_dit.py
7ce24c0 to
11ac55e
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/trtllm_fmha_kernel_launcher.cu (1)
570-570: Consider renamingkv_data_typefor clarity.This variable is now derived from
value.dtype()only, making the namekv_data_typepotentially misleading in the DiT case where K and V have different types. Consider renaming tov_data_typeor adding a clarifying comment.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fmha_kernel_launcher.cu` at line 570, The variable kv_data_type is misleading because it is derived only from value.dtype() and may not represent K and V separately (e.g., DiT where K and V differ); rename kv_data_type to v_data_type (or alternatively add a clarifying comment where dl_dtype_to_tllm_data_type(value.dtype()) is assigned) and update any subsequent uses to reference v_data_type so the code accurately reflects that this type is for V only.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Line 570: The variable kv_data_type is misleading because it is derived only
from value.dtype() and may not represent K and V separately (e.g., DiT where K
and V differ); rename kv_data_type to v_data_type (or alternatively add a
clarifying comment where dl_dtype_to_tllm_data_type(value.dtype()) is assigned)
and update any subsequent uses to reference v_data_type so the code accurately
reflects that this type is for V only.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0c9b60bd-9322-466e-9c5b-65729dd78859
📒 Files selected for processing (1)
csrc/trtllm_fmha_kernel_launcher.cu
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
952-963:⚠️ Potential issue | 🔴 CriticalFix the SageAttention block-size encoding in the factory key.
__builtin_ctz(blockSize + 1)is notlog2(blockSize) + 1: for2/4/8/...it returns0, so those sizes alias the disabled case in the cache key. That can reuse aTllmGenFmhaKernelbuilt for the wrong SageAttention shape.🔧 Suggested fix
auto const computeLog2BlockSizePlus1 = [](int blockSize) -> int { if (blockSize <= 0) { return 0; } FLASHINFER_CHECK((blockSize & (blockSize - 1)) == 0, "SageAttention block size must be a power of 2."); - return __builtin_ctz(static_cast<unsigned int>(blockSize) + 1); + return __builtin_ctz(static_cast<unsigned int>(blockSize)) + 1; };#!/bin/bash python - <<'PY' def ctz(n: int) -> int: return (n & -n).bit_length() - 1 def current(block: int) -> int: if block <= 0: return 0 return ctz(block + 1) def expected(block: int) -> int: if block <= 0: return 0 return ctz(block) + 1 for b in [0, 1, 2, 4, 8, 16, 32, 64]: print(f"blockSize={b:>2} current={current(b)} expected={expected(b)}") PYAlso applies to: 975-981
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 952 - 963, The factory key encodes SageAttention block-size incorrectly inside hashID: the lambda computeLog2BlockSizePlus1 uses __builtin_ctz(blockSize + 1) which maps powers of two to zero and collides with the disabled case; update computeLog2BlockSizePlus1 (and the other identical occurrence in the same function) to return 0 for blockSize <= 0, otherwise return __builtin_ctz(static_cast<unsigned int>(blockSize)) + 1 so that powers-of-two produce log2(blockSize)+1; keep the existing FLASHINFER_CHECK for power-of-two validation and apply the same change to the duplicate block-size encoding sites referenced in this function.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 312-320: The code launches Vx kernels via cuLaunchKernelEx before
verifying reduction mode; move the rejection check so you call
FLASHINFER_CHECK(!isGmemReductionWithSeparateKernel(static_cast<MultiCtasKvMode>(kernelMeta.mMultiCtasKvMode)),
...) inside the kernelMeta.isKernelVx() branch before creating kernelParams or
calling KernelParamsVx::setKernelParams and cuLaunchKernelEx. Ensure the check
uses the same kernelMeta and MultiCtasKvMode cast so unsupported Vx reduction
modes are rejected early and the kernel is never launched.
---
Duplicate comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 952-963: The factory key encodes SageAttention block-size
incorrectly inside hashID: the lambda computeLog2BlockSizePlus1 uses
__builtin_ctz(blockSize + 1) which maps powers of two to zero and collides with
the disabled case; update computeLog2BlockSizePlus1 (and the other identical
occurrence in the same function) to return 0 for blockSize <= 0, otherwise
return __builtin_ctz(static_cast<unsigned int>(blockSize)) + 1 so that
powers-of-two produce log2(blockSize)+1; keep the existing FLASHINFER_CHECK for
power-of-two validation and apply the same change to the duplicate block-size
encoding sites referenced in this function.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e50114de-f61e-4cc8-b6a5-a6420e1a97e7
📒 Files selected for processing (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh
|
/bot run |
789ece9 to
10007b3
Compare
|
/bot run |
10007b3 to
a006ba7
Compare
|
/bot run |
a006ba7 to
1c14fba
Compare
|
/bot run |
… fix (#3089) <!-- .github/pull_request_template.md --> ## 📌 Description The branch has 2 commits: 1. Update trtllm-gen FMHA cubins to fix context SWA page-skip — updates artifacts.py path + checksum 2. Sync trtllm FMHA headers with latest trtllm-gen (from PR #2711) — cherry-picks header changes to match the new cubin MetaInfo struct ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed sparse-attention truncation so sequence-length top-K is applied correctly when sparse-attention is enabled. * **Improvements** * Standardized sparse-attention parameter naming and selection logic to make behavior more consistent across launches and kernel choices. * Skip incompatible kernel variants during runtime kernel loading to avoid incorrect selections. * **Chores** * Updated FMHA runtime artifact paths and their checksums for validation and downloads. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1c14fba to
6983f46
Compare
|
/bot run |
237c9d8 to
b2582c3
Compare
|
/bot run |
|
/bot run |
Head branch was pushed to by a user without write access
b2582c3 to
6aa45d6
Compare
|
/bot run |
|
CI blocked by #3184 |
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
6aa45d6 to
be7a5a6
Compare


📌 Description
This PR adds support for DiT-oriented TRTLLM kernels with 3 variants:
To integrate, the following changes are made to FlashInfer:
a separaterefreshedkernelMetaInfoVx.hfilekernelMetaInfo.hfile with separateddtypeKanddtypeVtraits.FmhaKernelsto support the new DiT kernels.trtllm_ragged_attention_launcheris updated as the entry point to these kernels.A compatibility layer was added asAPI unification has taken place. This PR was refreshed with compatibility layers removed.fmhaKernelMetaAdapter.h🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
& pagedattention; query/key reinterpretation for mixed dtypes and INT8 input handling; expanded kernel parameterization for improved FMHA paths.Tests
Documentation
Chores