Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions csrc/flat_prefill_kernel_delta_rule_sm120_extern.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Extern template declarations to prevent implicit instantiation in the dispatcher.
// Explicit instantiations are in separate generated files for parallel compilation.

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cutlass/arch/arch.h"

namespace flat {

// clang-format off

#define FOR_EACH_BOOL_5(MACRO, ...) \
MACRO(false, false, false, false, false, __VA_ARGS__) \
MACRO(false, false, false, false, true, __VA_ARGS__) \
MACRO(false, false, false, true, false, __VA_ARGS__) \
MACRO(false, false, false, true, true, __VA_ARGS__) \
MACRO(false, false, true, false, false, __VA_ARGS__) \
MACRO(false, false, true, false, true, __VA_ARGS__) \
MACRO(false, false, true, true, false, __VA_ARGS__) \
MACRO(false, false, true, true, true, __VA_ARGS__) \
MACRO(false, true, false, false, false, __VA_ARGS__) \
MACRO(false, true, false, false, true, __VA_ARGS__) \
MACRO(false, true, false, true, false, __VA_ARGS__) \
MACRO(false, true, false, true, true, __VA_ARGS__) \
MACRO(false, true, true, false, false, __VA_ARGS__) \
MACRO(false, true, true, false, true, __VA_ARGS__) \
MACRO(false, true, true, true, false, __VA_ARGS__) \
MACRO(false, true, true, true, true, __VA_ARGS__) \
MACRO(true, false, false, false, false, __VA_ARGS__) \
MACRO(true, false, false, false, true, __VA_ARGS__) \
MACRO(true, false, false, true, false, __VA_ARGS__) \
MACRO(true, false, false, true, true, __VA_ARGS__) \
MACRO(true, false, true, false, false, __VA_ARGS__) \
MACRO(true, false, true, false, true, __VA_ARGS__) \
MACRO(true, false, true, true, false, __VA_ARGS__) \
MACRO(true, false, true, true, true, __VA_ARGS__) \
MACRO(true, true, false, false, false, __VA_ARGS__) \
MACRO(true, true, false, false, true, __VA_ARGS__) \
MACRO(true, true, false, true, false, __VA_ARGS__) \
MACRO(true, true, false, true, true, __VA_ARGS__) \
MACRO(true, true, true, false, false, __VA_ARGS__) \
MACRO(true, true, true, false, true, __VA_ARGS__) \
MACRO(true, true, true, true, false, __VA_ARGS__) \
MACRO(true, true, true, true, true, __VA_ARGS__)

#define DECLARE_TEMPLATE_INSTANCE(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, ctype) \
extern template void launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, cutlass::arch::Sm120, ctype, ctype, float>( \
cudaStream_t, ctype*, float*, ctype const*, ctype const*, ctype const*, \
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t, \
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*, \
int32_t);

// Extern template declarations for half
FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, half)

// Extern template declarations for nv_bfloat16
FOR_EACH_BOOL_5(DECLARE_TEMPLATE_INSTANCE, nv_bfloat16)

#undef DECLARE_TEMPLATE_INSTANCE
#undef FOR_EACH_BOOL_5

// clang-format on

} // namespace flat
25 changes: 20 additions & 5 deletions csrc/gdn_prefill_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, vo
int device_major;
cudaDeviceGetAttribute(&device_major, cudaDevAttrComputeCapabilityMajor, dev_id);

