Skip to content

feat(moe-a2a): Update nvlink onesided all-to-all#3139

Open
trevor-m wants to merge 4 commits intoflashinfer-ai:mainfrom
trevor-m:feat/b802a3818-float32-combine-accumulators
Open

feat(moe-a2a): Update nvlink onesided all-to-all#3139
trevor-m wants to merge 4 commits intoflashinfer-ai:mainfrom
trevor-m:feat/b802a3818-float32-combine-accumulators

Conversation

@trevor-m
Copy link
Copy Markdown
Contributor

@trevor-m trevor-m commented Apr 21, 2026

📌 Description

Ports NVIDIA/TensorRT-LLM#10591 and NVIDIA/TensorRT-LLM#11844 to Flashinfer.

Microbenchmark testing between flashinfer and TRT-LLM's implementation of the one-sided MNNVL all-to-all revealed a 15% perf regression in Flashinfer's combine. This is resolved by encorporating the latest updates in TRT-LLM: 1. PDL support and 2. Low precision combine, which is not enabled but has other changes which improve the perf. Added unit tests for both.

Benchmarking

8 nodes x 4 gpus GB200 NVL72. All values in microseconds. Deepseek config (hidden=7168, top_k=8, global experts=256) with EP32 (local experts=8). NVFP4 dispatch, BF16 combine. Cuda graph enabled.

BS TRT-LLM Dispatch FI main Dispatch FI new Dispatch TRT-LLM Combine FI main Combine FI new Combine
1 26.02 22.62 23.26 31.04 36.38 29.47
2 26.30 23.49 23.74 30.27 36.77 29.09
4 26.50 22.91 23.55 31.46 37.15 29.28
8 27.71 23.90 24.13 31.84 37.34 29.82
16 26.37 24.54 24.13 32.38 38.37 31.07
32 27.68 26.34 26.30 33.15 38.95 32.35
64 33.92 31.78 32.29 36.90 41.70 36.19
128 47.07 44.35 43.97 46.88 49.44 44.99
256 68.77 66.34 66.98 74.18 74.21 73.38
512 115.01 113.18 113.34 132.42 131.04 130.94
1024 200.19 199.14 199.04 241.50 245.66 241.89
2048 379.10 380.38 377.31 457.57 454.85 452.77

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • FP8 (float8) support for MoE all-to-all combine with FP8→BF16 conversion and low-precision combine mode
    • Programmatic Dependent Launch (PDL) support for safer kernel synchronization; exposed via new enable_pdl flag across ops and params
    • Public APIs extended with use_low_precision and enable_pdl options (C++ and Python wrappers)
  • Tests

    • Added FP8-specific tests and parameterized multi-rank combine tests covering enable_pdl scenarios

trevor-m and others added 2 commits April 20, 2026 16:41
Float32 accumulators, FP8 combine, and related API updates:

- vectorized_combine_impl: float32 acc[TOP_K] instead of uint8 byte accumulation;
  two-pass load-then-upcast (descending j) avoids WAR hazards for InT<float
- InT template param: recv buffer element type (defaults to T); enables FP8 in-place
  combine (InT=fp8_e4m3, T=bf16) without separate conversion kernel
- stride_per_token: byte distance between tokens in recv buffer; equals
  size_per_token normally, differs for FP8 in-place (BF16-stride workspace)
- vec_convert: generic + SM89 PTX f16x2→e4m3x2 + SM100 PTX bf16x2→e4m3x2
- vectorized_quant: SrcT→FP8 with ThreadingPolicy::sync() for WAR safety
- moeA2APrepareCombineKernel: LOW_PRECISION=true path quantizes payload to FP8
- moeA2ACombineKernel: dispatches FP8→BF16 or same-type path based on T
- MoeA2ACombineParams: add use_low_precision flag
- Python API: add use_low_precision=False to moe_a2a_combine / MoeAlltoAll.combine
- Tests: add FP8 combine correctness test (SM>=89, tolerates FP8 quantization error)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Port PDL support from TRT-LLM b802a3818:
- Add getEnvEnablePDL() (TRTLLM_ENABLE_PDL env var, default on SM>=90)
- Add launchWithPdlWhenEnabled() helper using cudaLaunchKernelEx
- Replace all <<<>>> launches with launchWithPdlWhenEnabled in dispatch,
  combine, and sanitize kernels
