Skip to content

Commit 09886e3

Browse files
committed
add fused fp8 moe kernel for low-latency llm inference
1 parent 8d300e1 commit 09886e3

40 files changed

Lines changed: 6106 additions & 584 deletions

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ find_package(CUDAToolkit REQUIRED)
88
find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED)
99

1010
file(GLOB_RECURSE SOURCES "src/*/*.cu" "src/*/*.cc")
11+
file(GLOB CP_ASYNC_SOURCES "src/fuse_moe/cp_async/*.cu" "src/group_gemm/cp_async/*.cu"
12+
"src/group_gemm/cp_async/*.cc")
13+
list(APPEND SOURCES ${CP_ASYNC_SOURCES})
1114
list(FILTER SOURCES EXCLUDE REGEX ".*test.*")
1215

1316

bench/fused_moe/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# FusedMoE Benchmark
2+
3+
This directory contains a per-tensor FP8 FusedMoE benchmark for HPC-Ops,
4+
vLLM Triton, vLLM CUTLASS, and SGLang.
5+
6+
## Requirements
7+
8+
- NVIDIA GPU with FP8 support.
9+
- CUDA, PyTorch, Triton, NumPy, `nvtx`, and `nsys`.
10+
- Built HPC-Ops, vLLM, and SGLang checkouts.
11+
12+
Set checkout roots before running:
13+
14+
```bash
15+
export HPCOPS_ROOT=/path/to/hpc-ops
16+
export VLLM_ROOT=/path/to/vllm
17+
export SGLANG_ROOT=/path/to/sglang
18+
```
19+
20+
## Usage
21+
22+
Run TP mode:
23+
24+
```bash
25+
python3 bench.py \
26+
--tp 8 --ep 1 \
27+
--gpu 0 \
28+
--backends hpcops vllm vllm_cutlass sglang
29+
```
30+
31+
Run EP mode:
32+
33+
```bash
34+
python3 bench.py \
35+
--tp 1 --ep 8 \
36+
--gpu 0 \
37+
--backends hpcops vllm vllm_cutlass sglang
38+
```
39+
40+
Run a smaller smoke test:
41+
42+
```bash
43+
python3 bench.py \
44+
--tp 8 --ep 1 \
45+
--models qwen3-235b \
46+
--bs 16 32 \
47+
--backends hpcops vllm_cutlass \
48+
--gpu 0
49+
```
50+
51+
By default, outputs are written under `./log/<tag>/`. Override this with:
52+
53+
```bash
54+
python3 bench.py --output-dir /path/to/output ...
55+
```
56+
57+
## Defaults
58+
59+
Models:
60+
61+
| Model | Experts | topk | Hidden | Intermediate |
62+
|---|---:|---:|---:|---:|
63+
| `qwen3-235b` | 128 | 8 | 4096 | 1536 |
64+
| `hunyuan-v3` | 192 | 8 | 4096 | 1536 |
65+
| `deepseek-v3` | 256 | 8 | 7168 | 2048 |
66+
67+
Shape semantics:
68+
69+
- `bs` is the kernel-visible sequence count on the measured rank.
70+
- `TP` partitions the intermediate dimension only, so `intermediate_per_rank = intermediate / TP`.
71+
- `EP` partitions experts, so `experts_per_rank = experts / EP`.
72+
- The reported `avg/group` is `bs * topk / experts_per_rank`.
73+
74+
For `TP=8 EP=1`, experts are not partitioned and the benchmark keeps the full expert set
75+
visible to the measured rank:
76+
77+
```text
78+
avg/group = bs * topk / experts
79+
```
80+
81+
For `TP=1 EP=8`, the benchmark measures one EP rank with local experts only. Routing is
82+
sampled within that local expert set, so:
83+
84+
```text
85+
experts_per_rank = experts / 8
86+
avg/group = bs * topk / experts_per_rank
87+
```
88+
89+
The EP batch range is shorter than the TP range to cover the same per-rank operator regime at
90+
comparable `avg/group` values.
91+
92+
Batch sizes:
93+
94+
| Mode | Batch sizes |
95+
|---|---|
96+
| `TP=8 EP=1` | `4 16 32 64 128 256 512 1024 2048 4096 8192 16384` |
97+
| `TP=1 EP=8` | `4 8 16 32 64 128 256 512 1024 2048` |
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (C) 2026 Tencent.
2+
3+
"""Backend registry."""
4+
from __future__ import annotations
5+
6+
from typing import Callable, Dict
7+
8+
from .base import Backend, BenchSpec # re-export for convenience
9+
10+
_REGISTRY: Dict[str, Callable[[], Backend]] = {}
11+
12+
13+
def register(name: str, factory: Callable[[], Backend]) -> None:
14+
if name in _REGISTRY:
15+
raise ValueError(f"backend already registered: {name}")
16+
_REGISTRY[name] = factory
17+
18+
19+
def make(name: str) -> Backend:
20+
if name not in _REGISTRY:
21+
raise KeyError(
22+
f"unknown backend: {name!r} (known: {sorted(_REGISTRY.keys())})")
23+
return _REGISTRY[name]()
24+
25+
26+
def known() -> list[str]:
27+
return sorted(_REGISTRY.keys())
28+
29+
30+
def _import_all():
31+
"""Trigger registration side effects for all known modules."""
32+
from . import hpcops # noqa: F401
33+
from . import vllm # noqa: F401
34+
from . import vllm_cutlass # noqa: F401
35+
from . import sglang # noqa: F401
36+
37+
38+
_import_all()
39+
40+
41+
__all__ = ["Backend", "BenchSpec", "register", "make", "known"]