#if defined(FLAT_SM90A_ENABLED)
if (device_major == 9) {
#if defined(FLAT_SM90A_ENABLED)
flat::launch_delta_rule_prefill_kernel<cutlass::arch::Sm90, DType, DType, float>(
stream, static_cast<DType*>(output), static_cast<float*>(output_state),
static_cast<DType const*>(q), static_cast<DType const*>(k), static_cast<DType const*>(v),
Expand All @@ -57,16 +57,31 @@ void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, vo
static_cast<float*>(state_checkpoints), checkpoint_cu_starts,
static_cast<int32_t>(checkpoint_every_n_tokens));
return true;
#else
FLASHINFER_ERROR("sm_90a is not enabled, delta rule kernel is not built");
return false;
#endif
} else if (device_major == 12) {
#if defined(FLAT_SM120A_ENABLED)
flat::launch_delta_rule_prefill_kernel<cutlass::arch::Sm120, DType, DType, float>(
stream, static_cast<DType*>(output), static_cast<float*>(output_state),
static_cast<DType const*>(q), static_cast<DType const*>(k), static_cast<DType const*>(v),
static_cast<float const*>(input_state), static_cast<float const*>(alpha),
static_cast<float const*>(beta), cu_seqlens, workspace_buffer, num_seqs, num_q_heads,
num_k_heads, num_v_heads, num_o_heads, head_size, packed_seq, scale, sm_count,
static_cast<float*>(state_checkpoints), checkpoint_cu_starts,
static_cast<int32_t>(checkpoint_every_n_tokens));
return true;
#else
FLASHINFER_ERROR("sm_120a is not enabled, delta rule kernel is not built");
return false;
#endif
} else {
std::ostringstream err_msg;
err_msg << "delta rule kernel does not support this device major version: " << device_major;
FLASHINFER_ERROR(err_msg.str());
return false;
}
#else
FLASHINFER_ERROR("sm_90a is not enabled, delta rule kernel is not built");
return false;
#endif
});
}

Expand Down
39 changes: 39 additions & 0 deletions csrc/gdn_prefill_sm120_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Auto-generated file for separate compilation of GDN prefill kernel variants.
// Template parameters: dtype={{ dtype }}, is_gva={{ is_gva }}, needs_beta={{ needs_beta }},
// needs_alpha={{ needs_alpha }}, init_state={{ init_state }},
// enable_checkpointing={{ enable_checkpointing }}

// CUDA type definitions for half and nv_bfloat16
#include <cuda_bf16.h>
#include <cuda_fp16.h>

// Include the header which defines the function template
// The header includes all necessary CUTLASS type definitions
#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh"

namespace flat {

// Explicit template instantiation for launch_delta_rule_prefill_kernel_gbai
// Parameter types must exactly match the extern template declaration in prefill_kernel_delta_rule_sm120_extern.inc
template void launch_delta_rule_prefill_kernel_gbai<{{ is_gva }}, {{ needs_beta }}, {{ needs_alpha }}, {{ init_state }}, {{ enable_checkpointing }}, cutlass::arch::Sm120, {{ dtype }}, {{ dtype }}, float>(
Comment thread
coderabbitai[bot] marked this conversation as resolved.
cudaStream_t, {{ dtype }}*, float*, {{ dtype }} const*, {{ dtype }} const*, {{ dtype }} const*,
float const*, float const*, float const*, int64_t const*, uint8_t*, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t, int64_t, float, int32_t, float*, int64_t const*,
int32_t);

} // namespace flat
110 changes: 110 additions & 0 deletions csrc/prefill_kernel_delta_rule_sm120.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>

#include "flashinfer/flat/prefill/prefill_kernel_delta_rule_sm120.cuh"

// Extern template declarations prevent implicit instantiation here.
// Explicit instantiations are in separate generated files for parallel compilation.
#include "flat_prefill_kernel_delta_rule_sm120_extern.inc"

