Skip to content

Commit 48459a2

Browse files
authored
feat(benchmark/gemm): add base, dense, and deepseek GEMM benchmarks (#226)
1 parent f8df714 commit 48459a2

File tree

7 files changed

+1123
-1
lines changed

7 files changed

+1123
-1
lines changed

primus/cli/benchmark_cli.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
8+
def run(args, extra_args):
9+
"""
10+
Execute the benchmark command.
11+
This can internally call Megatron / TorchTitan hooks, or profile.py scripts.
12+
"""
13+
14+
suite = args.suite
15+
print(f"[Primus:Benchmark] suite={suite} args={args}")
16+
17+
from primus.tools.utils import finalize_distributed, init_distributed
18+
19+
init_distributed()
20+
21+
if suite == "gemm":
22+
from primus.tools.benchmark.gemm_bench import run_gemm_benchmark
23+
24+
run_gemm_benchmark(args)
25+
elif suite == "gemm-dense":
26+
from primus.tools.benchmark.dense_gemm_bench import run_gemm_benchmark
27+
28+
run_gemm_benchmark(args)
29+
elif suite == "gemm-deepseek":
30+
from primus.tools.benchmark.deepseek_dense_gemm_bench import run_gemm_benchmark
31+
32+
run_gemm_benchmark(args)
33+
34+
finalize_distributed()
35+
36+
37+
def register_subcommand(subparsers):
38+
"""
39+
primus-cli benchmark <suite> [suite-specific-args]
40+
suites: gemm | attention | rccl
41+
"""
42+
parser = subparsers.add_parser("benchmark", help="Run performance benchmarks (GEMM / Attention / RCCL).")
43+
suite_parsers = parser.add_subparsers(dest="suite", required=True)
44+
45+
# ---------- GEMM ----------
46+
gemm = suite_parsers.add_parser("gemm", help="GEMM microbench.")
47+
from primus.tools.benchmark import gemm_bench
48+
49+
gemm_bench.add_gemm_parser(gemm)
50+
51+
# ---------- DENSE-GEMM ----------
52+
dense_gemm = suite_parsers.add_parser("gemm-dense", help="GEMM-DENSE microbench.")
53+
from primus.tools.benchmark import dense_gemm_bench
54+
55+
dense_gemm_bench.add_gemm_parser(dense_gemm)
56+
57+
# ---------- DEEPSEEK-GEMM ----------
58+
deepseek_gemm = suite_parsers.add_parser("gemm-deepseek", help="DEEPSEEK-GEMM microbench.")
59+
from primus.tools.benchmark import deepseek_dense_gemm_bench
60+
61+
deepseek_dense_gemm_bench.add_gemm_parser(deepseek_gemm)
62+
63+
parser.set_defaults(func=run)
64+
65+
return parser

primus/cli/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ def main():
2323
parser = argparse.ArgumentParser(prog="primus", description="Primus Unified CLI for Training & Utilities")
2424
subparsers = parser.add_subparsers(dest="command", required=True)
2525

26-
from primus.cli import train_cli
26+
from primus.cli import benchmark_cli, train_cli
2727

2828
# Register train subcommand (only implemented one for now)
2929
train_cli.register_subcommand(subparsers)
30+
benchmark_cli.register_subcommand(subparsers)
3031

3132
args, unknown_args = parser.parse_known_args()
3233

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
import argparse
8+
import itertools
9+
from datetime import datetime
10+
from typing import Tuple
11+
12+
import torch
13+
from git import List
14+
from tqdm import tqdm
15+
16+
from primus.tools.benchmark.gemm_bench import profile_gemm
17+
from primus.tools.report import write_table_simple
18+
from primus.tools.utils import gather_records, is_rank_0
19+
20+
MODEL_CONFIGS = {
21+
"Deepseek_V2_Lite": {
22+
"seqlen": 4096,
23+
"hidden_size": 2048,
24+
"intermediate_size": 10944,
25+
"kv_lora_rank": 512,
26+
"moe_intermediate_size": 1408,
27+
"num_attention_heads": 16,
28+
"num_experts_per_tok": 6,
29+
"n_routed_experts": 64,
30+
"n_shared_experts": 2,
31+
"q_lora_rank": None,
32+
"qk_nope_head_dim": 128,
33+
"qk_rope_head_dim": 64,
34+
"v_head_dim": 128,
35+
"vocab_size": 102400,
36+
},
37+
"Deepseek_V2": {
38+
"seqlen": 4096,
39+
"hidden_size": 5120,
40+
"intermediate_size": 12288,
41+
"kv_lora_rank": 512,
42+
"moe_intermediate_size": 1536,
43+
"num_attention_heads": 128,
44+
"num_experts_per_tok": 6,
45+
"n_routed_experts": 160,
46+
"n_shared_experts": 2,
47+
"q_lora_rank": 1536,
48+
"qk_nope_head_dim": 128,
49+
"qk_rope_head_dim": 64,
50+
"v_head_dim": 128,
51+
"vocab_size": 102400,
52+
},
53+
"Deepseek_V3": {
54+
"seqlen": 4096,
55+
"hidden_size": 7168,
56+
"intermediate_size": 18432,
57+
"kv_lora_rank": 512,
58+
"moe_intermediate_size": 2048,
59+
"num_attention_heads": 128,
60+
"num_experts_per_tok": 8,
61+
"n_routed_experts": 256,
62+
"n_shared_experts": 1,
63+
"q_lora_rank": 1536,
64+
"qk_nope_head_dim": 128,
65+
"qk_rope_head_dim": 64,
66+
"v_head_dim": 128,
67+
"vocab_size": 129280,
68+
},
69+
}
70+
71+
72+
def add_gemm_parser(parser: argparse.ArgumentParser):
73+
parser.add_argument("--model", default=None, help="Model name (Deepseek_V2, Deepseek_V3, etc.)")
74+
parser.add_argument("--seqlen", type=int, default=4096)
75+
parser.add_argument("--hidden-size", type=int, default=4096)
76+
parser.add_argument("--intermediate-size", type=int, default=12288)
77+
parser.add_argument("--kv-lora-rank", type=int, default=512)
78+
parser.add_argument("--moe-intermediate-size", type=int, default=1536)
79+
parser.add_argument("--num-attention-heads", type=int, default=64)
80+
parser.add_argument("--num-experts-per-tok", type=int, default=6)
81+
parser.add_argument("--n-routed-experts", type=int, default=128)
82+
parser.add_argument("--n-shared-experts", type=int, default=2)
83+
parser.add_argument("--q-lora-rank", type=int, default=None)
84+
parser.add_argument("--qk-nope-head-dim", type=int, default=128)
85+
parser.add_argument("--qk-rope-head-dim", type=int, default=64)
86+
parser.add_argument("--v-head-dim", type=int, default=128)
87+
parser.add_argument("--vocab-size", type=int, default=128256)
88+
parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
89+
parser.add_argument("--mbs", type=int, default=1)
90+
parser.add_argument("--duration", type=int, default=3, help="Benchmark duration per shape (sec)")
91+
parser.add_argument("--output-file", default="./gemm-deepseek_report.md")
92+
parser.add_argument("--append", action="store_true", help="Append to existing report")
93+
return parser
94+
return parser
95+
96+
97+
def profile_fwd(m, n, k, dtype, duration):
98+
return profile_gemm(m, n, k, dtype, False, True, duration)
99+
100+
101+
def profile_wgrad(m, n, k, dtype, duration):
102+
return profile_gemm(n, k, m, dtype, True, False, duration)
103+
104+
105+
def profile_dgrad(m, n, k, dtype, duration):
106+
return profile_gemm(m, k, n, dtype, False, False, duration)
107+
108+
109+
def build_preamble(args, shapes: List[Tuple[str, List[int]]]) -> str:
110+
lines = [
111+
"# DeepSeek GEMM Benchmark Report",
112+
"",
113+
f"- Model: {args.model or 'Custom'}",
114+
f"- Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
115+
f"- Duration per shape: {args.duration}s",
116+
"",
117+
"## Configuration",
118+
f"- seqlen: {args.seqlen}",
119+
f"- hidden_size: {args.hidden_size}",
120+
f"- intermediate_size: {args.intermediate_size}",
121+
f"- kv_lora_rank: {args.kv_lora_rank}",
122+
f"- moe_intermediate_size: {args.moe_intermediate_size}",
123+
f"- num_attention_heads: {args.num_attention_heads}",
124+
f"- num_experts_per_tok: {args.num_experts_per_tok}",
125+
f"- n_routed_experts: {args.n_routed_experts}",
126+
f"- n_shared_experts: {args.n_shared_experts}",
127+
f"- q_lora_rank: {args.q_lora_rank}",
128+
f"- dtype: {args.dtype}",
129+
"",
130+
"## GEMM Shapes (M, N, K)",
131+
]
132+
for name, (m, n, k) in shapes:
133+
lines.append(f"- {name}: ({m}, {n}, {k})")
134+
lines += ["", "## Phases", "- fwd", "- wgrad", "- dgrad", ""]
135+
return "\n".join(lines)
136+
137+
138+
def run_gemm_benchmark(args):
139+
if args.model:
140+
model_lower_map = {k.lower(): k for k in MODEL_CONFIGS.keys()}
141+
model_key = args.model.lower()
142+
143+
if model_key not in model_lower_map:
144+
raise ValueError(
145+
f"[ERROR] Unknown model '{args.model}'. Supported models: {', '.join(MODEL_CONFIGS.keys())}"
146+
)
147+
148+
true_key = model_lower_map[model_key]
149+
cfg = MODEL_CONFIGS[true_key]
150+
args.model = true_key # 规范化模型名
151+
for k, v in cfg.items():
152+
setattr(args, k, v)
153+
else:
154+
print("[INFO] No model specified. Using CLI-provided parameters.")
155+
156+
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp34": torch.float32}
157+
dtype = dtype_map[args.dtype]
158+
159+
q_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
160+
shape_defs = []
161+
162+
# q-proj
163+
if args.q_lora_rank is None:
164+
shape_defs.append(("attn_q", [args.seqlen, args.num_attention_heads * q_head_dim, args.hidden_size]))
165+
else:
166+
shape_defs.append(("attn_q_down", [args.seqlen, args.q_lora_rank, args.hidden_size]))
167+
shape_defs.append(
168+
("attn_q_up", [args.seqlen, args.num_attention_heads * q_head_dim, args.q_lora_rank])
169+
)
170+
171+
# kv projections
172+
shape_defs += [
173+
("attn_kv_down", [args.seqlen, args.kv_lora_rank + args.qk_rope_head_dim, args.hidden_size]),
174+
(
175+
"attn_kv_up",
176+
[
177+
args.seqlen,
178+
args.num_attention_heads * (args.qk_nope_head_dim + args.v_head_dim),
179+
args.kv_lora_rank,
180+
],
181+
),
182+
("attn_out", [args.seqlen, args.hidden_size, args.num_attention_heads * args.v_head_dim]),
183+
("router", [args.seqlen, args.n_routed_experts, args.hidden_size]),
184+
]
185+
186+
# shared experts
187+
if args.n_shared_experts > 0:
188+
shape_defs.append(("shared_gateup", [args.seqlen, args.intermediate_size * 2, args.hidden_size]))
189+
shape_defs.append(("shared_down", [args.seqlen, args.hidden_size, args.intermediate_size]))
190+
191+
# routed experts (balance)
192+
balance_seq = int(args.seqlen * args.num_experts_per_tok // args.n_routed_experts)
193+
shape_defs.append(("moe_gateup", [balance_seq, args.moe_intermediate_size * 2, args.hidden_size]))
194+
shape_defs.append(("moe_down", [balance_seq, args.hidden_size, args.moe_intermediate_size]))
195+
196+
# vocab
197+
shape_defs.append(("vocab", [args.seqlen, args.vocab_size, args.hidden_size]))
198+
199+
func_defs = [
200+
("fwd", profile_fwd),
201+
("wgrad", profile_wgrad),
202+
("dgrad", profile_dgrad),
203+
]
204+
205+
record = {}
206+
for (phase, shape), (tag, func) in tqdm(
207+
itertools.product(shape_defs, func_defs),
208+
total=len(shape_defs) * len(func_defs),
209+
desc=f"[DeepSeek GEMM] {args.model or 'Custom'}",
210+
):
211+
m, n, k = [args.mbs * shape[0], shape[1], shape[2]]
212+
213+
res = func(m, n, k, dtype, args.duration)
214+
summary = (
215+
f"{res['tflops']:.2f}TF/s / "
216+
f"{res['bandwidth_gbps']:.2f}GB/s / "
217+
f"{res['avg_time_ms']:.6f}s / "
218+
f"AI={res['arith_intensity']:.2f}"
219+
)
220+
record[f"{phase}_{tag}"] = summary
221+
222+
gathered = gather_records(record)
223+
if is_rank_0():
224+
all_keys = set().union(*(r.keys() for r in gathered))
225+
header = ["host", "world", "rank"] + sorted(
226+
[k for k in all_keys if k not in {"host", "rank", "world"}]
227+
)
228+
229+
rows = [[r.get(col, "") for col in header] for r in gathered]
230+
231+
preamble = build_preamble(args, shape_defs)
232+
233+
append = getattr(args, "append", False)
234+
235+
write_table_simple(
236+
header=header,
237+
rows=rows,
238+
output_file=args.output_file or f"benchmark_gemm_dense_{args.model}.md",
239+
append=append,
240+
preamble=preamble if not append else None,
241+
)
242+
243+
print(f"[✔] DeepSeek GEMM benchmark finished. Results saved to {args.output_file}")
244+
245+
246+
def build_gemm_dense_parser() -> argparse.ArgumentParser:
247+
"""
248+
Build a standalone parser for local execution.
249+
"""
250+
parser = argparse.ArgumentParser(description="DEEPSEEK-GEMM benchmark")
251+
add_gemm_parser(parser)
252+
return parser
253+
254+
255+
if __name__ == "__main__":
256+
parser = build_gemm_dense_parser()
257+
args = parser.parse_args()
258+
run_gemm_benchmark(args)

0 commit comments

Comments
 (0)