|
| 1 | +# Copyright (C) 2026 Tencent. |
| 2 | + |
| 3 | +"""Shared base classes and helpers for all FusedMoE backends.""" |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from abc import ABC, abstractmethod |
| 7 | +from dataclasses import dataclass, asdict |
| 8 | +from typing import Callable, Optional |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | + |
| 13 | +# --------------------------------------------------------------------------- |
| 14 | +# Per-cell shape spec |
| 15 | +# --------------------------------------------------------------------------- |
| 16 | +@dataclass(frozen=True) |
| 17 | +class BenchSpec: |
| 18 | + """Per-cell input shape that backends consume. |
| 19 | +
|
| 20 | + Fields match the kernel-visible shape (post TP/EP simulation), not the |
| 21 | + full-model shape. The driver derives these from the user's --tp/--ep |
| 22 | + args and stores them in the JSONL output for traceability. |
| 23 | + """ |
| 24 | + num_seq: int # batch_size |
| 25 | + hidden: int # K of Gate-Up; N of Down |
| 26 | + intermediate_per_rank: int # N of one of {Gate, Up}; K of Down |
| 27 | + num_expert_local: int # experts visible to this rank |
| 28 | + num_expert_total: int # for sampling topk_ids; equals local under EP-rank-0 sim |
| 29 | + num_topk: int |
| 30 | + |
| 31 | + model: str = "" |
| 32 | + tp: int = 1 |
| 33 | + ep: int = 1 |
| 34 | + |
| 35 | + |
| 36 | +# --------------------------------------------------------------------------- |
| 37 | +# Backend ABC |
| 38 | +# --------------------------------------------------------------------------- |
| 39 | +class Backend(ABC): |
| 40 | + """Abstract backend. Subclass for each registered benchmark backend.""" |
| 41 | + |
| 42 | + name: str # registry key, e.g. "hpcops" |
| 43 | + |
| 44 | + @abstractmethod |
| 45 | + def setup(self, spec: BenchSpec) -> Callable[[], None]: |
| 46 | + """Build tensors and return the timed call_fn.""" |
| 47 | + raise NotImplementedError |
| 48 | + |
| 49 | + def cleanup(self) -> None: |
| 50 | + torch.cuda.empty_cache() |
| 51 | + |
| 52 | + def extra_metadata(self) -> dict: |
| 53 | + return {} |
| 54 | + |
| 55 | + |
| 56 | +# --------------------------------------------------------------------------- |
| 57 | +# Shared tensor builders |
| 58 | +# --------------------------------------------------------------------------- |
| 59 | +DTYPE_FP8 = torch.float8_e4m3fn |
| 60 | +DTYPE_HALF = torch.half |
| 61 | + |
| 62 | + |
| 63 | +def build_fp8_weights( |
| 64 | + num_expert_local: int, |
| 65 | + intermediate_per_rank: int, |
| 66 | + hidden: int, |
| 67 | + *, |
| 68 | + seed: int = 0, |
| 69 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 70 | + """Build w1 (gate+up fused along N) and w2 (down) in fp8 + per-expert |
| 71 | + per-tensor scales. |
| 72 | +
|
| 73 | + Layouts match Triton / CUTLASS / sglang convention: |
| 74 | + w1: [E, 2N, K] (fp8) w1_scale: [E, 1, 1] (fp32) |
| 75 | + w2: [E, K, N] (fp8) w2_scale: [E, 1, 1] (fp32) |
| 76 | +
|
| 77 | + Returns the same 4 tensors regardless of backend; HPC reshapes/views |
| 78 | + these to its own per-expert layout in its own backend module. |
| 79 | + """ |
| 80 | + from vllm import _custom_ops as ops |
| 81 | + |
| 82 | + g = torch.Generator(device="cuda").manual_seed(seed) |
| 83 | + E, N, K = num_expert_local, intermediate_per_rank, hidden |
| 84 | + |
| 85 | + w1_half = torch.randn( |
| 86 | + (E, 2 * N, K), dtype=torch.float, device="cuda", generator=g, |
| 87 | + ).to(DTYPE_HALF) |
| 88 | + w2_half = torch.randn( |
| 89 | + (E, K, N), dtype=torch.float, device="cuda", generator=g, |
| 90 | + ).to(DTYPE_HALF) |
| 91 | + |
| 92 | + w1_fp8 = torch.empty_like(w1_half, dtype=DTYPE_FP8) |
| 93 | + w2_fp8 = torch.empty_like(w2_half, dtype=DTYPE_FP8) |
| 94 | + w1_scale = torch.empty((E, 1, 1), device="cuda", dtype=torch.float32) |
| 95 | + w2_scale = torch.empty((E, 1, 1), device="cuda", dtype=torch.float32) |
| 96 | + for e in range(E): |
| 97 | + w1_fp8[e], s1 = ops.scaled_fp8_quant(w1_half[e]) |
| 98 | + w2_fp8[e], s2 = ops.scaled_fp8_quant(w2_half[e]) |
| 99 | + w1_scale[e, 0, 0] = s1 |
| 100 | + w2_scale[e, 0, 0] = s2 |
| 101 | + return w1_fp8, w2_fp8, w1_scale, w2_scale |
| 102 | + |
| 103 | + |
| 104 | +def build_routing( |
| 105 | + num_seq: int, |
| 106 | + num_expert_total: int, |
| 107 | + num_topk: int, |
| 108 | + *, |
| 109 | + seed: int = 0, |
| 110 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 111 | + """Sample uniform `topk_ids` and a normalized `topk_weights`. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + topk_ids : (num_seq, num_topk) int32, sorted along topk axis |
| 115 | + topk_w : (num_seq, num_topk) float32, softmax-normalized |
| 116 | + """ |
| 117 | + g = torch.Generator(device="cuda").manual_seed(seed) |
| 118 | + topk_ids = torch.stack([ |
| 119 | + torch.sort( |
| 120 | + torch.randperm( |
| 121 | + num_expert_total, dtype=torch.int32, device="cuda", |
| 122 | + generator=g, |
| 123 | + )[:num_topk] |
| 124 | + ).values |
| 125 | + for _ in range(num_seq) |
| 126 | + ]) |
| 127 | + topk_w = torch.softmax( |
| 128 | + torch.randn((num_seq, num_topk), dtype=torch.float32, device="cuda", |
| 129 | + generator=g), |
| 130 | + dim=-1, |
| 131 | + ) |
| 132 | + return topk_ids, topk_w |
| 133 | + |
| 134 | + |
| 135 | +def build_activation( |
| 136 | + num_seq: int, hidden: int, *, seed: int = 0, |
| 137 | +) -> torch.Tensor: |
| 138 | + """Build a half activation tensor.""" |
| 139 | + g = torch.Generator(device="cuda").manual_seed(seed) |
| 140 | + return torch.randn( |
| 141 | + (num_seq, hidden), dtype=DTYPE_HALF, device="cuda", generator=g, |
| 142 | + ) / 10 |
| 143 | + |
| 144 | + |
| 145 | +A_SCALE_VALUE = 1e-2 |
| 146 | + |
| 147 | + |
| 148 | +def build_a_scale() -> torch.Tensor: |
| 149 | + return torch.full((), A_SCALE_VALUE, device="cuda", dtype=torch.float32) |
| 150 | + |
| 151 | + |
| 152 | +# --------------------------------------------------------------------------- |
| 153 | +# Method C timing harness |
| 154 | +# --------------------------------------------------------------------------- |
| 155 | +def run_method_c(call_fn: Callable[[], None], *, n_timed: int = 52): |
| 156 | + """Run warmup, graph capture, replay warmup, and timed graph replays.""" |
| 157 | + import nvtx # imported lazily so backends that error during setup don't fail on import |
| 158 | + |
| 159 | + for _ in range(3): |
| 160 | + call_fn() |
| 161 | + torch.cuda.synchronize() |
| 162 | + |
| 163 | + graph = torch.cuda.CUDAGraph() |
| 164 | + with torch.cuda.graph(graph): |
| 165 | + call_fn() |
| 166 | + |
| 167 | + for _ in range(3): |
| 168 | + graph.replay() |
| 169 | + torch.cuda.synchronize() |
| 170 | + |
| 171 | + torch.cuda.cudart().cudaProfilerStart() |
| 172 | + for _ in range(n_timed): |
| 173 | + with nvtx.annotate("step"): |
| 174 | + graph.replay() |
| 175 | + torch.cuda.synchronize() |
| 176 | + torch.cuda.cudart().cudaProfilerStop() |
| 177 | + |
| 178 | + |
| 179 | +# --------------------------------------------------------------------------- |
| 180 | +# Spec serialization (worker stdin <-> driver) |
| 181 | +# --------------------------------------------------------------------------- |
| 182 | +def spec_to_argv(spec: BenchSpec) -> list[str]: |
| 183 | + """Serialize a BenchSpec to argv (used by bench.py to invoke worker.py).""" |
| 184 | + return [ |
| 185 | + "--num-seq", str(spec.num_seq), |
| 186 | + "--hidden", str(spec.hidden), |
| 187 | + "--intermediate-per-rank", str(spec.intermediate_per_rank), |
| 188 | + "--num-expert-local", str(spec.num_expert_local), |
| 189 | + "--num-expert-total", str(spec.num_expert_total), |
| 190 | + "--num-topk", str(spec.num_topk), |
| 191 | + "--model", spec.model, |
| 192 | + "--tp", str(spec.tp), |
| 193 | + "--ep", str(spec.ep), |
| 194 | + ] |
| 195 | + |
| 196 | + |
| 197 | +def spec_from_args(args) -> BenchSpec: |
| 198 | + return BenchSpec( |
| 199 | + num_seq=args.num_seq, |
| 200 | + hidden=args.hidden, |
| 201 | + intermediate_per_rank=args.intermediate_per_rank, |
| 202 | + num_expert_local=args.num_expert_local, |
| 203 | + num_expert_total=args.num_expert_total, |
| 204 | + num_topk=args.num_topk, |
| 205 | + model=args.model, |
| 206 | + tp=args.tp, |
| 207 | + ep=args.ep, |
| 208 | + ) |
0 commit comments