Skip to content

Commit e3f751e

Browse files
committed
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
2 parents 5b5756d + 3265bd5 commit e3f751e

13 files changed

Lines changed: 2283 additions & 513 deletions

File tree

.github/workflows/pr-test.yml

Lines changed: 396 additions & 35 deletions
Large diffs are not rendered by default.

benchmarks/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ Currently supports testing attention, gemm, fused MOE, normalization, and quanti
2929
- `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling.
3030
- `trtllm_fp8_per_tensor_scale_moe` - MOE with FP8 quantized weights and per-tensor scaling.
3131
- `cutlass_fused_moe` - CUTLASS fused MoE (base/fp8/nvfp4 variants with optional TP/EP)
32+
- MOE Communication:
33+
- `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
3234
- Norm:
3335
- `rmsnorm` - Root Mean Square Layer Normalization.
3436
- `rmsnorm_quant` - RMSNorm with FP8 quantized output.
@@ -238,6 +240,50 @@ Notes:
238240
- FP8 MOE kernels require integer values for group parameters, while FP4 MOE kernels accept optional values.
239241
- CUTLASS fused MoE (`cutlass_fused_moe`) ignores `--routing_method`, `--n_group`, and `--topk_group`; it computes routing via softmax+top-k internally from the provided logits.
240242

243+
### MoE Communication Flags (moe_a2a_dispatch_combine)
244+
The `moe_a2a_dispatch_combine` routine benchmarks MoE All-to-All communication for multi-GPU expert-parallel inference. It must be launched with `mpirun`.
245+
246+
| Flag | Description |
247+
|--------------------------|-------------------------------------------------------------------------------------------------------------|
248+
| `--num_tokens` | Number of tokens per rank (local batch size) |
249+
| `--hidden_size` | Hidden dimension size |
250+
| `--num_experts` | Total number of experts across all ranks |
251+
| `--top_k` | Number of experts to route each token to |
252+
| `--input_dtype` | Data type for hidden states payload: `bfloat16` (default) or `float16` |
253+
| `--quant_dtype` | Quantization format: `fp8` (per-tensor), `nvfp4` (block-scale FP4), `fp8_block_scale` (block-scale FP8) |
254+
| `--real_math` | Run actual MoE kernels instead of fake computation. Requires `--intermediate_size` and `--quant_dtype` to be `nvfp4` or `fp8_block_scale` |
255+
| `--intermediate_size` | Intermediate FFN size. Required if `--real_math` is set |
256+
| `--max_num_tokens` | Max tokens per rank for workspace allocation. Defaults to `--num_tokens` |
257+
| `--validate` | Run correctness validation before benchmarking using deterministic fake MoE |
258+
| `--per_phase_timing` | Enable per-phase timing (dispatch/combine/moe_kernel). Adds slight overhead from CUDA events |
259+
| `--nvtx` | Enable NVTX markers for Nsight Systems profiling |
260+
261+
**Launch Examples:**
262+
```bash
263+
# Basic (no quantization)
264+
mpirun -np 8 python benchmarks/flashinfer_benchmark.py \
265+
--routine moe_a2a_dispatch_combine \
266+
--num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8
267+
268+
# With FP8 quantization
269+
mpirun -np 8 python benchmarks/flashinfer_benchmark.py \
270+
--routine moe_a2a_dispatch_combine \
271+
--num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 \
272+
--quant_dtype fp8
273+
274+
# With NVFP4 quantization and real MoE kernel
275+
mpirun -np 8 python benchmarks/flashinfer_benchmark.py \
276+
--routine moe_a2a_dispatch_combine \
277+
--num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 \
278+
--quant_dtype nvfp4 --real_math --intermediate_size 18432
279+
280+
# With validation and per-phase timing
281+
mpirun -np 8 python benchmarks/flashinfer_benchmark.py \
282+
--routine moe_a2a_dispatch_combine \
283+
--num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 \
284+
--validate --per_phase_timing
285+
```
286+
241287
### Norm Flags
242288
| Flag | Description |
243289
|--------------------------|-------------------------------------------------------------------------------------------------------------|
@@ -301,6 +347,7 @@ Legend:
301347
| **trtllm_fp8_block_scale_moe** | | | | | | trtllm | trtllm | |
302348
| **trtllm_fp8_per_tensor_scale_moe** | | | | | | trtllm | trtllm | |
303349
| **cutlass_fused_moe** | | | | | | cutlass | cutlass | |
350+
| **moe_a2a_dispatch_combine** | | | | | | moe_a2a | moe_a2a | |
304351
| **rmsnorm** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda |
305352
| **rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda |
306353
| **fused_add_rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda |
@@ -324,3 +371,4 @@ Backend Legend:
324371
- trtllm-native: TensorRT-LLM (out-of-wrapper)
325372
- cuda: FlashInfer CUDA kernels
326373
- cute-dsl: FlashInfer CuTe-DSL kernels (Blackwell SM10.0+)
374+
- moe_a2a: MoE All-to-All communication (requires mpirun, Blackwell SM10.0+ with MNNVL)

0 commit comments

Comments
 (0)