|
| 1 | +from steptronoss.utils.npu_patch import apply_npu_patch |
| 2 | + |
| 3 | +apply_npu_patch() |
| 4 | + |
| 5 | +import argparse |
| 6 | +import time |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +REPO_ROOT = Path(__file__).resolve().parent |
| 12 | +if str(REPO_ROOT.parent) not in __import__("sys").path: |
| 13 | + __import__("sys").path.insert(0, str(REPO_ROOT.parent)) |
| 14 | + |
| 15 | + |
| 16 | +DEFAULT_PARAM_SETS = [ |
| 17 | + { |
| 18 | + "name": "moe_like_large", |
| 19 | + "group_size": 36, |
| 20 | + "batch_size": 3256, |
| 21 | + "k": 4096, |
| 22 | + "n": 2560, |
| 23 | + "dtype": "bf16", |
| 24 | + "warmup": 20, |
| 25 | + "iters": 20, |
| 26 | + "trans_b": True, |
| 27 | + }, |
| 28 | +] |
| 29 | + |
| 30 | + |
| 31 | +def _dtype_from_name(name: str) -> torch.dtype: |
| 32 | + table = { |
| 33 | + "bf16": torch.bfloat16, |
| 34 | + "fp16": torch.float16, |
| 35 | + "fp32": torch.float32, |
| 36 | + } |
| 37 | + if name not in table: |
| 38 | + raise ValueError(f"Unsupported dtype: {name}") |
| 39 | + return table[name] |
| 40 | + |
| 41 | + |
| 42 | +def _sync(): |
| 43 | + torch.npu.synchronize() |
| 44 | + |
| 45 | + |
| 46 | +def _build_inputs(params: dict[str, object], device: torch.device): |
| 47 | + dtype = _dtype_from_name(params["dtype"]) |
| 48 | + group_size = int(params["group_size"]) |
| 49 | + batch_size = int(params["batch_size"]) |
| 50 | + k = int(params["k"]) |
| 51 | + n = int(params["n"]) |
| 52 | + trans_b = bool(params["trans_b"]) |
| 53 | + |
| 54 | + batch_sizes = torch.full((group_size,), batch_size, device=device, dtype=torch.int64) |
| 55 | + total_m = int(batch_sizes.sum().item()) |
| 56 | + |
| 57 | + a = torch.randn(total_m, k, device=device, dtype=dtype, requires_grad=True) |
| 58 | + if trans_b: |
| 59 | + b = torch.randn(group_size, n, k, device=device, dtype=dtype, requires_grad=True) |
| 60 | + else: |
| 61 | + b = torch.randn(group_size, k, n, device=device, dtype=dtype, requires_grad=True) |
| 62 | + return a, b, batch_sizes |
| 63 | + |
| 64 | + |
| 65 | +def _run_baseline( |
| 66 | + mat_a_flat: torch.Tensor, mat_b: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool |
| 67 | +) -> torch.Tensor: |
| 68 | + batch_sizes_list = batch_sizes.tolist() |
| 69 | + outputs = [] |
| 70 | + start = 0 |
| 71 | + for i, size in enumerate(batch_sizes_list): |
| 72 | + rhs = mat_b[i].t() if trans_b else mat_b[i] |
| 73 | + outputs.append(mat_a_flat[start : start + size] @ rhs) |
| 74 | + start += size |
| 75 | + if outputs: |
| 76 | + return torch.cat(outputs, dim=0) |
| 77 | + return mat_a_flat.new_zeros((0, mat_b.shape[1] if trans_b else mat_b.shape[2])) |
| 78 | + |
| 79 | + |
| 80 | +def _run_npu_gmm_v2( |
| 81 | + mat_a_flat: torch.Tensor, mat_b: torch.Tensor, batch_sizes: torch.Tensor, trans_b: bool |
| 82 | +) -> torch.Tensor: |
| 83 | + try: |
| 84 | + from mindspeed.ops.gmm import npu_gmm_v2 |
| 85 | + except Exception as exc: |
| 86 | + raise ImportError("from mindspeed.ops.gmm import npu_gmm_v2 failed.") from exc |
| 87 | + |
| 88 | + if mat_a_flat.shape[0] == 0: |
| 89 | + return mat_a_flat.new_zeros((0, mat_b.shape[1] if trans_b else mat_b.shape[2])) |
| 90 | + |
| 91 | + weight = mat_b.transpose(-1, -2) if trans_b else mat_b |
| 92 | + if batch_sizes.device.type != "npu": |
| 93 | + batch_sizes = batch_sizes.to(device=mat_a_flat.device) |
| 94 | + batch_sizes = batch_sizes.to(dtype=torch.int64) |
| 95 | + return npu_gmm_v2(mat_a_flat, weight, bias=None, group_list=batch_sizes, group_type=0) |
| 96 | + |
| 97 | + |
| 98 | +def _time_forward(fn, warmup: int, iters: int) -> tuple[float, torch.Tensor]: |
| 99 | + out = None |
| 100 | + for _ in range(warmup): |
| 101 | + with torch.no_grad(): |
| 102 | + out = fn() |
| 103 | + _sync() |
| 104 | + |
| 105 | + start = time.perf_counter() |
| 106 | + for _ in range(iters): |
| 107 | + with torch.no_grad(): |
| 108 | + out = fn() |
| 109 | + _sync() |
| 110 | + return (time.perf_counter() - start) * 1000.0 / iters, out |
| 111 | + |
| 112 | + |
| 113 | +def _time_forward_backward( |
| 114 | + fn, a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int |
| 115 | +) -> tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 116 | + out = None |
| 117 | + for _ in range(warmup): |
| 118 | + out = fn() |
| 119 | + out.sum().backward() |
| 120 | + a.grad = None |
| 121 | + b.grad = None |
| 122 | + _sync() |
| 123 | + |
| 124 | + start = time.perf_counter() |
| 125 | + for _ in range(iters): |
| 126 | + out = fn() |
| 127 | + out.sum().backward() |
| 128 | + grad_a = a.grad.detach().clone() |
| 129 | + grad_b = b.grad.detach().clone() |
| 130 | + a.grad = None |
| 131 | + b.grad = None |
| 132 | + _sync() |
| 133 | + total_ms = (time.perf_counter() - start) * 1000.0 / iters |
| 134 | + return total_ms, out.detach().clone(), grad_a, grad_b |
| 135 | + |
| 136 | + |
| 137 | +def _max_abs_diff(x: torch.Tensor, y: torch.Tensor) -> float: |
| 138 | + return float((x.float() - y.float()).abs().max().item()) |
| 139 | + |
| 140 | + |
| 141 | +def _check_close(x: torch.Tensor, y: torch.Tensor, rtol: float, atol: float) -> bool: |
| 142 | + try: |
| 143 | + torch.testing.assert_close(x, y, rtol=rtol, atol=atol) |
| 144 | + return True |
| 145 | + except Exception: |
| 146 | + return False |
| 147 | + |
| 148 | + |
| 149 | +def _bench_one(params: dict[str, object], rtol: float, atol: float): |
| 150 | + if not hasattr(torch, "npu") or not torch.npu.is_available(): |
| 151 | + raise RuntimeError("NPU is not available.") |
| 152 | + |
| 153 | + device = torch.device("npu") |
| 154 | + warmup = int(params["warmup"]) |
| 155 | + iters = int(params["iters"]) |
| 156 | + trans_b = bool(params["trans_b"]) |
| 157 | + |
| 158 | + a_base, b_base, batch_sizes = _build_inputs(params, device) |
| 159 | + a_npu = a_base.detach().clone().requires_grad_(True) |
| 160 | + b_npu = b_base.detach().clone().requires_grad_(True) |
| 161 | + |
| 162 | + fw_ms_base, ref_out = _time_forward( |
| 163 | + lambda: _run_baseline(a_base, b_base, batch_sizes, trans_b), |
| 164 | + warmup=warmup, |
| 165 | + iters=iters, |
| 166 | + ) |
| 167 | + total_ms_base, ref_out_bw, ref_da, ref_db = _time_forward_backward( |
| 168 | + lambda: _run_baseline(a_base, b_base, batch_sizes, trans_b), |
| 169 | + a=a_base, |
| 170 | + b=b_base, |
| 171 | + warmup=warmup, |
| 172 | + iters=iters, |
| 173 | + ) |
| 174 | + |
| 175 | + fw_ms_npu, out_npu = _time_forward( |
| 176 | + lambda: _run_npu_gmm_v2(a_npu, b_npu, batch_sizes, trans_b), |
| 177 | + warmup=warmup, |
| 178 | + iters=iters, |
| 179 | + ) |
| 180 | + total_ms_npu, out_npu_bw, da_npu, db_npu = _time_forward_backward( |
| 181 | + lambda: _run_npu_gmm_v2(a_npu, b_npu, batch_sizes, trans_b), |
| 182 | + a=a_npu, |
| 183 | + b=b_npu, |
| 184 | + warmup=warmup, |
| 185 | + iters=iters, |
| 186 | + ) |
| 187 | + |
| 188 | + print( |
| 189 | + f"[npu_grouped_gemm] name={params['name']} group={params['group_size']} " |
| 190 | + f"batch={params['batch_size']} k={params['k']} n={params['n']} " |
| 191 | + f"dtype={params['dtype']} trans_b={trans_b}" |
| 192 | + ) |
| 193 | + print("backend, fw_ms, bw_ms, total_ms") |
| 194 | + print(f"baseline, {fw_ms_base:.3f}, {total_ms_base - fw_ms_base:.3f}, {total_ms_base:.3f}") |
| 195 | + print(f"npu_gmm_v2, {fw_ms_npu:.3f}, {total_ms_npu - fw_ms_npu:.3f}, {total_ms_npu:.3f}") |
| 196 | + print( |
| 197 | + "speedup_vs_baseline, " |
| 198 | + f"fw={fw_ms_base / fw_ms_npu:.2f}x, " |
| 199 | + f"bw={(total_ms_base - fw_ms_base) / (total_ms_npu - fw_ms_npu):.2f}x, " |
| 200 | + f"total={total_ms_base / total_ms_npu:.2f}x" |
| 201 | + ) |
| 202 | + print("metric, close, max_abs_diff") |
| 203 | + print(f"forward, {_check_close(out_npu, ref_out, rtol, atol)}, {_max_abs_diff(out_npu, ref_out):.6f}") |
| 204 | + print( |
| 205 | + f"forward_bw_run, {_check_close(out_npu_bw, ref_out_bw, rtol, atol)}, {_max_abs_diff(out_npu_bw, ref_out_bw):.6f}" |
| 206 | + ) |
| 207 | + print(f"grad_a, {_check_close(da_npu, ref_da, rtol, atol)}, {_max_abs_diff(da_npu, ref_da):.6f}") |
| 208 | + print(f"grad_b, {_check_close(db_npu, ref_db, rtol, atol)}, {_max_abs_diff(db_npu, ref_db):.6f}") |
| 209 | + |
| 210 | + |
| 211 | +def main() -> int: |
| 212 | + parser = argparse.ArgumentParser() |
| 213 | + parser.add_argument("--rtol", type=float, default=1e-2) |
| 214 | + parser.add_argument("--atol", type=float, default=1e-2) |
| 215 | + args = parser.parse_args() |
| 216 | + |
| 217 | + torch.npu.set_device(0) |
| 218 | + for params in DEFAULT_PARAM_SETS: |
| 219 | + _bench_one(params, rtol=args.rtol, atol=args.atol) |
| 220 | + return 0 |
| 221 | + |
| 222 | + |
| 223 | +if __name__ == "__main__": |
| 224 | + raise SystemExit(main()) |
0 commit comments