namespace flat {

using namespace cute;

template <typename ArchTag, // FIXME: hide this
typename TO, typename TQKV, typename TState>
void launch_delta_rule_prefill_kernel(
cudaStream_t stream, TO* output, TState* output_state, TQKV const* q, TQKV const* k,
TQKV const* v, TState const* input_state, float const* alpha, float const* beta,
int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads,
int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, int32_t head_size,
int64_t total_seqlen, float scale, int32_t sm_count, float* state_checkpoints,
int64_t const* checkpoint_cu_starts, int32_t checkpoint_every_n_tokens) {
bool is_gva = num_v_heads > num_q_heads;
bool needs_beta = beta != nullptr;
bool needs_alpha = alpha != nullptr;
bool init_state = input_state != nullptr;
bool enable_ckpt = checkpoint_every_n_tokens > 0;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, this is only supported by the SM120 kernel, not the SM90 kernel, is it correct?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are all supported.


#define LAUNCH(is_gva, needs_beta, needs_alpha, init_state, enable_ckpt) \
launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, enable_ckpt, \
ArchTag>( \
stream, output, output_state, q, k, v, input_state, alpha, beta, cu_seqlens, \
workspace_buffer, num_seqs, num_q_heads, num_k_heads, num_v_heads, num_o_heads, head_size, \
total_seqlen, scale, sm_count, state_checkpoints, checkpoint_cu_starts, \
checkpoint_every_n_tokens);

#define DISPATCH_GBAI(init_state, enable_ckpt) \
if (is_gva && needs_beta && needs_alpha) { \
LAUNCH(true, true, true, init_state, enable_ckpt); \
} else if (is_gva && needs_beta && !needs_alpha) { \
LAUNCH(true, true, false, init_state, enable_ckpt); \
} else if (is_gva && !needs_beta && needs_alpha) { \
LAUNCH(true, false, true, init_state, enable_ckpt); \
} else if (is_gva && !needs_beta && !needs_alpha) { \
LAUNCH(true, false, false, init_state, enable_ckpt); \
} else if (!is_gva && needs_beta && needs_alpha) { \
LAUNCH(false, true, true, init_state, enable_ckpt); \
} else if (!is_gva && needs_beta && !needs_alpha) { \
LAUNCH(false, true, false, init_state, enable_ckpt); \
} else if (!is_gva && !needs_beta && needs_alpha) { \
LAUNCH(false, false, true, init_state, enable_ckpt); \
} else if (!is_gva && !needs_beta && !needs_alpha) { \
LAUNCH(false, false, false, init_state, enable_ckpt); \
} else { \
throw std::runtime_error("unreachable"); \
}

if (enable_ckpt) {
if (init_state) {
DISPATCH_GBAI(true, true);
} else {
DISPATCH_GBAI(false, true);
}
} else {
if (init_state) {
DISPATCH_GBAI(true, false);
} else {
DISPATCH_GBAI(false, false);
}
}

#undef DISPATCH_GBAI
#undef LAUNCH
}

// Explicit instantiations for the outer dispatch function only.
// The inner launch_delta_rule_prefill_kernel_gbai instantiations are in separate files.
template void launch_delta_rule_prefill_kernel<cutlass::arch::Sm120, half, half, float>(
cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v,
float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens,
uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads,
int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, int64_t total_seqlen, float scale,
int32_t sm_count, float* state_checkpoints, int64_t const* checkpoint_cu_starts,
int32_t checkpoint_every_n_tokens);

template void
launch_delta_rule_prefill_kernel<cutlass::arch::Sm120, nv_bfloat16, nv_bfloat16, float>(
cudaStream_t stream, nv_bfloat16* output, float* state, nv_bfloat16 const* q,
nv_bfloat16 const* k, nv_bfloat16 const* v, float const* input_state, float const* alpha,
float const* beta, int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs,
int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads,
int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count,
float* state_checkpoints, int64_t const* checkpoint_cu_starts,
int32_t checkpoint_every_n_tokens);

} // namespace flat
19 changes: 14 additions & 5 deletions flashinfer/gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,29 @@
import torch

from .api_logging import flashinfer_api
from .jit.gdn import gen_gdn_prefill_sm90_module
from .jit.gdn import gen_gdn_prefill_sm90_module, gen_gdn_prefill_sm120_module
from .utils import (
register_custom_op,
register_fake_op,
get_device_sm_count,
is_sm90a_supported,
is_sm100a_supported,
is_sm120a_supported,
_get_cache_buf,
)
from .gdn_kernels import chunk_gated_delta_rule_sm100, _has_blackwell_prefill


@functools.cache
def get_gdn_prefill_module():
module = gen_gdn_prefill_sm90_module().build_and_load()
def get_gdn_prefill_module(device: torch.device):
if is_sm90a_supported(device):
module = gen_gdn_prefill_sm90_module().build_and_load()
elif is_sm120a_supported(device):
module = gen_gdn_prefill_sm120_module().build_and_load()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the test, we mention SM120 GDN prefill requires CUDA 13+, but here we only check is_sm120a_supported that checks 12.8 under the hood. Seems like contradicting to each other. What is the minimum recomended cuda version for the sm120 prefill kernels?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually not quite sure what is the recommended version. According to https://docs.nvidia.com/cuda/archive/12.8.1/blackwell-compatibility-guide/index.html

CUDA applications built using CUDA Toolkit 12.8 are compatible with Blackwell architecture as long as they are built to include kernels in native cubin (compute capability 10.0) or PTX form or both.

So I think we are safe to relax the requirement to 12.8.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relaxed to 12.8 in test