bench/fused_moe/backends/base.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (C) 2026 Tencent.
2+
3+
"""Shared base classes and helpers for all FusedMoE backends."""
4+
from __future__ import annotations
5+
6+
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass, asdict
8+
from typing import Callable, Optional
9+
10+
import torch
11+
12+
13+
# ---------------------------------------------------------------------------
14+
# Per-cell shape spec
15+
# ---------------------------------------------------------------------------
16+
@dataclass(frozen=True)
17+
class BenchSpec:
18+
"""Per-cell input shape that backends consume.
19+
20+
Fields match the kernel-visible shape (post TP/EP simulation), not the
21+
full-model shape. The driver derives these from the user's --tp/--ep
22+
args and stores them in the JSONL output for traceability.
23+
"""
24+
num_seq: int # batch_size
25+
hidden: int # K of Gate-Up; N of Down
26+
intermediate_per_rank: int # N of one of {Gate, Up}; K of Down
27+
num_expert_local: int # experts visible to this rank
28+
num_expert_total: int # for sampling topk_ids; equals local under EP-rank-0 sim
29+
num_topk: int
30+
31+
model: str = ""
32+
tp: int = 1
33+
ep: int = 1
34+
35+
36+
# ---------------------------------------------------------------------------
37+
# Backend ABC
38+
# ---------------------------------------------------------------------------
39+
class Backend(ABC):
40+
"""Abstract backend. Subclass for each registered benchmark backend."""
41+
42+
name: str # registry key, e.g. "hpcops"
43+
44+
@abstractmethod
45+
def setup(self, spec: BenchSpec) -> Callable[[], None]:
46+
"""Build tensors and return the timed call_fn."""
47+
raise NotImplementedError
48+
49+
def cleanup(self) -> None:
50+
torch.cuda.empty_cache()
51+
52+
def extra_metadata(self) -> dict:
53+
return {}
54+
55+
56+
# ---------------------------------------------------------------------------
57+
# Shared tensor builders
58+
# ---------------------------------------------------------------------------
59+
DTYPE_FP8 = torch.float8_e4m3fn
60+
DTYPE_HALF = torch.half
61+
62+
63+
def build_fp8_weights(
64+
num_expert_local: int,
65+
intermediate_per_rank: int,
66+
hidden: int,
67+
*,
68+
seed: int = 0,
69+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70+
"""Build w1 (gate+up fused along N) and w2 (down) in fp8 + per-expert
71+
per-tensor scales.
72+
73+
Layouts match Triton / CUTLASS / sglang convention:
74+
w1: [E, 2N, K] (fp8) w1_scale: [E, 1, 1] (fp32)
75+
w2: [E, K, N] (fp8) w2_scale: [E, 1, 1] (fp32)
76+
77+
Returns the same 4 tensors regardless of backend; HPC reshapes/views
78+
these to its own per-expert layout in its own backend module.
79+
"""
80+
from vllm import _custom_ops as ops
81+
82+
g = torch.Generator(device="cuda").manual_seed(seed)
83+
E, N, K = num_expert_local, intermediate_per_rank, hidden
84+
85+
w1_half = torch.randn(
86+
(E, 2 * N, K), dtype=torch.float, device="cuda", generator=g,
87+
).to(DTYPE_HALF)
88+
w2_half = torch.randn(
89+
(E, K, N), dtype=torch.float, device="cuda", generator=g,
90+
).to(DTYPE_HALF)
91+
92+
w1_fp8 = torch.empty_like(w1_half, dtype=DTYPE_FP8)
93+
w2_fp8 = torch.empty_like(w2_half, dtype=DTYPE_FP8)
94+
w1_scale = torch.empty((E, 1, 1), device="cuda", dtype=torch.float32)
95+
w2_scale = torch.empty((E, 1, 1), device="cuda", dtype=torch.float32)
96+
for e in range(E):
97+
w1_fp8[e], s1 = ops.scaled_fp8_quant(w1_half[e])
98+
w2_fp8[e], s2 = ops.scaled_fp8_quant(w2_half[e])
99+
w1_scale[e, 0, 0] = s1
100+
w2_scale[e, 0, 0] = s2
101+
return w1_fp8, w2_fp8, w1_scale, w2_scale
102+
103+
104+
def build_routing(
105+
num_seq: int,
106+
num_expert_total: int,
107+
num_topk: int,
108+
*,
109+
seed: int = 0,
110+
) -> tuple[torch.Tensor, torch.Tensor]:
111+
"""Sample uniform `topk_ids` and a normalized `topk_weights`.
112+
113+
Returns:
114+
topk_ids : (num_seq, num_topk) int32, sorted along topk axis
115+
topk_w : (num_seq, num_topk) float32, softmax-normalized
116+
"""
117+
g = torch.Generator(device="cuda").manual_seed(seed)
118+
topk_ids = torch.stack([
119+
torch.sort(
120+
torch.randperm(
121+
num_expert_total, dtype=torch.int32, device="cuda",
122+
generator=g,
123+
)[:num_topk]
124+
).values
125+
for _ in range(num_seq)
126+
])
127+
topk_w = torch.softmax(
128+
torch.randn((num_seq, num_topk), dtype=torch.float32, device="cuda",
129+
generator=g),
130+
dim=-1,
131+
)
132+
return topk_ids, topk_w
133+
134+
135+
def build_activation(
136+
num_seq: int, hidden: int, *, seed: int = 0,
137+
) -> torch.Tensor:
138+
"""Build a half activation tensor."""
139+
g = torch.Generator(device="cuda").manual_seed(seed)
140+
return torch.randn(
141+
(num_seq, hidden), dtype=DTYPE_HALF, device="cuda", generator=g,
142+
) / 10
143+
144+
145+
A_SCALE_VALUE = 1e-2
146+
147+
148+
def build_a_scale() -> torch.Tensor:
149+
return torch.full((), A_SCALE_VALUE, device="cuda", dtype=torch.float32)
150+
151+
152+
# ---------------------------------------------------------------------------
153+
# Method C timing harness
154+
# ---------------------------------------------------------------------------
155+
def run_method_c(call_fn: Callable[[], None], *, n_timed: int = 52):
156+
"""Run warmup, graph capture, replay warmup, and timed graph replays."""
157+
import nvtx # imported lazily so backends that error during setup don't fail on import
158+
159+
for _ in range(3):
160+
call_fn()
161+
torch.cuda.synchronize()
162+
163+
graph = torch.cuda.CUDAGraph()
164+
with torch.cuda.graph(graph):
165+
call_fn()
166+
167+
for _ in range(3):
168+
graph.replay()
169+
torch.cuda.synchronize()
170+
171+
torch.cuda.cudart().cudaProfilerStart()
172+
for _ in range(n_timed):
173+
with nvtx.annotate("step"):
174+
graph.replay()
175+
torch.cuda.synchronize()
176+
torch.cuda.cudart().cudaProfilerStop()
177+
178+
179+
# ---------------------------------------------------------------------------
180+
# Spec serialization (worker stdin <-> driver)
181+
# ---------------------------------------------------------------------------
182+
def spec_to_argv(spec: BenchSpec) -> list[str]:
183+
"""Serialize a BenchSpec to argv (used by bench.py to invoke worker.py)."""
184+
return [
185+
"--num-seq", str(spec.num_seq),
186+
"--hidden", str(spec.hidden),
187+
"--intermediate-per-rank", str(spec.intermediate_per_rank),
188+
"--num-expert-local", str(spec.num_expert_local),
189+
"--num-expert-total", str(spec.num_expert_total),
190+
"--num-topk", str(spec.num_topk),
191+
"--model", spec.model,
192+
"--tp", str(spec.tp),
193+
"--ep", str(spec.ep),
194+
]
195+
196+
197+
def spec_from_args(args) -> BenchSpec:
198+
return BenchSpec(
199+
num_seq=args.num_seq,
200+
hidden=args.hidden,
201+
intermediate_per_rank=args.intermediate_per_rank,
202+
num_expert_local=args.num_expert_local,
203+
num_expert_total=args.num_expert_total,
204+
num_topk=args.num_topk,
205+
model=args.model,
206+
tp=args.tp,
207+
ep=args.ep,
208+
)

0 commit comments

Comments
 (0)