- Add cudaGridDependencySynchronize() / cudaTriggerProgrammaticLaunchCompletion()
  guards (#if __CUDA_ARCH__ >= 900) in all five kernel entry points
- Reorder moeA2ACombineKernel params: stride_per_token now before local_num_tokens
  to match TRT-LLM signature

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Adds programmatic dependent launch (PDL) support and FP8/low-precision handling to MoE all-to-all kernels and APIs: introduces a PDL-aware host launch helper, threads enable_pdl/use_low_precision flags through C++/Python layers, updates kernels to use PDL-aware launches and FP8 conversion paths, and adds tests.

Changes

Cohort / File(s) Summary
PDL Launch Infrastructure
csrc/nv_internal/tensorrt_llm/common/envUtils.h
Added launchWithPdlWhenEnabled inline template that builds cudaLaunchConfig_t, conditionally sets cudaLaunchAttributeProgrammaticStreamSerialization from enable_pdl, calls cudaLaunchKernelEx, and checks errors.
MoE Kernel Implementation
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Added FP8 dtype handling and conversion utilities, refactored prepare/dispatch/combine paths to support FP8 and cross-type vectorized combine, added enable_pdl kernel parameter and in-kernel PDL synchronization (SM>=900), and replaced direct <<<>>> launches with launchWithPdlWhenEnabled(...).
MoE Kernel API
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Extended MoeA2ADispatchParams and MoeA2ACombineParams with bool enable_pdl; added bool use_low_precision to MoeA2ACombineParams; updated moe_a2a_sanitize_expert_ids_launch signature to accept enable_pdl.
C++ Binding Layer
csrc/trtllm_moe_alltoall.cu
Threaded new flags into exported ops: added enablePdl to dispatch/sanitize ops and useLowPrecision+enablePdl to combine op; added FP8 DL→NV dtype mapping and select BF16 output when useLowPrecision is true.
Python API
flashinfer/comm/trtllm_moe_alltoall.py
Extended custom-op/wrapper signatures to accept optional enable_pdl (auto-detected when None) for dispatch/combine/sanitize; added use_low_precision to combine and MoeAlltoAll.combine.
Tests
tests/comm/test_trtllm_moe_alltoall.py
Parametrized tests to pass enable_pdl through dispatch/sanitize/combine; added use_low_precision test path and new test_moe_combine_fp8 gated by compute capability.

Sequence Diagram

sequenceDiagram
    participant Host as Host Code
    participant Helper as launchWithPdlWhenEnabled
    participant CUDA as CUDA Runtime
    participant Kernel as Device Kernel
    participant Sync as Synchronization

    Host->>Helper: call launchWithPdlWhenEnabled(name, enable_pdl, kernelFn, grid, block, shm, stream, args...)
    Helper->>CUDA: build cudaLaunchConfig_t (grid, block, shm, stream)
    alt enable_pdl == true
        Helper->>CUDA: attach cudaLaunchAttributeProgrammaticStreamSerialization(true)
    else
        Helper->>CUDA: attach cudaLaunchAttributeProgrammaticStreamSerialization(false)
    end
    Helper->>CUDA: cudaLaunchKernelEx(config, kernelFn, args...)
    CUDA->>Kernel: execute kernel on device
    Kernel->>Kernel: optional cudaGridDependencySynchronize() (SM>=900)
    Kernel->>Sync: signal completion
    CUDA->>Helper: return cudaError_t
    Helper->>Host: FLASHINFER_CHECK(error)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci, op: moe, op: moe-routing

Suggested reviewers

  • yzh119
  • aleozlx
  • bkryu
  • nv-yunzheq
  • sricketts
  • samuellees
  • cyx-6
  • saltyminty
  • yongwww
  • jimmyzho
  • nvmbreughe
  • kahyunnam

Poem

🐰 Hops of bytes from rank to rank,

FP8 whispers, BF16's bright flank,
PDL sparks the kernel's spree,
Streams align and launches agree,
I nibble bugs and leave a patchy thank-you!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 24.39% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: updating the NVLink one-sided all-to-all implementation with specific features (PDL support and low-precision combine).
Description check ✅ Passed The PR description covers the main objectives (porting TRT-LLM updates), includes benchmarking results demonstrating performance improvements, and mentions that tests have been added. However, the pre-commit checklist items are not explicitly marked as completed, and test details are minimal.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for Programmatic Dependent Launch (PDL) and FP8 quantization within the MoE All-to-All kernels. It introduces a launchWithPdlWhenEnabled utility to manage kernel launches with PDL attributes and updates the moe_a2a_combine logic to handle FP8 data with BF16 accumulation. Critical feedback was provided regarding the incorrect placement of cudaTriggerProgrammaticLaunchCompletion() in several kernels, including moeA2APrepareDispatchKernel, moeA2APrepareCombineKernel, and moeA2ASanitizeExpertIdsKernel. Calling this function at the start of the kernels creates race conditions by signaling completion before the work is actually performed.

Comment on lines +294 to +297
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The placement of cudaTriggerProgrammaticLaunchCompletion() at the beginning of moeA2APrepareDispatchKernel is incorrect. This function signals to subsequent kernels in the stream (that use PDL) that the current kernel has completed the work they depend on. By calling it at the start, you are signaling completion before the send_counters and local_token_counter have actually been zeroed. This creates a race condition where moeA2ADispatchKernel might start executing before the initialization is finished.

To fix this, cudaTriggerProgrammaticLaunchCompletion() should be called at the end of the kernel, and you must ensure that every block in the grid calls it exactly once, even if they return early due to bounds checks.

Comment on lines +923 to +926
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the dispatch prepare kernel, cudaTriggerProgrammaticLaunchCompletion() is called here before any data has been copied or quantized. The subsequent moeA2ACombineKernel depends on the data being present in the recv_buffer. Triggering completion at the start of this kernel allows the combine kernel to potentially read uninitialized or partially updated data.

Please move the trigger call to the end of the function and restructure the early returns (lines 934, 939, 946) to ensure that cudaTriggerProgrammaticLaunchCompletion() is still called by every block in the grid as required by the CUDA specification.

Comment on lines +1200 to +1203
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cudaTriggerProgrammaticLaunchCompletion() call is misplaced at the start of moeA2ASanitizeExpertIdsKernel. If any subsequent operation depends on the expert IDs being sanitized, it may start prematurely. Move this call to the end of the kernel, ensuring all blocks reach it.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/trtllm_moe_alltoall.cu (1)

319-334: ⚠️ Potential issue | 🔴 Critical

Allocate BF16 output whenever the combine kernel will run the FP8 path.

Line 320 only switches to BF16 when useLowPrecision is true, but line 334 can still set params.dtype = kFP8 for an FP8 payload. The CUDA combine kernel treats kFP8 as “FP8 recv buffer → BF16 output”, so an FP8 payload with useLowPrecision=false allocates an FP8-sized output and then writes BF16 elements into it.

🐛 Proposed fix
-  // For FP8 combine, recv buffers hold FP8 but output is BF16 (upcast during accumulation).
-  DLDataType outputDtype = useLowPrecision ? dl_bfloat16 : payload.dtype();
+  auto const payloadDtypeCode = encode_dlpack_dtype(payload.dtype());
+  bool const fp8Payload = payloadDtypeCode == float8_e4m3fn_code;
+  bool const effectiveLowPrecision = useLowPrecision || fp8Payload;
+
+  // For FP8 combine, recv buffers hold FP8 but output is BF16 (upcast during accumulation).
+  DLDataType outputDtype = effectiveLowPrecision ? dl_bfloat16 : payload.dtype();
   Tensor output = alloc_tensor({localNumTokens, elementsPerToken}, outputDtype, payload.device());
@@
-  params.use_low_precision = useLowPrecision;
+  params.use_low_precision = effectiveLowPrecision;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_moe_alltoall.cu` around lines 319 - 334, The output tensor is
currently allocated as BF16 only when useLowPrecision is true but the combine
kernel treats kFP8 payloads as "FP8 recv → BF16 output", so when params.dtype is
set to kFP8 you must allocate BF16 output too; change the allocation logic for
outputDtype (used by alloc_tensor and output.data_ptr()) to pick BF16 when
either useLowPrecision is true OR the payload dtype indicates FP8 (check
payload.dtype() or equivalently toNvDataType(payload.dtype()) == kFP8) so
params.dtype can remain as toNvDataType(payload.dtype()) while the output buffer
has the correct BF16 shape.
flashinfer/comm/trtllm_moe_alltoall.py (1)

320-345: ⚠️ Potential issue | 🟠 Major

Add SM89+ runtime check for use_low_precision=True path.

The tests skip FP8 combine on SM<89, but the public APIs (moe_a2a_combine at line 320–346 and MoeAlltoAll.combine at lines 659–698) allow callers to pass use_low_precision=True without checking device capability. Add a runtime validation before forwarding to the custom op so unsupported GPUs fail predictably. Follow the pattern in flashinfer/norm/__init__.py:703–704 (check get_compute_capability and raise RuntimeError if SM<89 when FP8 is requested).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_moe_alltoall.py` around lines 320 - 345, Add a runtime
GPU capability check when the user requests FP8: in moe_a2a_combine (before
calling get_moe_alltoall_module().moe_a2a_combine) and likewise in the
MoeAlltoAll.combine method, call get_compute_capability() and if
use_low_precision is True and the returned compute capability is less than 89
raise a RuntimeError with a clear message that SM<89 does not support
low-precision FP8 combine; this mirrors the pattern used in
flashinfer/norm/__init__.py and ensures unsupported devices fail predictably
instead of invoking the custom op.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 294-297: The PDL completion calls (cudaGridDependencySynchronize()
and cudaTriggerProgrammaticLaunchCompletion()) are being invoked before kernels
publish the state their consumers read; move these calls to the end of each
affected kernel so they occur after all required writes and an appropriate
system-wide visibility fence. Concretely, in moeA2APrepareDispatchKernel,
moeA2ADispatchKernel, moeA2APrepareCombineKernel, and
moeA2ASanitizeExpertIdsKernel remove the early calls around lines
~295/415/923/1200 and instead place cudaGridDependencySynchronize() followed by
__threadfence_system() (or ensure an existing __threadfence_system() runs first)
and then cudaTriggerProgrammaticLaunchCompletion() as the final operations in
each kernel, after zeroing/incrementing send_counters/local_token_counter, after
writing recv_counters/completion_flags, and after incrementing flag_val/staging
payloads respectively so all produced state is published before signaling
completion.

In `@tests/comm/test_trtllm_moe_alltoall.py`:
- Around line 600-603: Replace the raw torch.cuda.get_device_capability(0) check
in the pytest.mark.skipif decorator with
flashinfer.utils.get_compute_capability: import get_compute_capability and call
get_compute_capability(torch.device("cuda")), convert the returned (major,
minor) tuple into a numeric capability via major * 10 + minor and use >= 89 to
gate the FP8 test; update the pytest.mark.skipif condition that currently
references torch.cuda.get_device_capability to use this computed value (refer to
get_compute_capability and the pytest.mark.skipif decorator in the test).

---

Outside diff comments:
In `@csrc/trtllm_moe_alltoall.cu`:
- Around line 319-334: The output tensor is currently allocated as BF16 only
when useLowPrecision is true but the combine kernel treats kFP8 payloads as "FP8
recv → BF16 output", so when params.dtype is set to kFP8 you must allocate BF16
output too; change the allocation logic for outputDtype (used by alloc_tensor
and output.data_ptr()) to pick BF16 when either useLowPrecision is true OR the
payload dtype indicates FP8 (check payload.dtype() or equivalently
toNvDataType(payload.dtype()) == kFP8) so params.dtype can remain as
toNvDataType(payload.dtype()) while the output buffer has the correct BF16
shape.

In `@flashinfer/comm/trtllm_moe_alltoall.py`:
- Around line 320-345: Add a runtime GPU capability check when the user requests
FP8: in moe_a2a_combine (before calling
get_moe_alltoall_module().moe_a2a_combine) and likewise in the
MoeAlltoAll.combine method, call get_compute_capability() and if
use_low_precision is True and the returned compute capability is less than 89
raise a RuntimeError with a clear message that SM<89 does not support
low-precision FP8 combine; this mirrors the pattern used in
flashinfer/norm/__init__.py and ensures unsupported devices fail predictably
instead of invoking the custom op.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 83a0c88e-165b-407e-8741-9d94c340bbac

📥 Commits

Reviewing files that changed from the base of the PR and between 9e3d8b9 and 752476d.

📒 Files selected for processing (7)
  • csrc/nv_internal/cpp/common/envUtils.cpp
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
  • csrc/trtllm_moe_alltoall.cu
  • flashinfer/comm/trtllm_moe_alltoall.py
  • tests/comm/test_trtllm_moe_alltoall.py

Comment on lines +294 to +297
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

find . -name "moeAlltoAllKernels.cu" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 150


🏁 Script executed:

wc -l ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 155


🏁 Script executed:

# Examine the primary flagged location (lines 294-297) with context
sed -n '280,310p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

# Examine second location (lines 415-417)
echo "---"
sed -n '405,430p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

# Examine third location (lines 923-926)
echo "---"
sed -n '910,940p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

# Examine fourth location (lines 1200-1203)
echo "---"
sed -n '1190,1220p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 5850


🏁 Script executed:

# Get more context around the dispatch kernel to understand data flow
sed -n '250,330p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 4564


🏁 Script executed:

# Examine the moeA2ADispatchKernel that comes after prepare-dispatch to see data dependencies
sed -n '330,450p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 5493


🏁 Script executed:

# Check what comes before PrepareDispatch to understand if previous kernel's trigger is expected
sed -n '1,250p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -n "cudaTriggerProgrammaticLaunchCompletion\|__global__" | tail -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Let me check the full context of moeA2ADispatchKernel - how much work is done after trigger
sed -n '416,450p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1438


🏁 Script executed:

# Check if there's documentation or comments about kernel launch order/dependencies
grep -n "moeA2APrepare\|moeA2ADispatch\|moeA2ACombine\|moeA2ASanitize" ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1454


🏁 Script executed:

# Check the launchWithPdlWhenEnabled function to understand PDL launch semantics
grep -n "launchWithPdlWhenEnabled" ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -5

# Now search for the definition of this function
grep -B5 -A20 "void launchWithPdlWhenEnabled\|launchWithPdlWhenEnabled.*{" ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 502


🏁 Script executed:

# Find where launchWithPdlWhenEnabled is defined
find . -name "*.cu" -o -name "*.cuh" -o -name "*.hpp" -o -name "*.h" | xargs grep -l "launchWithPdlWhenEnabled" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 200


🏁 Script executed:

grep -B5 -A30 "launchWithPdlWhenEnabled" ./csrc/nv_internal/tensorrt_llm/common/envUtils.h | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1147


🏁 Script executed:

# Let me examine more carefully what happens after each trigger - check the complete flow for one cycle
sed -n '294,310p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

echo "---Dispatch trigger and what follows---"
sed -n '414,430p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

echo "---PrepareCombine trigger and what follows---"
sed -n '923,950p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 2530


🏁 Script executed:

# Let me look at the exact sequence: what does moeA2ADispatchKernel write after its trigger at 416?
sed -n '414,470p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 2574


🏁 Script executed:

# Check what PrepareDispatch writes
sed -n '292,315p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

echo "---What DispatchKernel reads from PrepareDispatch---"
# Check what DispatchKernel reads
sed -n '320,380p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -n "send_counters\|local_token_counter"

echo "---What DispatchKernel writes after trigger---"
sed -n '414,450p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 2755


🏁 Script executed:

# Check the sanitize kernel trigger
sed -n '1196,1215p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

echo "---What combine kernel does before sanitize---"
# Look at end of combine kernel
sed -n '1050,1090p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 3391


🏁 Script executed:

# Let me verify the exact sequence by checking what happens between trigger and actual data writes
# For PrepareDispatch specifically
sed -n '292,315p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

echo "---"
# For Dispatch kernel - check full flow after trigger
sed -n '414,445p' ./csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 2009


Move PDL completion triggers to after all required state is published.

These kernels trigger cudaTriggerProgrammaticLaunchCompletion() before the data they produce is ready, creating race windows for dependent PDL-enabled kernels. Specifically:

  • moeA2APrepareDispatchKernel (line 295): Triggers before zeroing send_counters and local_token_counter (lines 297–301), which moeA2ADispatchKernel reads and increments.
  • moeA2ADispatchKernel (line 416): Triggers before writing recv_counters and completion_flags (lines 425–440), which moeA2APrepareCombineKernel depends on reading.
  • moeA2APrepareCombineKernel (line 925): Triggers before incrementing flag_val (line 928) and staging payload data, which moeA2ACombineKernel requires.
  • moeA2ASanitizeExpertIdsKernel (line 1202): Triggers before its consumer has guaranteed visibility of previously written state.

Move each trigger to the end of the kernel, after __threadfence_system() or equivalent barrier and all required writes are complete.

Also applies to: 415–417, 923–926, 1200–1203

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 294 - 297, The PDL completion calls
(cudaGridDependencySynchronize() and cudaTriggerProgrammaticLaunchCompletion())
are being invoked before kernels publish the state their consumers read; move
these calls to the end of each affected kernel so they occur after all required
writes and an appropriate system-wide visibility fence. Concretely, in
moeA2APrepareDispatchKernel, moeA2ADispatchKernel, moeA2APrepareCombineKernel,
and moeA2ASanitizeExpertIdsKernel remove the early calls around lines
~295/415/923/1200 and instead place cudaGridDependencySynchronize() followed by
__threadfence_system() (or ensure an existing __threadfence_system() runs first)
and then cudaTriggerProgrammaticLaunchCompletion() as the final operations in
each kernel, after zeroing/incrementing send_counters/local_token_counter, after
writing recv_counters/completion_flags, and after incrementing flag_val/staging
payloads respectively so all produced state is published before signaling
completion.

Comment on lines +600 to +603
@pytest.mark.skipif(
torch.cuda.get_device_capability(0) < (8, 9),
reason="FP8 (e4m3) requires SM>=89 (Ada Lovelace or newer)",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Locate supported compute-capability helper APIs for replacing the raw torch capability check.

rg -n -C3 'def (is_sm[0-9a-z_]*_supported|.*compute_capability.*supported)|is_compute_capability_supported|supported_compute_capability' --type py

Repository: flashinfer-ai/flashinfer

Length of output: 46392


🏁 Script executed:

# Find the test file and identify what API is being tested
head -n 650 tests/comm/test_trtllm_moe_alltoall.py | tail -n 60

Repository: flashinfer-ai/flashinfer

Length of output: 2542


🏁 Script executed:

# Search for FP8 or SM89 related helpers in the codebase
rg -n 'def.*fp8|def.*sm89|SM89|FP8.*supported' --type py flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what patterns are used in other comm tests
rg -n '@pytest.mark.skipif|is_compute_capability_supported' --type py tests/comm/ | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 498


🏁 Script executed:

# Find what MOE/comm functions are being called in the test
grep -n 'from flashinfer\|import.*moe\|cute_dsl\|trtllm' tests/comm/test_trtllm_moe_alltoall.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1366


🏁 Script executed:

# Search for FP8 MOE functions in codebase
rg -n 'cute_dsl_fused_moe.*fp8|trtllm.*fp8' --type py flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 3335


🏁 Script executed:

# Check if there's a utility to check compute capability ranges or SM89 support
rg -n 'def.*compute_capability|def.*sm8[0-9]' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 274


🏁 Script executed:

# Look at how the test file imports and uses utilities
head -50 tests/comm/test_trtllm_moe_alltoall.py

Repository: flashinfer-ai/flashinfer

Length of output: 1380


🏁 Script executed:

# Check what compute capabilities the trtllm_moe_alltoall module supports
rg -n 'supported_compute_capability|is_compute_capability' flashinfer/comm/trtllm_moe_alltoall.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the entire MoeAlltoAll or functions in trtllm_moe_alltoall to see if they have decorators
head -200 flashinfer/comm/trtllm_moe_alltoall.py | grep -A 5 'def \|@'

Repository: flashinfer-ai/flashinfer

Length of output: 2207


🏁 Script executed:

# Check if there are any helpers for range-based compute capability checks
rg -n 'def.*compute_capability.*ge|def.*compute_capability.*range|def.*sm8' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at how get_compute_capability works and what it returns
sed -n '264,280p' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 572


🏁 Script executed:

# Check if trtllm_moe_alltoall or its functions are decorated with supported_compute_capability
grep -B 10 'class MoeAlltoAll\|def moe_a2a\|def get_moe_alltoall_module' flashinfer/comm/trtllm_moe_alltoall.py | grep -E '@|def |class ' | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 790


🏁 Script executed:

# Check if there's a broader utility for SM89+ or FP8 support checks
rg -n 'sm.*89|fp8.*supported' flashinfer/utils.py -i

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the entire lines 543-600 of flashinfer/utils.py to see all the is_sm*_supported functions
sed -n '543,600p' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 2308


🏁 Script executed:

# Search for how other test files handle range-based compute capability checks
rg -B 3 -A 3 'get_compute_capability.*8.*9|compute_capability.*<.*8.*9' tests/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what patterns are used in other places for SM89+ checks
rg -B 5 -A 5 'SM.*89|sm.*89|8.*9' tests/ --type py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 2377


Use flashinfer.utils.get_compute_capability() to gate this FP8 test.

Replace the raw torch.cuda.get_device_capability(0) check with get_compute_capability() from flashinfer.utils, following the repository's test gating pattern. Compute the numeric capability (major * 10 + minor) and check if it's >= 89:

Suggested fix
from flashinfer.utils import get_compute_capability

`@pytest.mark.skipif`(
    get_compute_capability(torch.device("cuda")) < (8, 9),
    reason="FP8 (e4m3) requires SM>=89 (Ada Lovelace or newer)",
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_trtllm_moe_alltoall.py` around lines 600 - 603, Replace the
raw torch.cuda.get_device_capability(0) check in the pytest.mark.skipif
decorator with flashinfer.utils.get_compute_capability: import
get_compute_capability and call get_compute_capability(torch.device("cuda")),
convert the returned (major, minor) tuple into a numeric capability via major *
10 + minor and use >= 89 to gate the FP8 test; update the pytest.mark.skipif
condition that currently references torch.cuda.get_device_capability to use this
computed value (refer to get_compute_capability and the pytest.mark.skipif
decorator in the test).


bool getEnvEnablePDL() {
static bool const kEnablePDL = []() -> bool {
auto const val = getIntEnv("TRTLLM_ENABLE_PDL");
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does FlashInfer have no env var for enabling PDL? TRTLLM_ENABLE_PDL sounds pretty weird in FlashInfer

Copy link
Copy Markdown
Contributor Author

@trevor-m trevor-m Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I had made a change to use a bool flag enable_pdl to match other Flashinfer APIs instead of the env var, but I forgot to push it. Please check again, thanks for reviewing!

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/comm/trtllm_moe_alltoall.py (1)

119-131: ⚠️ Potential issue | 🟡 Minor

Document use_low_precision in the custom-op wrapper docstring.

The signature now accepts use_low_precision, but the Args block jumps from payload_in_workspace to enable_pdl. Add the new parameter description to keep generated API docs accurate.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_moe_alltoall.py` around lines 119 - 131, The docstring
Args block for the custom-op wrapper is missing the new parameter description
for use_low_precision; update the docstring in trtllm_moe_alltoall.py (the
wrapper function that now accepts use_low_precision) by inserting a one-line
entry describing use_low_precision (e.g., boolean to enable reduced-precision
computation) between payload_in_workspace and enable_pdl so the generated API
docs include it and match the function signature.
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)

553-574: ⚠️ Potential issue | 🟡 Minor

Commit the clang-format output for the launcher changes.

Pre-commit is failing because clang-format rewrote argument wrapping and general C++ formatting in this file. Please run formatting and commit the result.

Also applies to: 1102-1117, 1184-1195, 1228-1230

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 553 - 574, The diff shows formatting changes to kernel launcher
calls (SWITCH_TOP_K, moeA2ADispatchKernel, launchWithPdlWhenEnabled) that
clang-format applied but were not committed; run your repository's clang-format
configuration over this file (and the other affected ranges around the
SWITCH_TOP_K blocks at the noted locations), stage the resulting formatted
file(s), and commit so pre-commit no longer fails—ensure the argument wrapping
for the launchWithPdlWhenEnabled calls and the surrounding SWITCH_TOP_K blocks
(using WarpPolicy/BlockPolicy and TOP_K) match the clang-formatted output before
pushing.
♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)

292-309: ⚠️ Potential issue | 🔴 Critical

Move PDL completion triggers after publishing the state consumed downstream.

The early cudaTriggerProgrammaticLaunchCompletion() calls still let dependent launches start before required data is visible: prepare-dispatch triggers before zeroing counters/incrementing flag_val, dispatch triggers before publishing recv_counters/completion flags, prepare-combine triggers before flag_val/payload staging, and sanitize triggers before invalid expert IDs are written. Move each trigger after the relevant writes and a system-scope fence. This was already raised on an earlier revision and still appears present.

Also applies to: 320-491, 918-971, 1199-1220

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 292 - 309, In moeA2APrepareDispatchKernel the
cudaTriggerProgrammaticLaunchCompletion() (and similar triggers in the other
kernels called out) fires before the kernel publishes state (zeroing
send_counters/local_token_counter and incrementing *flag_val_ptr), which can
allow dependent PDL launches to observe stale data; fix by moving the
cudaTriggerProgrammaticLaunchCompletion() call to after all writes that
consumers depend on (e.g., after send_counters[idx]=0, *local_token_counter=0
and *flag_val_ptr++ for moeA2APrepareDispatchKernel) and insert a system-scoped
fence (e.g., asm volatile("membar.sys" ::: "memory") or the appropriate CUDA
API) immediately before the trigger so the published state is visible to
dependent programmatic launches; apply the same pattern to the other kernels
referenced (prepare-dispatch, dispatch, prepare-combine, sanitize) so triggers
occur only after their consumer-visible writes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/nv_internal/tensorrt_llm/common/envUtils.h`:
- Around line 108-125: Run clang-format and commit the formatted version of the
helper function launchWithPdlWhenEnabled in envUtils.h: apply your project's
clang-format configuration to reformat the include order, signature/wrapping and
function body (the cudaLaunchConfig_t setup, attrs array, and cudaLaunchKernelEx
call), then stage and commit the changed file so the pre-commit CI hook stops
failing.

In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 874-895: The ThreadingPolicy::sync() inside the for-loop in
vectorized_quant_impl is divergent and can deadlock because some threads exit
the loop early; replace the variable-length loop with a fixed iteration count
computed from num_elements and stride so every thread executes the same number
of iterations, compute e = ThreadingPolicy::offset()*VEC_SIZE + iter*stride each
iteration, then guard loads/stores with an if (e < num_elements) so only valid
lanes touch src/dst; keep ThreadingPolicy::sync() as a per-iteration barrier
(not conditional) and leave vec_t, in_vec.load, vec_convert and out_vec.store
unchanged.

---

Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 553-574: The diff shows formatting changes to kernel launcher
calls (SWITCH_TOP_K, moeA2ADispatchKernel, launchWithPdlWhenEnabled) that
clang-format applied but were not committed; run your repository's clang-format
configuration over this file (and the other affected ranges around the
SWITCH_TOP_K blocks at the noted locations), stage the resulting formatted
file(s), and commit so pre-commit no longer fails—ensure the argument wrapping
for the launchWithPdlWhenEnabled calls and the surrounding SWITCH_TOP_K blocks
(using WarpPolicy/BlockPolicy and TOP_K) match the clang-formatted output before
pushing.

In `@flashinfer/comm/trtllm_moe_alltoall.py`:
- Around line 119-131: The docstring Args block for the custom-op wrapper is
missing the new parameter description for use_low_precision; update the
docstring in trtllm_moe_alltoall.py (the wrapper function that now accepts
use_low_precision) by inserting a one-line entry describing use_low_precision
(e.g., boolean to enable reduced-precision computation) between
payload_in_workspace and enable_pdl so the generated API docs include it and
match the function signature.

---

Duplicate comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`:
- Around line 292-309: In moeA2APrepareDispatchKernel the
cudaTriggerProgrammaticLaunchCompletion() (and similar triggers in the other
kernels called out) fires before the kernel publishes state (zeroing
send_counters/local_token_counter and incrementing *flag_val_ptr), which can
allow dependent PDL launches to observe stale data; fix by moving the
cudaTriggerProgrammaticLaunchCompletion() call to after all writes that
consumers depend on (e.g., after send_counters[idx]=0, *local_token_counter=0
and *flag_val_ptr++ for moeA2APrepareDispatchKernel) and insert a system-scoped
fence (e.g., asm volatile("membar.sys" ::: "memory") or the appropriate CUDA
API) immediately before the trigger so the published state is visible to
dependent programmatic launches; apply the same pattern to the other kernels
referenced (prepare-dispatch, dispatch, prepare-combine, sanitize) so triggers
occur only after their consumer-visible writes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a2a02583-0804-4b03-9257-1bc026ecdfbd

📥 Commits

Reviewing files that changed from the base of the PR and between 752476d and 3dca42c.

📒 Files selected for processing (6)
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
  • csrc/trtllm_moe_alltoall.cu
  • flashinfer/comm/trtllm_moe_alltoall.py
  • tests/comm/test_trtllm_moe_alltoall.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
  • tests/comm/test_trtllm_moe_alltoall.py
  • csrc/trtllm_moe_alltoall.cu

Comment on lines +108 to +125
template <typename KernelFn, typename... Args>
inline void launchWithPdlWhenEnabled(char const* name, bool enable_pdl, KernelFn kernelFn,
dim3 grid, dim3 block, size_t dynamicShmSize, cudaStream_t stream, Args&&... args)
{
cudaLaunchConfig_t kernelConfig;
kernelConfig.gridDim = grid;
kernelConfig.blockDim = block;
kernelConfig.dynamicSmemBytes = dynamicShmSize;
kernelConfig.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
kernelConfig.attrs = attrs;
kernelConfig.numAttrs = 1;
cudaError_t e = cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...);
FLASHINFER_CHECK(e == cudaSuccess, "cudaLaunchKernelEx (", name, ") failed: ",
cudaGetErrorString(e));
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Commit the clang-format output for this helper.

Pre-commit is failing because clang-format rewrote this file’s include order/signature/body wrapping. Please run formatting and commit the result so CI can pass.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/common/envUtils.h` around lines 108 - 125, Run
clang-format and commit the formatted version of the helper function
launchWithPdlWhenEnabled in envUtils.h: apply your project's clang-format
configuration to reformat the include order, signature/wrapping and function
body (the cudaLaunchConfig_t setup, attrs array, and cudaLaunchKernelEx call),
then stage and commit the changed file so the pre-commit CI hook stops failing.

Comment on lines +874 to +895
template <int VEC_SIZE, typename ThreadingPolicy, typename SrcT, typename DstT>
__device__ void vectorized_quant_impl(DstT* dst, SrcT const* src, int num_elements) {
using flashinfer::vec_t;

int const stride = ThreadingPolicy::stride() * VEC_SIZE;

for (int e = ThreadingPolicy::offset() * VEC_SIZE; e < num_elements; e += stride) {
vec_t<SrcT, VEC_SIZE> in_vec;
in_vec.load(src + e);

// Sync to ensure all threads have loaded their input vectors before any thread starts writing output.
// This avoids write-after-read hazards in the FP8 in-place case where the output of this kernel is
// read by the next iteration as input. Without this sync, some threads might start writing their
// output (DstT) before other threads have loaded their input (SrcT), causing the load to read partially
// updated data.
ThreadingPolicy::sync();

vec_t<DstT, VEC_SIZE> out_vec;
vec_convert(out_vec, in_vec);
out_vec.store(dst + e);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

find . -name "moeAlltoAllKernels.cu" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 150


🏁 Script executed:

cd csrc && find . -path "*nv_internal/tensorrt_llm*" -name "*.cu" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1972


🏁 Script executed:

rg "vectorized_quant_impl" --type cu --type cuh -A 30 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

rg "class ThreadingPolicy\|struct ThreadingPolicy" --type cu --type cuh -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

rg "ThreadingPolicy.*sync\|sync.*ThreadingPolicy" --type cu --type cuh -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

wc -l csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 153


🏁 Script executed:

sed -n '870,900p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1472


🏁 Script executed:

rg "ThreadingPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 5 -A 10 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4615


🏁 Script executed:

rg "struct.*Threading\|class.*Threading" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/ -A 15

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg "ThreadingPolicy::sync\|__syncthreads" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 3 -A 1

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '800,880p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4062


🏁 Script executed:

rg "for.*offset\(\)" csrc/nv_internal/tensorrt_llm/kernels/ -B 2 -A 8 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 5730


🏁 Script executed:

sed -n '860,920p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 3065


🏁 Script executed:

rg "vectorized_quant<" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1363


🏁 Script executed:

rg "struct.*ThreadingPolicy\|class.*ThreadingPolicy" csrc/nv_internal/tensorrt_llm/ -A 8

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1,100p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -A 10 "ThreadingPolicy\|sync"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg "ThreadingPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu -B 15 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2724


🏁 Script executed:

sed -n '1,700p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | tail -100

Repository: flashinfer-ai/flashinfer

Length of output: 4247


🏁 Script executed:

rg "moeA2APrepareCombineKernel" csrc/nv_internal/tensorrt_llm/ -A 3 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1722


🏁 Script executed:

sed -n '1,50p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1631


🏁 Script executed:

rg "BlockPolicy\|WarpPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

grep -n "struct WarpPolicy\|struct BlockPolicy" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 116


🏁 Script executed:

sed -n '300,450p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Repository: flashinfer-ai/flashinfer

Length of output: 5947


Remove divergent sync from loop in vectorized_quant_impl.

Line 889: The ThreadingPolicy::sync() call inside the loop condition e < num_elements causes deadlock. Threads that exit early (when e >= num_elements) never reach the sync, while others wait indefinitely. This divergence is unsafe for both __syncwarp() (WarpPolicy) and __syncthreads() (BlockPolicy).

Fix: Use fixed iteration count with inner guards
 template <int VEC_SIZE, typename ThreadingPolicy, typename SrcT, typename DstT>
 __device__ void vectorized_quant_impl(DstT* dst, SrcT const* src, int num_elements) {
   using flashinfer::vec_t;
 
   int const stride = ThreadingPolicy::stride() * VEC_SIZE;
+  int const iterations = ceilDiv(num_elements, stride);
 
-  for (int e = ThreadingPolicy::offset() * VEC_SIZE; e < num_elements; e += stride) {
+  for (int iter = 0; iter < iterations; ++iter) {
+    int const e = ThreadingPolicy::offset() * VEC_SIZE + iter * stride;
+    bool const valid = e < num_elements;
     vec_t<SrcT, VEC_SIZE> in_vec;
-    in_vec.load(src + e);
+    if (valid) {
+      in_vec.load(src + e);
+    }
 
     // Sync to ensure all threads have loaded their input vectors before any thread starts writing output.
     // This avoids write-after-read hazards in the FP8 in-place case where the output of this kernel is
     // read by the next iteration as input. Without this sync, some threads might start writing their
     // output (DstT) before other threads have loaded their input (SrcT), causing the load to read partially
     // updated data.
     ThreadingPolicy::sync();
 
-    vec_t<DstT, VEC_SIZE> out_vec;
-    vec_convert(out_vec, in_vec);
-    out_vec.store(dst + e);
+    if (valid) {
+      vec_t<DstT, VEC_SIZE> out_vec;
+      vec_convert(out_vec, in_vec);
+      out_vec.store(dst + e);
+    }
   }
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu`
around lines 874 - 895, The ThreadingPolicy::sync() inside the for-loop in
vectorized_quant_impl is divergent and can deadlock because some threads exit
the loop early; replace the variable-length loop with a fixed iteration count
computed from num_elements and stride so every thread executes the same number
of iterations, compute e = ThreadingPolicy::offset()*VEC_SIZE + iter*stride each
iteration, then guard loads/stores with an if (e < num_elements) so only valid
lanes touch src/dst; keep ThreadingPolicy::sync() as a per-iteration barrier
(not conditional) and leave vec_t, in_vec.load, vec_convert and out_vec.store
unchanged.

@trevor-m trevor-m force-pushed the feat/b802a3818-float32-combine-accumulators branch from 3dca42c to bfcc10f Compare April 22, 2026 18:41
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/trtllm_moe_alltoall.py (1)

297-310: ⚠️ Potential issue | 🟠 Major

Gate explicit enable_pdl=True with device support in all three public wrappers.

These functions only auto-detect when enable_pdl is None, allowing a caller to pass True and force PDL on unsupported devices. Implement the tri-state pattern: only False forces PDL off; None and True should enable it only when the device supports it.

Proposed fix
-    if enable_pdl is None:
-        enable_pdl = device_support_pdl(token_selected_experts.device)
+    enable_pdl = (
+        device_support_pdl(token_selected_experts.device)
+        if enable_pdl is not False
+        else False
+    )

Apply the same pattern to moe_a2a_combine (line 347) and moe_a2a_sanitize_expert_ids (line 374).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_moe_alltoall.py` around lines 297 - 310, The wrappers
currently only auto-detect PDL when enable_pdl is None, allowing callers to
force True on unsupported devices; change the tri-state handling so only
enable_pdl == False forces PDL off, otherwise consult device_support_pdl(...) to
decide: replace the current "if enable_pdl is None: enable_pdl =
device_support_pdl(token_selected_experts.device)" logic in the public wrappers
moe_a2a_dispatch (the block calling get_moe_alltoall_module().moe_a2a_dispatch),
moe_a2a_combine, and moe_a2a_sanitize_expert_ids with something like: if
enable_pdl is False: keep False else set enable_pdl =
device_support_pdl(token_selected_experts.device) so that both None and True are
gated by device_support_pdl(token_selected_experts.device).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_moe_alltoall.cu`:
- Line 200: The code assigns enablePdl directly to params.enable_pdl
(params.enable_pdl = enablePdl) without validating device PDL support; add a
device-capability check before setting params.enable_pdl (or gate explicit True
in the Python wrapper) so that PDL is only enabled on supported hardware: query
the device compute capability (or reuse the existing device_support_pdl logic)
and set params.enable_pdl = true only when enablePdl is true AND
device_support_pdl(device) is true (otherwise set false); update the C++ FFI
boundary where enablePdl is consumed (the assignment to params.enable_pdl) to
perform this check, or mirror the Python pattern enable_pdl =
device_support_pdl(input.device) if enable_pdl is not False else False.

---

Outside diff comments:
In `@flashinfer/comm/trtllm_moe_alltoall.py`:
- Around line 297-310: The wrappers currently only auto-detect PDL when
enable_pdl is None, allowing callers to force True on unsupported devices;
change the tri-state handling so only enable_pdl == False forces PDL off,
otherwise consult device_support_pdl(...) to decide: replace the current "if
enable_pdl is None: enable_pdl =
device_support_pdl(token_selected_experts.device)" logic in the public wrappers
moe_a2a_dispatch (the block calling get_moe_alltoall_module().moe_a2a_dispatch),
moe_a2a_combine, and moe_a2a_sanitize_expert_ids with something like: if
enable_pdl is False: keep False else set enable_pdl =
device_support_pdl(token_selected_experts.device) so that both None and True are
gated by device_support_pdl(token_selected_experts.device).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d5a9e84a-34e1-4de1-8931-4570434f2f89

📥 Commits

Reviewing files that changed from the base of the PR and between 3dca42c and bfcc10f.

📒 Files selected for processing (6)
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
  • csrc/trtllm_moe_alltoall.cu
  • flashinfer/comm/trtllm_moe_alltoall.py
  • tests/comm/test_trtllm_moe_alltoall.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h
  • tests/comm/test_trtllm_moe_alltoall.py
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu


tl_throughput::MoeA2ADispatchParams params{};
params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken();
params.enable_pdl = enablePdl;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for existing CUDA compute-capability / PDL support helpers.
rg -n -C3 'device_support_pdl|ProgrammaticStreamSerialization|cudaDeviceGetAttribute|cudaDevAttrComputeCapability|cudaGetDeviceProperties|get_compute_capability|compute capability|enable_pdl|enablePdl' --glob '!**/build/**' --glob '!**/dist/**'

Repository: flashinfer-ai/flashinfer

Length of output: 50382


🏁 Script executed:

#!/bin/bash
# Look for PDL-related device support checks or compute capability guards in C++
rg -n 'compute.capability|sm_\d+|CC_|cuda.*Capability|supports.*pdl|pdl.*support' --type cpp --type cuda

# Also check how trtllm_moe_alltoall passes enablePdl downstream
echo "=== Searching for where enablePdl is used in trtllm_moe_alltoall.cu ==="
rg -n -A5 -B2 'enablePdl' csrc/trtllm_moe_alltoall.cu | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 2933


🏁 Script executed:

#!/bin/bash
# Find where MoeA2ADispatchParams is defined and used
echo "=== Searching for MoeA2ADispatchParams definition ==="
rg -n 'struct MoeA2ADispatchParams|class MoeA2ADispatchParams' --type-list | head -20

# Find references to tl_throughput in the codebase
echo "=== Searching for tl_throughput includes and usage ==="
rg -n 'tl_throughput' csrc/trtllm_moe_alltoall.cu | head -30

# Search for actual cudaLaunchKernelEx or cudaLaunchAttributeProgrammaticStreamSerialization in moe files
echo "=== Searching for CUDA launch with PDL in MoE-related files ==="
rg -n 'cudaLaunchKernelEx|cudaLaunchAttributeProgrammaticStreamSerialization|cudaLaunchAttribute' csrc/trtllm*.cu | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1616


🏁 Script executed:

#!/bin/bash
# Look at the actual usage in moe_alltoall to see if there's any device checks
echo "=== Full context around MoeA2ADispatchParams usage ==="
sed -n '195,240p' csrc/trtllm_moe_alltoall.cu

echo ""
echo "=== Search for any device/capability checks in trtllm_moe_alltoall.cu ==="
rg -n 'cudaGetDevice|getDevice|device_support|compute_capability|capability|SM|Attribute' csrc/trtllm_moe_alltoall.cu

echo ""
echo "=== Check what Python-side wrapper does with enablePdl ==="
rg -n 'enablePdl|enable_pdl' flashinfer/flashinfer_alltoall.py

Repository: flashinfer-ai/flashinfer

Length of output: 2841


🏁 Script executed:

#!/bin/bash
# Search for Python wrappers that call into this C++ function
echo "=== Looking for Python files calling moeA2ADispatchOp ==="
find . -name "*.py" -type f | xargs grep -l 'moeA2ADispatchOp\|moe_a2a' 2>/dev/null | head -10

# Check the file structure to understand where Python bindings are
echo ""
echo "=== Search for Python binding files ==="
find . -name "*alltoall*.py" -type f

# Look at imports and comments in trtllm_moe_alltoall.cu for context
echo ""
echo "=== Top of trtllm_moe_alltoall.cu for includes and comments ==="
head -50 csrc/trtllm_moe_alltoall.cu

Repository: flashinfer-ai/flashinfer

Length of output: 2296


🏁 Script executed:

#!/bin/bash
# Check the Python binding to see how enablePdl is used
echo "=== Python wrapper for moeA2ADispatchOp ==="
head -100 flashinfer/comm/trtllm_moe_alltoall.py

echo ""
echo "=== Search for enablePdl handling in Python ==="
rg -n 'enablePdl|enable_pdl|device_support' flashinfer/comm/trtllm_moe_alltoall.py

echo ""
echo "=== Look for device support checking in flashinfer ==="
rg -n 'device_support_pdl|supports.*pdl|compute.*capability' flashinfer/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 26205


🏁 Script executed:

#!/bin/bash
# Check the implementation of device_support_pdl
echo "=== device_support_pdl implementation ==="
sed -n '687,710p' flashinfer/utils.py

echo ""
echo "=== Check if any C++ code also has device capability checks before PDL usage ==="
rg -B5 -A5 'cudaLaunchAttributeProgrammaticStreamSerialization' csrc/fmhaReduction.cu | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1153


🏁 Script executed:

#!/bin/bash
# Check how device_support_pdl is used in the Python wrapper more carefully
echo "=== Check enable_pdl parameter handling in trtllm_moe_alltoall.py ==="
sed -n '270,320p' flashinfer/comm/trtllm_moe_alltoall.py

echo ""
echo "=== Compare with nvfp4_quantize pattern ==="
sed -n '1234,1250p' flashinfer/quantization/kernels/nvfp4_quantize.py

echo ""
echo "=== Check if MoeA2ADispatchParams is used with any validation in downstream code ==="
rg -n 'enable_pdl.*false|enable_pdl.*False|enable_pdl.*0' csrc/ --type cpp

Repository: flashinfer-ai/flashinfer

Length of output: 3616


Add device capability validation at the C++ FFI boundary.

The C++ code directly assigns enablePdl to params.enable_pdl without verifying device PDL support. If a caller explicitly passes enable_pdl=true on hardware with compute capability < 9, the CUDA launch will fail. Unlike flashinfer/quantization/kernels/nvfp4_quantize.py which gates even explicit True via device_support_pdl(), the Python wrapper here only auto-detects when enable_pdl is None.

Add a device capability check in C++ before using the PDL launch attribute, or gate explicit True in the Python wrapper to match the pattern in nvfp4_quantize.py (line 1246): enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else False.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_moe_alltoall.cu` at line 200, The code assigns enablePdl directly
to params.enable_pdl (params.enable_pdl = enablePdl) without validating device
PDL support; add a device-capability check before setting params.enable_pdl (or
gate explicit True in the Python wrapper) so that PDL is only enabled on
supported hardware: query the device compute capability (or reuse the existing
device_support_pdl logic) and set params.enable_pdl = true only when enablePdl
is true AND device_support_pdl(device) is true (otherwise set false); update the
C++ FFI boundary where enablePdl is consumed (the assignment to
params.enable_pdl) to perform this check, or mirror the Python pattern
enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else
False.

@aleozlx aleozlx added the run-ci label Apr 24, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 24, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !599 has been created, and the CI pipeline #49348350 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 24, 2026

pls address failed pre-commit check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants