Skip to content

Commit 21b0ecf

Browse files
donpromaxlvdong
andauthored
feat: support ascend npu sft (#43)
* feat: support ascend npu sft * feat(npu): add native flash-attn, alltoall dispatcher and grouped-gemm optimizations * feat: raise exception when npu not available * feat(ascend): simplify npu playground code * docs(ascend): clarify native NPU alternative registration and patch scope --------- Co-authored-by: lvdong <lvdong@stepfun.com>
1 parent 471c676 commit 21b0ecf

20 files changed

Lines changed: 2627 additions & 8 deletions

benchmarks/benchmark_dispatcher_npu.py

Lines changed: 463 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from steptronoss.utils.npu_patch import apply_npu_patch
2+
3+
apply_npu_patch()
4+
5+
import argparse
6+
import time
7+
from pathlib import Path
8+
9+
import torch
10+
11+
REPO_ROOT = Path(__file__).resolve().parent
12+
if str(REPO_ROOT.parent) not in __import__("sys").path:
13+
__import__("sys").path.insert(0, str(REPO_ROOT.parent))
14+
15+
16+
DEFAULT_PARAM_SETS = [
17+
{
18+
"name": "moe_like_large",
19+
"group_size": 36,
20+
"batch_size": 3256,
21+
"k": 4096,
22+
"n": 2560,
23+
"dtype": "bf16",
24+
"warmup": 20,
25+
"iters": 20,
26+
"trans_b": True,
27+
},
28+
]
29+
30+
31+
def _dtype_from_name(name: str) -> torch.dtype:
32+
table = {
33+
"bf16": torch.bfloat16,
34+
"fp16": torch.float16,
35+
"fp32": torch.float32,
36+
}
37+
if name not in table:
38+
raise ValueError(f"Unsupported dtype: {name}")
39+
return table[name]
40+
41+
42+
def _sync():
43+
torch.npu.synchronize()
44+
45+
46+
def _build_inputs(params: dict[str, object], device: torch.device):
47+
dtype = _dtype_from_name(params["dtype"])
48+
group_size = int(params["group_size"])
49+
batch_size = int(params["batch_size"])
50+
k = int(params["k"])
51+
n = int(params["n"])
52+
trans_b = bool(params["trans_b"])
53+
54+
batch_sizes = torch.full((group_size,), batch_size, device=device, dtype=torch.int64)
55+
total_m = int(batch_sizes.sum().item())
56+
57+
a = torch.randn(total_m, k, device=device, dtype=dtype, requires_grad=True)
58+
if trans_b:
59+
b = torch.randn(group_size, n, k, device=device, dtype=dtype, requires_grad=True)
60+
else:
61+
b = torch.randn(group_size, k, n, device=device, dtype=dtype, requires_grad=True)
62+
return a, b, batch_sizes
63+
64+
65+
def _run_baseline(
66+
mat_a_flat: torch.Tensor, mat_b: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool
67+
) -> torch.Tensor:
68+
batch_sizes_list = batch_sizes.tolist()
69+
outputs = []
70+
start = 0
71+
for i, size in enumerate(batch_sizes_list):
72+
rhs = mat_b[i].t() if trans_b else mat_b[i]
73+
outputs.append(mat_a_flat[start : start + size] @ rhs)
74+
start += size
75+
if outputs:
76+
return torch.cat(outputs, dim=0)
77+
return mat_a_flat.new_zeros((0, mat_b.shape[1] if trans_b else mat_b.shape[2]))
78+
79+
80+
def _run_npu_gmm_v2(
81+
mat_a_flat: torch.Tensor, mat_b: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool
82+
) -> torch.Tensor:
83+
try:
84+
from mindspeed.ops.gmm import npu_gmm_v2
85+
except Exception as exc:
86+
raise ImportError("from mindspeed.ops.gmm import npu_gmm_v2 failed.") from exc
87+
88+
if mat_a_flat.shape[0] == 0:
89+
return mat_a_flat.new_zeros((0, mat_b.shape[1] if trans_b else mat_b.shape[2]))
90+
91+
weight = mat_b.transpose(-1, -2) if trans_b else mat_b
92+
if batch_sizes.device.type != "npu":
93+
batch_sizes = batch_sizes.to(device=mat_a_flat.device)
94+
batch_sizes = batch_sizes.to(dtype=torch.int64)
95+
return npu_gmm_v2(mat_a_flat, weight, bias=None, group_list=batch_sizes, group_type=0)
96+
97+
98+
def _time_forward(fn, warmup: int, iters: int) -> tuple[float, torch.Tensor]:
99+
out = None
100+
for _ in range(warmup):
101+
with torch.no_grad():
102+
out = fn()
103+
_sync()
104+
105+
start = time.perf_counter()
106+
for _ in range(iters):
107+
with torch.no_grad():
108+
out = fn()
109+
_sync()
110+
return (time.perf_counter() - start) * 1000.0 / iters, out
111+
112+
113+
def _time_forward_backward(
114+
fn, a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int
115+
) -> tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]:
116+
out = None
117+
for _ in range(warmup):
118+
out = fn()
119+
out.sum().backward()
120+
a.grad = None
121+
b.grad = None
122+
_sync()
123+
124+
start = time.perf_counter()
125+
for _ in range(iters):
126+
out = fn()
127+
out.sum().backward()
128+
grad_a = a.grad.detach().clone()
129+
grad_b = b.grad.detach().clone()
130+
a.grad = None
131+
b.grad = None
132+
_sync()
133+
total_ms = (time.perf_counter() - start) * 1000.0 / iters
134+
return total_ms, out.detach().clone(), grad_a, grad_b
135+
136+
137+
def _max_abs_diff(x: torch.Tensor, y: torch.Tensor) -> float:
138+
return float((x.float() - y.float()).abs().max().item())
139+
140+
141+
def _check_close(x: torch.Tensor, y: torch.Tensor, rtol: float, atol: float) -> bool:
142+
try:
143+
torch.testing.assert_close(x, y, rtol=rtol, atol=atol)
144+
return True
145+
except Exception:
146+
return False
147+
148+
149+
def _bench_one(params: dict[str, object], rtol: float, atol: float):
150+
if not hasattr(torch, "npu") or not torch.npu.is_available():
151+
raise RuntimeError("NPU is not available.")
152+
153+
device = torch.device("npu")
154+
warmup = int(params["warmup"])
155+
iters = int(params["iters"])
156+
trans_b = bool(params["trans_b"])
157+
158+
a_base, b_base, batch_sizes = _build_inputs(params, device)
159+
a_npu = a_base.detach().clone().requires_grad_(True)
160+
b_npu = b_base.detach().clone().requires_grad_(True)
161+
162+
fw_ms_base, ref_out = _time_forward(
163+
lambda: _run_baseline(a_base, b_base, batch_sizes, trans_b),
164+
warmup=warmup,
165+
iters=iters,
166+
)
167+
total_ms_base, ref_out_bw, ref_da, ref_db = _time_forward_backward(
168+
lambda: _run_baseline(a_base, b_base, batch_sizes, trans_b),
169+
a=a_base,
170+
b=b_base,
171+
warmup=warmup,
172+
iters=iters,
173+
)
174+
175+
fw_ms_npu, out_npu = _time_forward(
176+
lambda: _run_npu_gmm_v2(a_npu, b_npu, batch_sizes, trans_b),
177+
warmup=warmup,
178+
iters=iters,
179+
)
180+
total_ms_npu, out_npu_bw, da_npu, db_npu = _time_forward_backward(
181+
lambda: _run_npu_gmm_v2(a_npu, b_npu, batch_sizes, trans_b),
182+
a=a_npu,
183+
b=b_npu,
184+
warmup=warmup,
185+
iters=iters,
186+
)
187+
188+
print(
189+
f"[npu_grouped_gemm] name={params['name']} group={params['group_size']} "
190+
f"batch={params['batch_size']} k={params['k']} n={params['n']} "
191+
f"dtype={params['dtype']} trans_b={trans_b}"
192+
)
193+
print("backend, fw_ms, bw_ms, total_ms")
194+
print(f"baseline, {fw_ms_base:.3f}, {total_ms_base - fw_ms_base:.3f}, {total_ms_base:.3f}")
195+
print(f"npu_gmm_v2, {fw_ms_npu:.3f}, {total_ms_npu - fw_ms_npu:.3f}, {total_ms_npu:.3f}")
196+
print(
197+
"speedup_vs_baseline, "
198+
f"fw={fw_ms_base / fw_ms_npu:.2f}x, "
199+
f"bw={(total_ms_base - fw_ms_base) / (total_ms_npu - fw_ms_npu):.2f}x, "
200+
f"total={total_ms_base / total_ms_npu:.2f}x"
201+
)
202+
print("metric, close, max_abs_diff")
203+
print(f"forward, {_check_close(out_npu, ref_out, rtol, atol)}, {_max_abs_diff(out_npu, ref_out):.6f}")
204+
print(
205+
f"forward_bw_run, {_check_close(out_npu_bw, ref_out_bw, rtol, atol)}, {_max_abs_diff(out_npu_bw, ref_out_bw):.6f}"
206+
)
207+
print(f"grad_a, {_check_close(da_npu, ref_da, rtol, atol)}, {_max_abs_diff(da_npu, ref_da):.6f}")
208+
print(f"grad_b, {_check_close(db_npu, ref_db, rtol, atol)}, {_max_abs_diff(db_npu, ref_db):.6f}")
209+
210+
211+
def main() -> int:
212+
parser = argparse.ArgumentParser()
213+
parser.add_argument("--rtol", type=float, default=1e-2)
214+
parser.add_argument("--atol", type=float, default=1e-2)
215+
args = parser.parse_args()
216+
217+
torch.npu.set_device(0)
218+
for params in DEFAULT_PARAM_SETS:
219+
_bench_one(params, rtol=args.rtol, atol=args.atol)
220+
return 0
221+
222+
223+
if __name__ == "__main__":
224+
raise SystemExit(main())

0 commit comments

Comments
 (0)