feat(moe-a2a): Update nvlink onesided all-to-all#3139
feat(moe-a2a): Update nvlink onesided all-to-all#3139trevor-m wants to merge 4 commits intoflashinfer-ai:mainfrom
Conversation
Float32 accumulators, FP8 combine, and related API updates: - vectorized_combine_impl: float32 acc[TOP_K] instead of uint8 byte accumulation; two-pass load-then-upcast (descending j) avoids WAR hazards for InT<float - InT template param: recv buffer element type (defaults to T); enables FP8 in-place combine (InT=fp8_e4m3, T=bf16) without separate conversion kernel - stride_per_token: byte distance between tokens in recv buffer; equals size_per_token normally, differs for FP8 in-place (BF16-stride workspace) - vec_convert: generic + SM89 PTX f16x2→e4m3x2 + SM100 PTX bf16x2→e4m3x2 - vectorized_quant: SrcT→FP8 with ThreadingPolicy::sync() for WAR safety - moeA2APrepareCombineKernel: LOW_PRECISION=true path quantizes payload to FP8 - moeA2ACombineKernel: dispatches FP8→BF16 or same-type path based on T - MoeA2ACombineParams: add use_low_precision flag - Python API: add use_low_precision=False to moe_a2a_combine / MoeAlltoAll.combine - Tests: add FP8 combine correctness test (SM>=89, tolerates FP8 quantization error) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Port PDL support from TRT-LLM b802a3818: - Add getEnvEnablePDL() (TRTLLM_ENABLE_PDL env var, default on SM>=90) - Add launchWithPdlWhenEnabled() helper using cudaLaunchKernelEx - Replace all <<<>>> launches with launchWithPdlWhenEnabled in dispatch, combine, and sanitize kernels - Add cudaGridDependencySynchronize() / cudaTriggerProgrammaticLaunchCompletion() guards (#if __CUDA_ARCH__ >= 900) in all five kernel entry points - Reorder moeA2ACombineKernel params: stride_per_token now before local_num_tokens to match TRT-LLM signature Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdds programmatic dependent launch (PDL) support and FP8/low-precision handling to MoE all-to-all kernels and APIs: introduces a PDL-aware host launch helper, threads Changes
Sequence DiagramsequenceDiagram
participant Host as Host Code
participant Helper as launchWithPdlWhenEnabled
participant CUDA as CUDA Runtime
participant Kernel as Device Kernel
participant Sync as Synchronization
Host->>Helper: call launchWithPdlWhenEnabled(name, enable_pdl, kernelFn, grid, block, shm, stream, args...)
Helper->>CUDA: build cudaLaunchConfig_t (grid, block, shm, stream)
alt enable_pdl == true
Helper->>CUDA: attach cudaLaunchAttributeProgrammaticStreamSerialization(true)
else
Helper->>CUDA: attach cudaLaunchAttributeProgrammaticStreamSerialization(false)
end
Helper->>CUDA: cudaLaunchKernelEx(config, kernelFn, args...)
CUDA->>Kernel: execute kernel on device
Kernel->>Kernel: optional cudaGridDependencySynchronize() (SM>=900)
Kernel->>Sync: signal completion
CUDA->>Helper: return cudaError_t
Helper->>Host: FLASHINFER_CHECK(error)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements support for Programmatic Dependent Launch (PDL) and FP8 quantization within the MoE All-to-All kernels. It introduces a launchWithPdlWhenEnabled utility to manage kernel launches with PDL attributes and updates the moe_a2a_combine logic to handle FP8 data with BF16 accumulation. Critical feedback was provided regarding the incorrect placement of cudaTriggerProgrammaticLaunchCompletion() in several kernels, including moeA2APrepareDispatchKernel, moeA2APrepareCombineKernel, and moeA2ASanitizeExpertIdsKernel. Calling this function at the start of the kernels creates race conditions by signaling completion before the work is actually performed.
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| cudaGridDependencySynchronize(); | ||
| cudaTriggerProgrammaticLaunchCompletion(); | ||
| #endif |
There was a problem hiding this comment.
The placement of cudaTriggerProgrammaticLaunchCompletion() at the beginning of moeA2APrepareDispatchKernel is incorrect. This function signals to subsequent kernels in the stream (that use PDL) that the current kernel has completed the work they depend on. By calling it at the start, you are signaling completion before the send_counters and local_token_counter have actually been zeroed. This creates a race condition where moeA2ADispatchKernel might start executing before the initialization is finished.
To fix this, cudaTriggerProgrammaticLaunchCompletion() should be called at the end of the kernel, and you must ensure that every block in the grid calls it exactly once, even if they return early due to bounds checks.
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| cudaGridDependencySynchronize(); | ||
| cudaTriggerProgrammaticLaunchCompletion(); | ||
| #endif |
There was a problem hiding this comment.
Similar to the dispatch prepare kernel, cudaTriggerProgrammaticLaunchCompletion() is called here before any data has been copied or quantized. The subsequent moeA2ACombineKernel depends on the data being present in the recv_buffer. Triggering completion at the start of this kernel allows the combine kernel to potentially read uninitialized or partially updated data.
Please move the trigger call to the end of the function and restructure the early returns (lines 934, 939, 946) to ensure that cudaTriggerProgrammaticLaunchCompletion() is still called by every block in the grid as required by the CUDA specification.
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| cudaGridDependencySynchronize(); | ||
| cudaTriggerProgrammaticLaunchCompletion(); | ||
| #endif |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/trtllm_moe_alltoall.cu (1)
319-334:⚠️ Potential issue | 🔴 CriticalAllocate BF16 output whenever the combine kernel will run the FP8 path.
Line 320 only switches to BF16 when
useLowPrecisionis true, but line 334 can still setparams.dtype = kFP8for an FP8 payload. The CUDA combine kernel treatskFP8as “FP8 recv buffer → BF16 output”, so an FP8 payload withuseLowPrecision=falseallocates an FP8-sized output and then writes BF16 elements into it.🐛 Proposed fix
- // For FP8 combine, recv buffers hold FP8 but output is BF16 (upcast during accumulation). - DLDataType outputDtype = useLowPrecision ? dl_bfloat16 : payload.dtype(); + auto const payloadDtypeCode = encode_dlpack_dtype(payload.dtype()); + bool const fp8Payload = payloadDtypeCode == float8_e4m3fn_code; + bool const effectiveLowPrecision = useLowPrecision || fp8Payload; + + // For FP8 combine, recv buffers hold FP8 but output is BF16 (upcast during accumulation). + DLDataType outputDtype = effectiveLowPrecision ? dl_bfloat16 : payload.dtype(); Tensor output = alloc_tensor({localNumTokens, elementsPerToken}, outputDtype, payload.device()); @@ - params.use_low_precision = useLowPrecision; + params.use_low_precision = effectiveLowPrecision;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_moe_alltoall.cu` around lines 319 - 334, The output tensor is currently allocated as BF16 only when useLowPrecision is true but the combine kernel treats kFP8 payloads as "FP8 recv → BF16 output", so when params.dtype is set to kFP8 you must allocate BF16 output too; change the allocation logic for outputDtype (used by alloc_tensor and output.data_ptr()) to pick BF16 when either useLowPrecision is true OR the payload dtype indicates FP8 (check payload.dtype() or equivalently toNvDataType(payload.dtype()) == kFP8) so params.dtype can remain as toNvDataType(payload.dtype()) while the output buffer has the correct BF16 shape.flashinfer/comm/trtllm_moe_alltoall.py (1)
320-345:⚠️ Potential issue | 🟠 MajorAdd SM89+ runtime check for
use_low_precision=Truepath.The tests skip FP8 combine on SM<89, but the public APIs (
moe_a2a_combineat line 320–346 andMoeAlltoAll.combineat lines 659–698) allow callers to passuse_low_precision=Truewithout checking device capability. Add a runtime validation before forwarding to the custom op so unsupported GPUs fail predictably. Follow the pattern inflashinfer/norm/__init__.py:703–704(checkget_compute_capabilityand raiseRuntimeErrorif SM<89 when FP8 is requested).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_moe_alltoall.py` around lines 320 - 345, Add a runtime GPU capability check when the user requests FP8: in moe_a2a_combine (before calling get_moe_alltoall_module().moe_a2a_combine) and likewise in the MoeAlltoAll.combine method, call get_compute_capability() and if use_low_precision is True and the returned compute capability is less than 89 raise a RuntimeError with a clear message that SM<89 does not support low-precision FP8 combine; this mirrors the pattern used in flashinfer/norm/__init__.py and ensures unsupported devices fail predictably instead of invoking the custom op.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 294-297: The PDL completion calls (cudaGridDependencySynchronize()
and cudaTriggerProgrammaticLaunchCompletion()) are being invoked before kernels
publish the state their consumers read; move these calls to the end of each
affected kernel so they occur after all required writes and an appropriate
system-wide visibility fence. Concretely, in moeA2APrepareDispatchKernel,
moeA2ADispatchKernel, moeA2APrepareCombineKernel, and
moeA2ASanitizeExpertIdsKernel remove the early calls around lines
~295/415/923/1200 and instead place cudaGridDependencySynchronize() followed by
__threadfence_system() (or ensure an existing __threadfence_system() runs first)
and then cudaTriggerProgrammaticLaunchCompletion() as the final operations in
each kernel, after zeroing/incrementing send_counters/local_token_counter, after
writing recv_counters/completion_flags, and after incrementing flag_val/staging
payloads respectively so all produced state is published before signaling
completion.
In `@tests/comm/test_trtllm_moe_alltoall.py`:
- Around line 600-603: Replace the raw torch.cuda.get_device_capability(0) check
in the pytest.mark.skipif decorator with
flashinfer.utils.get_compute_capability: import get_compute_capability and call
get_compute_capability(torch.device("cuda")), convert the returned (major,
minor) tuple into a numeric capability via major * 10 + minor and use >= 89 to
gate the FP8 test; update the pytest.mark.skipif condition that currently
references torch.cuda.get_device_capability to use this computed value (refer to
get_compute_capability and the pytest.mark.skipif decorator in the test).
---
Outside diff comments:
In `@csrc/trtllm_moe_alltoall.cu`:
- Around line 319-334: The output tensor is currently allocated as BF16 only
when useLowPrecision is true but the combine kernel treats kFP8 payloads as "FP8
recv → BF16 output", so when params.dtype is set to kFP8 you must allocate BF16
output too; change the allocation logic for outputDtype (used by alloc_tensor
and output.data_ptr()) to pick BF16 when either useLowPrecision is true OR the
payload dtype indicates FP8 (check payload.dtype() or equivalently
toNvDataType(payload.dtype()) == kFP8) so params.dtype can remain as
toNvDataType(payload.dtype()) while the output buffer has the correct BF16
shape.
In `@flashinfer/comm/trtllm_moe_alltoall.py`:
- Around line 320-345: Add a runtime GPU capability check when the user requests
FP8: in moe_a2a_combine (before calling
get_moe_alltoall_module().moe_a2a_combine) and likewise in the
MoeAlltoAll.combine method, call get_compute_capability() and if
use_low_precision is True and the returned compute capability is less than 89
raise a RuntimeError with a clear message that SM<89 does not support
low-precision FP8 combine; this mirrors the pattern used in
flashinfer/norm/__init__.py and ensures unsupported devices fail predictably
instead of invoking the custom op.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 83a0c88e-165b-407e-8741-9d94c340bbac
📒 Files selected for processing (7)
csrc/nv_internal/cpp/common/envUtils.cppcsrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cucsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.hcsrc/trtllm_moe_alltoall.cuflashinfer/comm/trtllm_moe_alltoall.pytests/comm/test_trtllm_moe_alltoall.py
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| cudaGridDependencySynchronize(); | ||
| cudaTriggerProgrammaticLaunchCompletion(); | ||
| #endif |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "moeAlltoAllKernels.cu" -type fRepository: flashinfer-ai/flashinfer
Length of output: 150
🏁 Script executed:
wc -l ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 155
🏁 Script executed:
# Examine the primary flagged location (lines 294-297) with context
sed -n '280,310p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
# Examine second location (lines 415-417)
echo "---"
sed -n '405,430p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
# Examine third location (lines 923-926)
echo "---"
sed -n '910,940p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
# Examine fourth location (lines 1200-1203)
echo "---"
sed -n '1190,1220p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 5850
🏁 Script executed:
# Get more context around the dispatch kernel to understand data flow
sed -n '250,330p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 4564
🏁 Script executed:
# Examine the moeA2ADispatchKernel that comes after prepare-dispatch to see data dependencies
sed -n '330,450p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 5493
🏁 Script executed:
# Check what comes before PrepareDispatch to understand if previous kernel's trigger is expected
sed -n '1,250p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -n "cudaTriggerProgrammaticLaunchCompletion\|__global__" | tail -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Let me check the full context of moeA2ADispatchKernel - how much work is done after trigger
sed -n '416,450p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1438
🏁 Script executed:
# Check if there's documentation or comments about kernel launch order/dependencies
grep -n "moeA2APrepare\|moeA2ADispatch\|moeA2ACombine\|moeA2ASanitize" ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1454
🏁 Script executed:
# Check the launchWithPdlWhenEnabled function to understand PDL launch semantics
grep -n "launchWithPdlWhenEnabled" ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -5
# Now search for the definition of this function
grep -B5 -A20 "void launchWithPdlWhenEnabled\|launchWithPdlWhenEnabled.*{" ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -40Repository: flashinfer-ai/flashinfer
Length of output: 502
🏁 Script executed:
# Find where launchWithPdlWhenEnabled is defined
find . -name "*.cu" -o -name "*.cuh" -o -name "*.hpp" -o -name "*.h" | xargs grep -l "launchWithPdlWhenEnabled" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 200
🏁 Script executed:
grep -B5 -A30 "launchWithPdlWhenEnabled" ./csrc/nv_internal/tensorrt_llm/common/envUtils.h | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1147
🏁 Script executed:
# Let me examine more carefully what happens after each trigger - check the complete flow for one cycle
sed -n '294,310p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
echo "---Dispatch trigger and what follows---"
sed -n '414,430p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
echo "---PrepareCombine trigger and what follows---"
sed -n '923,950p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2530
🏁 Script executed:
# Let me look at the exact sequence: what does moeA2ADispatchKernel write after its trigger at 416?
sed -n '414,470p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2574
🏁 Script executed:
# Check what PrepareDispatch writes
sed -n '292,315p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
echo "---What DispatchKernel reads from PrepareDispatch---"
# Check what DispatchKernel reads
sed -n '320,380p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -n "send_counters\|local_token_counter"
echo "---What DispatchKernel writes after trigger---"
sed -n '414,450p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2755
🏁 Script executed:
# Check the sanitize kernel trigger
sed -n '1196,1215p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n
echo "---What combine kernel does before sanitize---"
# Look at end of combine kernel
sed -n '1050,1090p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3391
🏁 Script executed:
# Let me verify the exact sequence by checking what happens between trigger and actual data writes
# For PrepareDispatch specifically
sed -n '292,315p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
echo "---"
# For Dispatch kernel - check full flow after trigger
sed -n '414,445p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 2009
Move PDL completion triggers to after all required state is published.
These kernels trigger cudaTriggerProgrammaticLaunchCompletion() before the data they produce is ready, creating race windows for dependent PDL-enabled kernels. Specifically:
- moeA2APrepareDispatchKernel (line 295): Triggers before zeroing
send_countersandlocal_token_counter(lines 297–301), which moeA2ADispatchKernel reads and increments. - moeA2ADispatchKernel (line 416): Triggers before writing
recv_countersandcompletion_flags(lines 425–440), which moeA2APrepareCombineKernel depends on reading. - moeA2APrepareCombineKernel (line 925): Triggers before incrementing
flag_val(line 928) and staging payload data, which moeA2ACombineKernel requires. - moeA2ASanitizeExpertIdsKernel (line 1202): Triggers before its consumer has guaranteed visibility of previously written state.
Move each trigger to the end of the kernel, after __threadfence_system() or equivalent barrier and all required writes are complete.
Also applies to: 415–417, 923–926, 1200–1203
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 294 - 297, The PDL completion calls
(cudaGridDependencySynchronize() and cudaTriggerProgrammaticLaunchCompletion())
are being invoked before kernels publish the state their consumers read; move
these calls to the end of each affected kernel so they occur after all required
writes and an appropriate system-wide visibility fence. Concretely, in
moeA2APrepareDispatchKernel, moeA2ADispatchKernel, moeA2APrepareCombineKernel,
and moeA2ASanitizeExpertIdsKernel remove the early calls around lines
~295/415/923/1200 and instead place cudaGridDependencySynchronize() followed by
__threadfence_system() (or ensure an existing __threadfence_system() runs first)
and then cudaTriggerProgrammaticLaunchCompletion() as the final operations in
each kernel, after zeroing/incrementing send_counters/local_token_counter, after
writing recv_counters/completion_flags, and after incrementing flag_val/staging
payloads respectively so all produced state is published before signaling
completion.
| @pytest.mark.skipif( | ||
| torch.cuda.get_device_capability(0) < (8, 9), | ||
| reason="FP8 (e4m3) requires SM>=89 (Ada Lovelace or newer)", | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Locate supported compute-capability helper APIs for replacing the raw torch capability check.
rg -n -C3 'def (is_sm[0-9a-z_]*_supported|.*compute_capability.*supported)|is_compute_capability_supported|supported_compute_capability' --type pyRepository: flashinfer-ai/flashinfer
Length of output: 46392
🏁 Script executed:
# Find the test file and identify what API is being tested
head -n 650 tests/comm/test_trtllm_moe_alltoall.py | tail -n 60Repository: flashinfer-ai/flashinfer
Length of output: 2542
🏁 Script executed:
# Search for FP8 or SM89 related helpers in the codebase
rg -n 'def.*fp8|def.*sm89|SM89|FP8.*supported' --type py flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check what patterns are used in other comm tests
rg -n '@pytest.mark.skipif|is_compute_capability_supported' --type py tests/comm/ | head -20Repository: flashinfer-ai/flashinfer
Length of output: 498
🏁 Script executed:
# Find what MOE/comm functions are being called in the test
grep -n 'from flashinfer\|import.*moe\|cute_dsl\|trtllm' tests/comm/test_trtllm_moe_alltoall.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1366
🏁 Script executed:
# Search for FP8 MOE functions in codebase
rg -n 'cute_dsl_fused_moe.*fp8|trtllm.*fp8' --type py flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 3335
🏁 Script executed:
# Check if there's a utility to check compute capability ranges or SM89 support
rg -n 'def.*compute_capability|def.*sm8[0-9]' flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 274
🏁 Script executed:
# Look at how the test file imports and uses utilities
head -50 tests/comm/test_trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 1380
🏁 Script executed:
# Check what compute capabilities the trtllm_moe_alltoall module supports
rg -n 'supported_compute_capability|is_compute_capability' flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the entire MoeAlltoAll or functions in trtllm_moe_alltoall to see if they have decorators
head -200 flashinfer/comm/trtllm_moe_alltoall.py | grep -A 5 'def \|@'Repository: flashinfer-ai/flashinfer
Length of output: 2207
🏁 Script executed:
# Check if there are any helpers for range-based compute capability checks
rg -n 'def.*compute_capability.*ge|def.*compute_capability.*range|def.*sm8' flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how get_compute_capability works and what it returns
sed -n '264,280p' flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 572
🏁 Script executed:
# Check if trtllm_moe_alltoall or its functions are decorated with supported_compute_capability
grep -B 10 'class MoeAlltoAll\|def moe_a2a\|def get_moe_alltoall_module' flashinfer/comm/trtllm_moe_alltoall.py | grep -E '@|def |class ' | head -30Repository: flashinfer-ai/flashinfer
Length of output: 790
🏁 Script executed:
# Check if there's a broader utility for SM89+ or FP8 support checks
rg -n 'sm.*89|fp8.*supported' flashinfer/utils.py -iRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the entire lines 543-600 of flashinfer/utils.py to see all the is_sm*_supported functions
sed -n '543,600p' flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 2308
🏁 Script executed:
# Search for how other test files handle range-based compute capability checks
rg -B 3 -A 3 'get_compute_capability.*8.*9|compute_capability.*<.*8.*9' tests/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check what patterns are used in other places for SM89+ checks
rg -B 5 -A 5 'SM.*89|sm.*89|8.*9' tests/ --type py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 2377
Use flashinfer.utils.get_compute_capability() to gate this FP8 test.
Replace the raw torch.cuda.get_device_capability(0) check with get_compute_capability() from flashinfer.utils, following the repository's test gating pattern. Compute the numeric capability (major * 10 + minor) and check if it's >= 89:
Suggested fix
from flashinfer.utils import get_compute_capability
`@pytest.mark.skipif`(
get_compute_capability(torch.device("cuda")) < (8, 9),
reason="FP8 (e4m3) requires SM>=89 (Ada Lovelace or newer)",
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/comm/test_trtllm_moe_alltoall.py` around lines 600 - 603, Replace the
raw torch.cuda.get_device_capability(0) check in the pytest.mark.skipif
decorator with flashinfer.utils.get_compute_capability: import
get_compute_capability and call get_compute_capability(torch.device("cuda")),
convert the returned (major, minor) tuple into a numeric capability via major *
10 + minor and use >= 89 to gate the FP8 test; update the pytest.mark.skipif
condition that currently references torch.cuda.get_device_capability to use this
computed value (refer to get_compute_capability and the pytest.mark.skipif
decorator in the test).
|
|
||
| bool getEnvEnablePDL() { | ||
| static bool const kEnablePDL = []() -> bool { | ||
| auto const val = getIntEnv("TRTLLM_ENABLE_PDL"); |
There was a problem hiding this comment.
Does FlashInfer have no env var for enabling PDL? TRTLLM_ENABLE_PDL sounds pretty weird in FlashInfer
There was a problem hiding this comment.
Oops, I had made a change to use a bool flag enable_pdl to match other Flashinfer APIs instead of the env var, but I forgot to push it. Please check again, thanks for reviewing!
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/comm/trtllm_moe_alltoall.py (1)
119-131:⚠️ Potential issue | 🟡 MinorDocument
use_low_precisionin the custom-op wrapper docstring.The signature now accepts
use_low_precision, but the Args block jumps frompayload_in_workspacetoenable_pdl. Add the new parameter description to keep generated API docs accurate.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_moe_alltoall.py` around lines 119 - 131, The docstring Args block for the custom-op wrapper is missing the new parameter description for use_low_precision; update the docstring in trtllm_moe_alltoall.py (the wrapper function that now accepts use_low_precision) by inserting a one-line entry describing use_low_precision (e.g., boolean to enable reduced-precision computation) between payload_in_workspace and enable_pdl so the generated API docs include it and match the function signature.csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
553-574:⚠️ Potential issue | 🟡 MinorCommit the clang-format output for the launcher changes.
Pre-commit is failing because clang-format rewrote argument wrapping and general C++ formatting in this file. Please run formatting and commit the result.
Also applies to: 1102-1117, 1184-1195, 1228-1230
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu` around lines 553 - 574, The diff shows formatting changes to kernel launcher calls (SWITCH_TOP_K, moeA2ADispatchKernel, launchWithPdlWhenEnabled) that clang-format applied but were not committed; run your repository's clang-format configuration over this file (and the other affected ranges around the SWITCH_TOP_K blocks at the noted locations), stage the resulting formatted file(s), and commit so pre-commit no longer fails—ensure the argument wrapping for the launchWithPdlWhenEnabled calls and the surrounding SWITCH_TOP_K blocks (using WarpPolicy/BlockPolicy and TOP_K) match the clang-formatted output before pushing.
♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
292-309:⚠️ Potential issue | 🔴 CriticalMove PDL completion triggers after publishing the state consumed downstream.
The early
cudaTriggerProgrammaticLaunchCompletion()calls still let dependent launches start before required data is visible: prepare-dispatch triggers before zeroing counters/incrementingflag_val, dispatch triggers before publishingrecv_counters/completion flags, prepare-combine triggers beforeflag_val/payload staging, and sanitize triggers before invalid expert IDs are written. Move each trigger after the relevant writes and a system-scope fence. This was already raised on an earlier revision and still appears present.Also applies to: 320-491, 918-971, 1199-1220
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu` around lines 292 - 309, In moeA2APrepareDispatchKernel the cudaTriggerProgrammaticLaunchCompletion() (and similar triggers in the other kernels called out) fires before the kernel publishes state (zeroing send_counters/local_token_counter and incrementing *flag_val_ptr), which can allow dependent PDL launches to observe stale data; fix by moving the cudaTriggerProgrammaticLaunchCompletion() call to after all writes that consumers depend on (e.g., after send_counters[idx]=0, *local_token_counter=0 and *flag_val_ptr++ for moeA2APrepareDispatchKernel) and insert a system-scoped fence (e.g., asm volatile("membar.sys" ::: "memory") or the appropriate CUDA API) immediately before the trigger so the published state is visible to dependent programmatic launches; apply the same pattern to the other kernels referenced (prepare-dispatch, dispatch, prepare-combine, sanitize) so triggers occur only after their consumer-visible writes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/nv_internal/tensorrt_llm/common/envUtils.h`:
- Around line 108-125: Run clang-format and commit the formatted version of the
helper function launchWithPdlWhenEnabled in envUtils.h: apply your project's
clang-format configuration to reformat the include order, signature/wrapping and
function body (the cudaLaunchConfig_t setup, attrs array, and cudaLaunchKernelEx
call), then stage and commit the changed file so the pre-commit CI hook stops
failing.
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 874-895: The ThreadingPolicy::sync() inside the for-loop in
vectorized_quant_impl is divergent and can deadlock because some threads exit
the loop early; replace the variable-length loop with a fixed iteration count
computed from num_elements and stride so every thread executes the same number
of iterations, compute e = ThreadingPolicy::offset()*VEC_SIZE + iter*stride each
iteration, then guard loads/stores with an if (e < num_elements) so only valid
lanes touch src/dst; keep ThreadingPolicy::sync() as a per-iteration barrier
(not conditional) and leave vec_t, in_vec.load, vec_convert and out_vec.store
unchanged.
---
Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 553-574: The diff shows formatting changes to kernel launcher
calls (SWITCH_TOP_K, moeA2ADispatchKernel, launchWithPdlWhenEnabled) that
clang-format applied but were not committed; run your repository's clang-format
configuration over this file (and the other affected ranges around the
SWITCH_TOP_K blocks at the noted locations), stage the resulting formatted
file(s), and commit so pre-commit no longer fails—ensure the argument wrapping
for the launchWithPdlWhenEnabled calls and the surrounding SWITCH_TOP_K blocks
(using WarpPolicy/BlockPolicy and TOP_K) match the clang-formatted output before
pushing.
In `@flashinfer/comm/trtllm_moe_alltoall.py`:
- Around line 119-131: The docstring Args block for the custom-op wrapper is
missing the new parameter description for use_low_precision; update the
docstring in trtllm_moe_alltoall.py (the wrapper function that now accepts
use_low_precision) by inserting a one-line entry describing use_low_precision
(e.g., boolean to enable reduced-precision computation) between
payload_in_workspace and enable_pdl so the generated API docs include it and
match the function signature.
---
Duplicate comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 292-309: In moeA2APrepareDispatchKernel the
cudaTriggerProgrammaticLaunchCompletion() (and similar triggers in the other
kernels called out) fires before the kernel publishes state (zeroing
send_counters/local_token_counter and incrementing *flag_val_ptr), which can
allow dependent PDL launches to observe stale data; fix by moving the
cudaTriggerProgrammaticLaunchCompletion() call to after all writes that
consumers depend on (e.g., after send_counters[idx]=0, *local_token_counter=0
and *flag_val_ptr++ for moeA2APrepareDispatchKernel) and insert a system-scoped
fence (e.g., asm volatile("membar.sys" ::: "memory") or the appropriate CUDA
API) immediately before the trigger so the published state is visible to
dependent programmatic launches; apply the same pattern to the other kernels
referenced (prepare-dispatch, dispatch, prepare-combine, sanitize) so triggers
occur only after their consumer-visible writes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a2a02583-0804-4b03-9257-1bc026ecdfbd
📒 Files selected for processing (6)
csrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cucsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.hcsrc/trtllm_moe_alltoall.cuflashinfer/comm/trtllm_moe_alltoall.pytests/comm/test_trtllm_moe_alltoall.py
🚧 Files skipped from review as they are similar to previous changes (3)
- csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
- tests/comm/test_trtllm_moe_alltoall.py
- csrc/trtllm_moe_alltoall.cu
| template <typename KernelFn, typename... Args> | ||
| inline void launchWithPdlWhenEnabled(char const* name, bool enable_pdl, KernelFn kernelFn, | ||
| dim3 grid, dim3 block, size_t dynamicShmSize, cudaStream_t stream, Args&&... args) | ||
| { | ||
| cudaLaunchConfig_t kernelConfig; | ||
| kernelConfig.gridDim = grid; | ||
| kernelConfig.blockDim = block; | ||
| kernelConfig.dynamicSmemBytes = dynamicShmSize; | ||
| kernelConfig.stream = stream; | ||
| cudaLaunchAttribute attrs[1]; | ||
| attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; | ||
| kernelConfig.attrs = attrs; | ||
| kernelConfig.numAttrs = 1; | ||
| cudaError_t e = cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...); | ||
| FLASHINFER_CHECK(e == cudaSuccess, "cudaLaunchKernelEx (", name, ") failed: ", | ||
| cudaGetErrorString(e)); | ||
| } |
There was a problem hiding this comment.
Commit the clang-format output for this helper.
Pre-commit is failing because clang-format rewrote this file’s include order/signature/body wrapping. Please run formatting and commit the result so CI can pass.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/nv_internal/tensorrt_llm/common/envUtils.h` around lines 108 - 125, Run
clang-format and commit the formatted version of the helper function
launchWithPdlWhenEnabled in envUtils.h: apply your project's clang-format
configuration to reformat the include order, signature/wrapping and function
body (the cudaLaunchConfig_t setup, attrs array, and cudaLaunchKernelEx call),
then stage and commit the changed file so the pre-commit CI hook stops failing.
| template <int VEC_SIZE, typename ThreadingPolicy, typename SrcT, typename DstT> | ||
| __device__ void vectorized_quant_impl(DstT* dst, SrcT const* src, int num_elements) { | ||
| using flashinfer::vec_t; | ||
|
|
||
| int const stride = ThreadingPolicy::stride() * VEC_SIZE; | ||
|
|
||
| for (int e = ThreadingPolicy::offset() * VEC_SIZE; e < num_elements; e += stride) { | ||
| vec_t<SrcT, VEC_SIZE> in_vec; | ||
| in_vec.load(src + e); | ||
|
|
||
| // Sync to ensure all threads have loaded their input vectors before any thread starts writing output. | ||
| // This avoids write-after-read hazards in the FP8 in-place case where the output of this kernel is | ||
| // read by the next iteration as input. Without this sync, some threads might start writing their | ||
| // output (DstT) before other threads have loaded their input (SrcT), causing the load to read partially | ||
| // updated data. | ||
| ThreadingPolicy::sync(); | ||
|
|
||
| vec_t<DstT, VEC_SIZE> out_vec; | ||
| vec_convert(out_vec, in_vec); | ||
| out_vec.store(dst + e); | ||
| } | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "moeAlltoAllKernels.cu" -type fRepository: flashinfer-ai/flashinfer
Length of output: 150
🏁 Script executed:
cd csrc && find . -path "*nv_internal/tensorrt_llm*" -name "*.cu" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1972
🏁 Script executed:
rg "vectorized_quant_impl" --type cu --type cuh -A 30 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
rg "class ThreadingPolicy\|struct ThreadingPolicy" --type cu --type cuh -A 10Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
rg "ThreadingPolicy.*sync\|sync.*ThreadingPolicy" --type cu --type cuh -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
wc -l csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 153
🏁 Script executed:
sed -n '870,900p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 1472
🏁 Script executed:
rg "ThreadingPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 5 -A 10 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4615
🏁 Script executed:
rg "struct.*Threading\|class.*Threading" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/ -A 15Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "ThreadingPolicy::sync\|__syncthreads" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 3 -A 1Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '800,880p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4062
🏁 Script executed:
rg "for.*offset\(\)" csrc/nv_internal/tensorrt_llm/kernels/ -B 2 -A 8 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 5730
🏁 Script executed:
sed -n '860,920p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 3065
🏁 Script executed:
rg "vectorized_quant<" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1363
🏁 Script executed:
rg "struct.*ThreadingPolicy\|class.*ThreadingPolicy" csrc/nv_internal/tensorrt_llm/ -A 8Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1,100p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -A 10 "ThreadingPolicy\|sync"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "ThreadingPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 15 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2724
🏁 Script executed:
sed -n '1,700p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | tail -100Repository: flashinfer-ai/flashinfer
Length of output: 4247
🏁 Script executed:
rg "moeA2APrepareCombineKernel" csrc/nv_internal/tensorrt_llm/ -A 3 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1722
🏁 Script executed:
sed -n '1,50p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 1631
🏁 Script executed:
rg "BlockPolicy\|WarpPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -n "struct WarpPolicy\|struct BlockPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 116
🏁 Script executed:
sed -n '300,450p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 5947
Remove divergent sync from loop in vectorized_quant_impl.
Line 889: The ThreadingPolicy::sync() call inside the loop condition e < num_elements causes deadlock. Threads that exit early (when e >= num_elements) never reach the sync, while others wait indefinitely. This divergence is unsafe for both __syncwarp() (WarpPolicy) and __syncthreads() (BlockPolicy).
Fix: Use fixed iteration count with inner guards
template <int VEC_SIZE, typename ThreadingPolicy, typename SrcT, typename DstT>
__device__ void vectorized_quant_impl(DstT* dst, SrcT const* src, int num_elements) {
using flashinfer::vec_t;
int const stride = ThreadingPolicy::stride() * VEC_SIZE;
+ int const iterations = ceilDiv(num_elements, stride);
- for (int e = ThreadingPolicy::offset() * VEC_SIZE; e < num_elements; e += stride) {
+ for (int iter = 0; iter < iterations; ++iter) {
+ int const e = ThreadingPolicy::offset() * VEC_SIZE + iter * stride;
+ bool const valid = e < num_elements;
vec_t<SrcT, VEC_SIZE> in_vec;
- in_vec.load(src + e);
+ if (valid) {
+ in_vec.load(src + e);
+ }
// Sync to ensure all threads have loaded their input vectors before any thread starts writing output.
// This avoids write-after-read hazards in the FP8 in-place case where the output of this kernel is
// read by the next iteration as input. Without this sync, some threads might start writing their
// output (DstT) before other threads have loaded their input (SrcT), causing the load to read partially
// updated data.
ThreadingPolicy::sync();
- vec_t<DstT, VEC_SIZE> out_vec;
- vec_convert(out_vec, in_vec);
- out_vec.store(dst + e);
+ if (valid) {
+ vec_t<DstT, VEC_SIZE> out_vec;
+ vec_convert(out_vec, in_vec);
+ out_vec.store(dst + e);
+ }
}
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 874 - 895, The ThreadingPolicy::sync() inside the for-loop in
vectorized_quant_impl is divergent and can deadlock because some threads exit
the loop early; replace the variable-length loop with a fixed iteration count
computed from num_elements and stride so every thread executes the same number
of iterations, compute e = ThreadingPolicy::offset()*VEC_SIZE + iter*stride each
iteration, then guard loads/stores with an if (e < num_elements) so only valid
lanes touch src/dst; keep ThreadingPolicy::sync() as a per-iteration barrier
(not conditional) and leave vec_t, in_vec.load, vec_convert and out_vec.store
unchanged.
3dca42c to
bfcc10f
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/trtllm_moe_alltoall.py (1)
297-310:⚠️ Potential issue | 🟠 MajorGate explicit
enable_pdl=Truewith device support in all three public wrappers.These functions only auto-detect when
enable_pdl is None, allowing a caller to passTrueand force PDL on unsupported devices. Implement the tri-state pattern: onlyFalseforces PDL off;NoneandTrueshould enable it only when the device supports it.Proposed fix
- if enable_pdl is None: - enable_pdl = device_support_pdl(token_selected_experts.device) + enable_pdl = ( + device_support_pdl(token_selected_experts.device) + if enable_pdl is not False + else False + )Apply the same pattern to
moe_a2a_combine(line 347) andmoe_a2a_sanitize_expert_ids(line 374).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_moe_alltoall.py` around lines 297 - 310, The wrappers currently only auto-detect PDL when enable_pdl is None, allowing callers to force True on unsupported devices; change the tri-state handling so only enable_pdl == False forces PDL off, otherwise consult device_support_pdl(...) to decide: replace the current "if enable_pdl is None: enable_pdl = device_support_pdl(token_selected_experts.device)" logic in the public wrappers moe_a2a_dispatch (the block calling get_moe_alltoall_module().moe_a2a_dispatch), moe_a2a_combine, and moe_a2a_sanitize_expert_ids with something like: if enable_pdl is False: keep False else set enable_pdl = device_support_pdl(token_selected_experts.device) so that both None and True are gated by device_support_pdl(token_selected_experts.device).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_moe_alltoall.cu`:
- Line 200: The code assigns enablePdl directly to params.enable_pdl
(params.enable_pdl = enablePdl) without validating device PDL support; add a
device-capability check before setting params.enable_pdl (or gate explicit True
in the Python wrapper) so that PDL is only enabled on supported hardware: query
the device compute capability (or reuse the existing device_support_pdl logic)
and set params.enable_pdl = true only when enablePdl is true AND
device_support_pdl(device) is true (otherwise set false); update the C++ FFI
boundary where enablePdl is consumed (the assignment to params.enable_pdl) to
perform this check, or mirror the Python pattern enable_pdl =
device_support_pdl(input.device) if enable_pdl is not False else False.
---
Outside diff comments:
In `@flashinfer/comm/trtllm_moe_alltoall.py`:
- Around line 297-310: The wrappers currently only auto-detect PDL when
enable_pdl is None, allowing callers to force True on unsupported devices;
change the tri-state handling so only enable_pdl == False forces PDL off,
otherwise consult device_support_pdl(...) to decide: replace the current "if
enable_pdl is None: enable_pdl =
device_support_pdl(token_selected_experts.device)" logic in the public wrappers
moe_a2a_dispatch (the block calling get_moe_alltoall_module().moe_a2a_dispatch),
moe_a2a_combine, and moe_a2a_sanitize_expert_ids with something like: if
enable_pdl is False: keep False else set enable_pdl =
device_support_pdl(token_selected_experts.device) so that both None and True are
gated by device_support_pdl(token_selected_experts.device).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d5a9e84a-34e1-4de1-8931-4570434f2f89
📒 Files selected for processing (6)
csrc/nv_internal/tensorrt_llm/common/envUtils.hcsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cucsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.hcsrc/trtllm_moe_alltoall.cuflashinfer/comm/trtllm_moe_alltoall.pytests/comm/test_trtllm_moe_alltoall.py
🚧 Files skipped from review as they are similar to previous changes (3)
- csrc/nv_internal/tensorrt_llm/common/envUtils.h
- tests/comm/test_trtllm_moe_alltoall.py
- csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
|
|
||
| tl_throughput::MoeA2ADispatchParams params{}; | ||
| params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); | ||
| params.enable_pdl = enablePdl; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for existing CUDA compute-capability / PDL support helpers.
rg -n -C3 'device_support_pdl|ProgrammaticStreamSerialization|cudaDeviceGetAttribute|cudaDevAttrComputeCapability|cudaGetDeviceProperties|get_compute_capability|compute capability|enable_pdl|enablePdl' --glob '!**/build/**' --glob '!**/dist/**'Repository: flashinfer-ai/flashinfer
Length of output: 50382
🏁 Script executed:
#!/bin/bash
# Look for PDL-related device support checks or compute capability guards in C++
rg -n 'compute.capability|sm_\d+|CC_|cuda.*Capability|supports.*pdl|pdl.*support' --type cpp --type cuda
# Also check how trtllm_moe_alltoall passes enablePdl downstream
echo "=== Searching for where enablePdl is used in trtllm_moe_alltoall.cu ==="
rg -n -A5 -B2 'enablePdl' csrc/trtllm_moe_alltoall.cu | head -100Repository: flashinfer-ai/flashinfer
Length of output: 2933
🏁 Script executed:
#!/bin/bash
# Find where MoeA2ADispatchParams is defined and used
echo "=== Searching for MoeA2ADispatchParams definition ==="
rg -n 'struct MoeA2ADispatchParams|class MoeA2ADispatchParams' --type-list | head -20
# Find references to tl_throughput in the codebase
echo "=== Searching for tl_throughput includes and usage ==="
rg -n 'tl_throughput' csrc/trtllm_moe_alltoall.cu | head -30
# Search for actual cudaLaunchKernelEx or cudaLaunchAttributeProgrammaticStreamSerialization in moe files
echo "=== Searching for CUDA launch with PDL in MoE-related files ==="
rg -n 'cudaLaunchKernelEx|cudaLaunchAttributeProgrammaticStreamSerialization|cudaLaunchAttribute' csrc/trtllm*.cu | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1616
🏁 Script executed:
#!/bin/bash
# Look at the actual usage in moe_alltoall to see if there's any device checks
echo "=== Full context around MoeA2ADispatchParams usage ==="
sed -n '195,240p' csrc/trtllm_moe_alltoall.cu
echo ""
echo "=== Search for any device/capability checks in trtllm_moe_alltoall.cu ==="
rg -n 'cudaGetDevice|getDevice|device_support|compute_capability|capability|SM|Attribute' csrc/trtllm_moe_alltoall.cu
echo ""
echo "=== Check what Python-side wrapper does with enablePdl ==="
rg -n 'enablePdl|enable_pdl' flashinfer/flashinfer_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 2841
🏁 Script executed:
#!/bin/bash
# Search for Python wrappers that call into this C++ function
echo "=== Looking for Python files calling moeA2ADispatchOp ==="
find . -name "*.py" -type f | xargs grep -l 'moeA2ADispatchOp\|moe_a2a' 2>/dev/null | head -10
# Check the file structure to understand where Python bindings are
echo ""
echo "=== Search for Python binding files ==="
find . -name "*alltoall*.py" -type f
# Look at imports and comments in trtllm_moe_alltoall.cu for context
echo ""
echo "=== Top of trtllm_moe_alltoall.cu for includes and comments ==="
head -50 csrc/trtllm_moe_alltoall.cuRepository: flashinfer-ai/flashinfer
Length of output: 2296
🏁 Script executed:
#!/bin/bash
# Check the Python binding to see how enablePdl is used
echo "=== Python wrapper for moeA2ADispatchOp ==="
head -100 flashinfer/comm/trtllm_moe_alltoall.py
echo ""
echo "=== Search for enablePdl handling in Python ==="
rg -n 'enablePdl|enable_pdl|device_support' flashinfer/comm/trtllm_moe_alltoall.py
echo ""
echo "=== Look for device support checking in flashinfer ==="
rg -n 'device_support_pdl|supports.*pdl|compute.*capability' flashinfer/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 26205
🏁 Script executed:
#!/bin/bash
# Check the implementation of device_support_pdl
echo "=== device_support_pdl implementation ==="
sed -n '687,710p' flashinfer/utils.py
echo ""
echo "=== Check if any C++ code also has device capability checks before PDL usage ==="
rg -B5 -A5 'cudaLaunchAttributeProgrammaticStreamSerialization' csrc/fmhaReduction.cu | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1153
🏁 Script executed:
#!/bin/bash
# Check how device_support_pdl is used in the Python wrapper more carefully
echo "=== Check enable_pdl parameter handling in trtllm_moe_alltoall.py ==="
sed -n '270,320p' flashinfer/comm/trtllm_moe_alltoall.py
echo ""
echo "=== Compare with nvfp4_quantize pattern ==="
sed -n '1234,1250p' flashinfer/quantization/kernels/nvfp4_quantize.py
echo ""
echo "=== Check if MoeA2ADispatchParams is used with any validation in downstream code ==="
rg -n 'enable_pdl.*false|enable_pdl.*False|enable_pdl.*0' csrc/ --type cppRepository: flashinfer-ai/flashinfer
Length of output: 3616
Add device capability validation at the C++ FFI boundary.
The C++ code directly assigns enablePdl to params.enable_pdl without verifying device PDL support. If a caller explicitly passes enable_pdl=true on hardware with compute capability < 9, the CUDA launch will fail. Unlike flashinfer/quantization/kernels/nvfp4_quantize.py which gates even explicit True via device_support_pdl(), the Python wrapper here only auto-detects when enable_pdl is None.
Add a device capability check in C++ before using the PDL launch attribute, or gate explicit True in the Python wrapper to match the pattern in nvfp4_quantize.py (line 1246): enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else False.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_moe_alltoall.cu` at line 200, The code assigns enablePdl directly
to params.enable_pdl (params.enable_pdl = enablePdl) without validating device
PDL support; add a device-capability check before setting params.enable_pdl (or
gate explicit True in the Python wrapper) so that PDL is only enabled on
supported hardware: query the device compute capability (or reuse the existing
device_support_pdl logic) and set params.enable_pdl = true only when enablePdl
is true AND device_support_pdl(device) is true (otherwise set false); update the
C++ FFI boundary where enablePdl is consumed (the assignment to
params.enable_pdl) to perform this check, or mirror the Python pattern
enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else
False.
|
/bot run |
|
pls address failed pre-commit check |
📌 Description
Ports NVIDIA/TensorRT-LLM#10591 and NVIDIA/TensorRT-LLM#11844 to Flashinfer.
Microbenchmark testing between flashinfer and TRT-LLM's implementation of the one-sided MNNVL all-to-all revealed a 15% perf regression in Flashinfer's combine. This is resolved by encorporating the latest updates in TRT-LLM: 1. PDL support and 2. Low precision combine, which is not enabled but has other changes which improve the perf. Added unit tests for both.
Benchmarking
8 nodes x 4 gpus GB200 NVL72. All values in microseconds. Deepseek config (hidden=7168, top_k=8, global experts=256) with EP32 (local experts=8). NVFP4 dispatch, BF16 combine. Cuda graph enabled.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests