Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/nv_internal/tensorrt_llm/common/reduceKernelUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ __inline__ __device__ T blockAllReduceMax(T val) {
return val;
}

#ifdef ENABLE_FP8
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude write a simple python unittest to verify norm module compiles when ENABLE_FP8 is disabled.

Copy link
Copy Markdown

@claude claude Bot Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude finished @yzh119's task in 2m 3s —— View job


Task: Write unittest to verify norm compilation without FP8 ✅

Completed Tasks:

  • Read CLAUDE.md to understand testing patterns
  • Examine existing norm tests to understand the pattern
  • Understand how ENABLE_FP8 is controlled
  • Create a test that verifies norm module compiles without FP8
  • Add test to appropriate test file
  • Commit and push changes

Test Created:
Added test_norm_compilation_without_fp8() to tests/utils/test_norm.py. This test:

  • Creates a JIT spec for the norm module without the -DENABLE_FP8 flag
  • Attempts to build and load the module via JIT compilation
  • Verifies compilation succeeds (which would fail before the fix due to PackType being undefined)

How it verifies the fix:
The test directly addresses issue #2271 by compiling the norm module without FP8 support. Before the fix, this would fail with:

error: incomplete type is not allowed

because batchWarpReduceSum in reduceKernelUtils.cuh depends on PackType, which is only defined when ENABLE_FP8 is set.

With the fix (guarding batchWarpReduceSum with #ifdef ENABLE_FP8), the compilation succeeds because the problematic code is excluded when FP8 is disabled.

Commit: f694301

--- |

// This function requires PackType which is only available when ENABLE_FP8 is defined
template <typename T, int SZ>
__inline__ __device__ typename PackType<T, SZ>::type batchWarpReduceSum(
typename PackType<T, SZ>::type val) {
Expand All @@ -179,6 +181,7 @@ __inline__ __device__ typename PackType<T, SZ>::type batchWarpReduceSum(
}
return val;
}
#endif // ENABLE_FP8

template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) {
Expand Down
3 changes: 3 additions & 0 deletions include/flashinfer/trtllm/common/reduceKernelUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ __inline__ __device__ T blockAllReduceMax(T val) {
return val;
}

#ifdef ENABLE_FP8
// This function requires PackType which is only available when ENABLE_FP8 is defined
template <typename T, int SZ>
__inline__ __device__ typename PackType<T, SZ>::type batchWarpReduceSum(
typename PackType<T, SZ>::type val) {
Expand All @@ -179,6 +181,7 @@ __inline__ __device__ typename PackType<T, SZ>::type batchWarpReduceSum(
}
return val;
}
#endif // ENABLE_FP8

template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) {
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch.nn.functional as F

import flashinfer
from flashinfer.jit import env as jit_env
from flashinfer.jit.core import gen_jit_spec
from flashinfer.utils import device_support_pdl


Expand Down Expand Up @@ -337,6 +339,36 @@ def test_layernorm(batch_size, hidden_size, dtype):
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)


def test_norm_compilation_without_fp8():
"""Test that norm module compiles successfully without ENABLE_FP8 flag.

This test verifies the fix for issue #2271 where batchWarpReduceSum in
reduceKernelUtils.cuh depends on PackType which is only defined when
ENABLE_FP8 is set. The fix guards batchWarpReduceSum with #ifdef ENABLE_FP8.
"""
# Create a JIT spec for norm module without ENABLE_FP8 flag
nvcc_flags = [
"-DENABLE_BF16",
# Note: ENABLE_FP8 is intentionally omitted to test compilation without it
]
spec = gen_jit_spec(
"norm_without_fp8_test",
[
jit_env.FLASHINFER_CSRC_DIR / "norm.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_norm_binding.cu",
],
extra_cuda_cflags=nvcc_flags,
)

# This should compile successfully without errors
# If batchWarpReduceSum is not properly guarded, this will fail with:
# "error: incomplete type is not allowed" for PackType
module = spec.build_and_load()

# Verify the module loaded successfully
assert module is not None


if __name__ == "__main__":
# test_norm(1, 1024, torch.float16, False, True, True)
test_norm(19, 1024, torch.float16, False, True, False)
Expand Down