else:
raise RuntimeError(
f"GDN prefill kernel requires SM90 or SM120, but device {device} is not supported"
)

@register_custom_op(
"flashinfer::gdn_prefill",
Expand Down Expand Up @@ -183,7 +192,7 @@ def chunk_gated_delta_rule(
- Supports GQA: ``num_q_heads > num_k_heads = num_v_heads``
- Supports GVA: ``num_v_heads > num_q_heads = num_k_heads``
- The final state layout is ``[N, H, V, K]``.
- Requires SM90 (Hopper) or SM100 (Blackwell) architecture.
- Requires SM90 (Hopper) or SM100 (Blackwell) or SM120 (Blackwell RTX) architecture.
- SM100 path requires head_size == 128.
- SM100 path requires ``nvidia-cutlass-dsl[cu13]>=4.4.2``
(install via ``pip install flashinfer-python[cu13]``).
Expand Down Expand Up @@ -336,7 +345,7 @@ def chunk_gated_delta_rule(
"gdn_prefill_workspace", workspace_size, device
)

get_gdn_prefill_module().gdn_prefill(
get_gdn_prefill_module(q.device).gdn_prefill(
output,
output_state,
q,
Expand Down
30 changes: 25 additions & 5 deletions flashinfer/jit/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,37 @@
JitSpec,
gen_jit_spec,
sm90a_nvcc_flags,
sm120a_nvcc_flags,
)
from .utils import write_if_different


def gen_gdn_prefill_sm90_module() -> JitSpec:
def _gen_gdn_prefill_module(arch: str) -> JitSpec:
"""Generate JIT module for GDN prefill kernel with separate compilation.

This generates 32 separate kernel instantiation files (2 dtypes Γ— 16 boolean combinations)
plus the original launcher file. The separate files enable parallel compilation by ninja,
significantly reducing build time on multi-core machines.
"""
Comment on lines 33 to 38
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Docstring count is stale: 32 combinations, not 16.

The docstring says "32 separate kernel instantiation files (2 dtypes Γ— 16 boolean combinations)", but the loop at Line 69 uses itertools.product([False, True], repeat=5) β†’ 32 combinations, and the in-line comment at Line 60 already says "2 dtypes Γ— 32 boolean combinations". So the total is 64 files (2 Γ— 32), which matches the Line 60 comment.

πŸ’‘ Proposed fix
-    This generates 32 separate kernel instantiation files (2 dtypes Γ— 16 boolean combinations)
-    plus the original launcher file. The separate files enable parallel compilation by ninja,
+    This generates 64 separate kernel instantiation files (2 dtypes x 32 boolean combinations)
+    plus the original launcher file. The separate files enable parallel compilation by ninja,
     significantly reducing build time on multi-core machines.
🧰 Tools
πŸͺ› Ruff (0.15.10)

[warning] 35-35: Docstring contains ambiguous Γ— (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gdn.py` around lines 33 - 38, Update the module docstring in
flashinfer/jit/gdn.py to reflect the correct count: change the phrase "32
separate kernel instantiation files (2 dtypes Γ— 16 boolean combinations)" to "64
separate kernel instantiation files (2 dtypes Γ— 32 boolean combinations)" (and
keep the note about the original launcher file) to match the
itertools.product([False, True], repeat=5) loop used to generate kernel variants
and the inline comment near the loop.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from SM90 version, we should open a seperate PR to fix them all in once.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyunh-nv, understood β€” deferring the docstring fix to a follow-up PR that cleans it up across both the SM90 and SM120 versions makes sense. Would you like me to open a GitHub issue to track this fix so it doesn't get lost?


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Each `gen_*_module()` function in `flashinfer/jit/` must: (1) compute unique URI from parameters, (2) create generation directory, (3) optionally render Jinja template, (4) copy source files, (5) return JitSpec

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:22.437Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold β€” not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.

Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:114-116
Timestamp: 2026-03-27T20:33:11.994Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (flashinfer-ai/flashinfer), `_compute_optimal_warps_for_k` must receive `sf_blocks_per_warp` as an explicit parameter (not use the global `SF_BLOCKS_PER_WARP=16` constant). The `MXFP8QuantizeSwizzledKernel` constructor calls it with `self._sf_blocks_per_warp`, which is `SF_BLOCKS_PER_WARP=16` when `use_2t_per_sf=True` and `SF_BLOCKS_PER_WARP_SMALL=8` when `use_2t_per_sf=False`. Using the wrong constant causes fractional `rows_per_block` (e.g., K=3072 4T/SF: 30 warps β†’ 960 threads β†’ 2.5 truncated to 2 β†’ write race from excess threads overlapping the next block's first row). MXFP4 and NVFP4 are unaffected because they use 1 thread per SF block with no multi-thread variant.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py:343-380
Timestamp: 2026-04-14T19:10:27.074Z
Learning: In `flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py` (flashinfer-ai/flashinfer PR `#3066`), `MoEDynamicKernel._setup_attributes()` intentionally omits the full SMEM post-check loop present in `MoEStaticKernel`. The `_compute_stages` output is already conservatively clamped (`max(1, min(ab_stage, 4))`) and further reduced by the k_tile_cnt divisibility check, yielding ab_stage=1 or 2 in all tested configurations β€” well within SM120's 232KB SMEM budget even with the extra staged sB_up/sSFB_up pair. A proper `_shared_storage_size_bytes()` for the dynamic kernel's different struct layout would be needed for a full post-check; the maintainer deferred this to a follow-up. Do not re-flag the missing post-check as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:384-385
Timestamp: 2026-03-27T20:51:45.564Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (`MXFP8QuantizeSwizzledKernel`, small-K path), the padding-column zeroing in the swizzled small-K path requires a thread-stride loop, not a simple predicated write. Because `sf_col_idx = local_tidx // _threads_per_sf` is bounded by `[0, num_sf_blocks_per_row)`, a bare `if sf_col_idx >= num_sf_blocks_per_row` guard is unreachable. The correct pattern (matching MXFP4/NVFP4 swizzled kernels) is:
- Padding rows: loop starting at `sf_col_idx`, striding by `num_sf_blocks_per_row`, up to `padded_sf_cols`.
- Real rows: loop starting at `num_sf_blocks_per_row + sf_col_idx`, striding by `num_sf_blocks_per_row`, guarded by `const_expr(self.num_sf_blocks_per_row != self.padded_sf_cols)` so it is eliminated at compile time when `K/32` is a multiple of 4 (no column padding needed).

Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: tests/norm/test_fused_rmsnorm_silu.py:138-141
Timestamp: 2026-04-03T21:06:16.453Z
Learning: In `tests/norm/test_fused_rmsnorm_silu.py` (flashinfer-ai/flashinfer PR `#2965`), the full `ALL_LUT_SHAPES` test matrix (8 hidden sizes Γ— 5 token counts, up to 399,360 tokens) across bf16, FP8, and NVFP4 is intentionally kept as the default CI parametrization. The maintainer confirmed the tests are fast and do not need to be split into a smoke subset vs. a slow marker. Do not flag this test matrix as too large for CI.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` β†’ `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3026
File: include/flashinfer/gemm/fp4_gemm_template_sm120.h:267-270
Timestamp: 2026-04-09T21:51:00.268Z
Learning: In flashinfer-ai/flashinfer, `include/flashinfer/gemm/fp4_gemm_template_sm120.h` is gated by `#define FLASHINFER_ENABLE_SM120` and is only included from `fp4_gemm_cutlass_template_sm120.h`, which is compiled exclusively for SM120/SM121 targets. Adding a runtime `Sm12xOnly` architecture guard inside this file is redundant β€” there is no code path that instantiates these kernels on non-SM12x hardware. Do not suggest adding such guards to this file.

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Use SHA256 hashing for source files and include operation type, parameters, compilation flags, and CUDA architecture in URI computation for cache invalidation

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3080
File: flashinfer/fused_moe/cute_dsl/b12x_moe.py:48-49
Timestamp: 2026-04-16T01:51:16.398Z
Learning: In flashinfer-ai/flashinfer, only use `backend_requirement` when an API dispatches across multiple backends. For single-backend, architecture-gated APIs that exclusively target a specific compute capability (e.g., SM120/SM121, such as `b12x_fused_moe` / `B12xMoEWrapper` in `flashinfer/fused_moe/cute_dsl/b12x_moe.py`), prefer and keep `supported_compute_capability([120, 121])` instead of suggesting a replacement with `backend_requirement`.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/fused_moe.py:206-220
Timestamp: 2026-04-14T19:11:17.176Z
Learning: In `flashinfer/fused_moe/cute_dsl/fused_moe.py` (flashinfer-ai/flashinfer PR `#3066`), the SM120/SM121 dispatch paths (`_moe_core_impl`, `CuteDslMoEWrapper.run`, and `cute_dsl_fused_moe_nvfp4`) intentionally do NOT forward `local_expert_offset` to `launch_sm120_moe`. Expert Parallelism (EP) is unsupported on SM120: the dynamic kernel (`MoEDynamicKernel`) lacks `global_to_local_expert` remapping, and EP tests are gated to SM100-only via `sm100_only`. Passing `local_expert_offset` without kernel-side support would silently produce incorrect results. Do not flag the missing `local_expert_offset` propagation in SM120 call sites as a bug.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.

Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: include/flashinfer/norm/ln_silu_headers.cuh:258-270
Timestamp: 2026-04-03T20:17:43.361Z
Learning: In `include/flashinfer/norm/ln_silu_headers.cuh`, the pre-SM80 `#else` branch inside `struct Converter<float2, nv_bfloat162>::convert` (the union-based fallback) is intentionally dead code. `fused_rmsnorm_silu` requires SM80+ at runtime, so the `#if __CUDA_ARCH__ >= 800` path (using `__float22bfloat162_rn`) is the only path that ever compiles. Do not flag the union member aliasing issue in the `#else` branch as a bug.

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to flashinfer/jit/**/*.py : Use `functools.cache` decorator on JIT module generator functions to implement Python-level module caching

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-04-14T19:02:21.525Z
Learning: Applies to {include/flashinfer/**/*.cuh,csrc/**/*.cu} : For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2962
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh:232-262
Timestamp: 2026-04-02T18:45:38.854Z
Learning: In `include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh` (flashinfer-ai/flashinfer PR `#2962`), the per-step `state_dst_slots` precompute has three mutually exclusive branches:
1. `dst_state_batch_indices` present β†’ always write unless index == pad_slot_id (caller controls slots via pad_slot_id; no update_state gating needed).
2. `intermediate_states` present β†’ always cache every step (no update_state gating needed).
3. Neither β†’ only write at last step when params.update_state is true.
`intermediate_states_buffer` and `dst_state_batch_indices` are enforced mutually exclusive by a Python-side ValueError in `flashinfer/mamba/selective_state_update.py`. Do not flag the absence of `update_state` gating in branches 1 and 2 as a bug.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).

uri = "gdn_prefill_sm90"
assert arch in ["sm90", "sm120"], (
"GDN prefill kernel is only supported on sm_90a and sm_120a"
)

if arch == "sm90":
arch_specific_flags = sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED"]
elif arch == "sm120":
arch_specific_flags = sm120a_nvcc_flags + ["-DFLAT_SM120A_ENABLED"]

uri = f"gdn_prefill_{arch}"
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_directory, exist_ok=True)

source_paths = []

# Load kernel instantiation template
with open(jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_sm90_kernel_inst.jinja") as f:
with open(
jit_env.FLASHINFER_CSRC_DIR / f"gdn_prefill_{arch}_kernel_inst.jinja"
) as f:
kernel_inst_templ = jinja2.Template(f.read())

# Generate 64 separate instance files (2 dtypes Γ— 32 boolean combinations)
Expand Down Expand Up @@ -74,7 +86,7 @@ def gen_gdn_prefill_sm90_module() -> JitSpec:
# Headers are now in include/flashinfer/flat/ and accessible via standard include paths
for filename in [
"gdn_prefill_launcher.cu",
"prefill_kernel_delta_rule_sm90.cu",
f"prefill_kernel_delta_rule_{arch}.cu",
]:
src_path = jit_env.FLASHINFER_CSRC_DIR / filename
dest_path = gen_directory / src_path.name
Expand All @@ -84,5 +96,13 @@ def gen_gdn_prefill_sm90_module() -> JitSpec:
return gen_jit_spec(
uri,
source_paths,
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
extra_cuda_cflags=arch_specific_flags + ["-std=c++20"],
)


def gen_gdn_prefill_sm90_module():
return _gen_gdn_prefill_module("sm90")


def gen_gdn_prefill_sm120_module():
return _gen_gdn_prefill_module("sm120")
Loading
Loading