Skip to content

Commit 5a07c6e

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Simplified NVFP4 quantize kernel for Torch API (#152)
Summary: This diff reworks the mslk nvfp4 stacked quantize kernel to hopefully be a bit simpler. As can be seen in gemm_ops.py, the new op minimizes extra artifacts needed for using the torch api for fp4fp4bf16_grouped_mm. This kernel is as performant as the mega kernel and hopefully robust, as shown in the added tests. Reviewed By: jiawenliu64 Differential Revision: D93169309
1 parent 1f190ee commit 5a07c6e

File tree

6 files changed

+547
-393
lines changed

6 files changed

+547
-393
lines changed

bench/gemm/gemm_bench.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,13 @@ def benchmark_grouped(
388388
tflops = 2 * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
389389
gbps = (
390390
(
391-
m[i] * k[i] * quantized_vals[0][0].element_size()
392-
+ n[i] * k[i] * quantized_vals[1][0].element_size()
393-
+ output_multiplier * m[i] * n[i] * output[0].element_size()
391+
quantized_vals[0][i].numel()
392+
* quantized_vals[0][i].element_size()
393+
+ quantized_vals[1][i].numel()
394+
* quantized_vals[1][i].element_size()
395+
+ output_multiplier
396+
* output[i].numel()
397+
* output[i].element_size()
394398
)
395399
/ (ms_runtime / 1e3)
396400
/ 1e9

bench/gemm/gemm_ops.py

Lines changed: 118 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
mega_fp4_pack,
2020
mega_fp4_quantize_kernel,
2121
mega_fp4_unpack,
22+
nvfp4_quantize_stacked,
2223
triton_quantize_mx4_unpack,
2324
triton_quantize_nvfp4,
2425
)
@@ -1237,11 +1238,6 @@ class FP8RowwiseGrouped(GemmOpBase):
12371238
FP8 grouped matmul with rowwise scaling.
12381239
"""
12391240

1240-
@property
1241-
def name(self) -> str:
1242-
prefix = "Cutlass" if torch.version.cuda else "CK"
1243-
return f"{prefix}{self.__class__.__name__}"
1244-
12451241
def preprocess(self, x, w):
12461242
m_values = [i.shape[0] for i in x]
12471243
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
@@ -2429,50 +2425,28 @@ class CutlassNVFP4GroupwiseGrouped(GemmOpBase):
24292425
def preprocess(self, x, w):
24302426
m_values = [i.shape[0] for i in x]
24312427
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2432-
x = torch.concat(x, dim=0).contiguous()
2428+
x_cat = torch.concat(x, dim=0).contiguous()
24332429

2434-
def get_global_scale(x, w, m_sizes):
2435-
G = len(w)
2436-
w_global_scale = []
2437-
global_scale = []
2430+
G = m_sizes.numel()
24382431

2439-
cumulative_sum = torch.zeros(
2440-
m_sizes.shape[0] + 1, dtype=torch.int64, device=m_sizes.device
2432+
# w_global_scale is static (weights don't change)
2433+
w_global_scale = []
2434+
for i in range(G):
2435+
w_gs = (448.0 * 6.0) / torch.amax(torch.abs(w[i].flatten()), dim=-1).to(
2436+
torch.float32
24412437
)
2442-
cumulative_sum[1:] = torch.cumsum(m_sizes, dim=0)
2443-
2444-
x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
2445-
2446-
for i in range(G):
2447-
w_global_scale_ = (448.0 * 6.0) / torch.amax(
2448-
torch.abs(w[i].flatten()), dim=-1
2449-
).to(torch.float32)
2450-
2451-
global_scale_ = 1 / (x_global_scale[i] * w_global_scale_)
2452-
2453-
w_global_scale.append(w_global_scale_)
2454-
global_scale.append(global_scale_)
2455-
2456-
return x_global_scale, w_global_scale, global_scale, tensor_idx
2457-
2458-
# Compute global scale for each group
2459-
G = m_sizes.numel()
2460-
x_global_scale, w_global_scale, global_scale, tensor_idx = get_global_scale(
2461-
x, w, m_sizes
2462-
)
2463-
global_scale = torch.stack(global_scale, dim=0).contiguous()
2438+
w_global_scale.append(w_gs)
2439+
w_global_scale = torch.stack(w_global_scale, dim=0).contiguous()
24642440

24652441
wq, w_scale = zip(
24662442
*[triton_quantize_nvfp4(w[i], w_global_scale[i]) for i in range(G)]
24672443
)
24682444
wq = torch.stack(wq, dim=0).contiguous()
24692445
w_scale = torch.stack(w_scale, dim=0).contiguous()
24702446

2471-
return x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2447+
return x_cat, wq, w_scale, w_global_scale, m_sizes
24722448

2473-
def quantize(
2474-
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2475-
):
2449+
def quantize(self, x, wq, w_scale, w_global_scale, m_sizes):
24762450
# alternative methods, may be useful in some scenarios
24772451
"""
24782452
starting_row_after_padding, belong_indices, row_within_tensor = (
@@ -2489,6 +2463,10 @@ def quantize(
24892463
)
24902464
"""
24912465

2466+
x_global_scale, tensor_idx = calculate_group_max(x, m_sizes=m_sizes)
2467+
2468+
global_scale = 1.0 / (x_global_scale * w_global_scale)
2469+
24922470
# we can optionally set optional_tensor_idx to None to run the alternative method
24932471
xq, x_scale, starting_row_after_padding = mega_fp4_quantize_kernel(
24942472
m_sizes, x, x_global_scale, optional_tensor_idx=tensor_idx
@@ -2527,9 +2505,7 @@ def compute(
25272505
)
25282506
return gemm_result
25292507

2530-
def quantize_and_compute(
2531-
self, x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2532-
):
2508+
def quantize_and_compute(self, x, wq, w_scale, w_global_scale, m_sizes):
25332509
(
25342510
xq,
25352511
wq,
@@ -2538,9 +2514,7 @@ def quantize_and_compute(
25382514
m_sizes,
25392515
global_scale,
25402516
starting_row_after_padding,
2541-
) = self.quantize(
2542-
x, wq, w_scale, x_global_scale, global_scale, m_sizes, tensor_idx
2543-
)
2517+
) = self.quantize(x, wq, w_scale, w_global_scale, m_sizes)
25442518
return self.compute(
25452519
xq,
25462520
wq,
@@ -2564,6 +2538,105 @@ def compute_dtype(self) -> ComputeDtype:
25642538
return ComputeDtype.FP4
25652539

25662540

2541+
@register_gemm_op
2542+
class CutlassNVFP4TorchGrouped(GemmOpBase):
2543+
"""
2544+
NVFP4 grouped matmul using per-expert global scales for activation
2545+
quantization (stacked_nvfp4_quantize), with per-expert alpha scales
2546+
applied post-GEMM via the torch offsets API.
2547+
"""
2548+
2549+
def preprocess(self, x, w):
2550+
m_values = [i.shape[0] for i in x]
2551+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2552+
x_cat = torch.concat(x, dim=0).contiguous()
2553+
2554+
G = m_sizes.numel()
2555+
N_per_expert = w[0].shape[0]
2556+
K = w[0].shape[1]
2557+
2558+
# Batch-quantize all expert weights in one shot using stacked kernel
2559+
w_cat = torch.cat(w, dim=0).contiguous() # [G*N, K]
2560+
w_m_sizes = torch.full(
2561+
(G,), N_per_expert, dtype=torch.int64, device=w_cat.device
2562+
)
2563+
w_global_scale, _ = calculate_group_max(w_cat, w_m_sizes)
2564+
wq, w_scale_2d = nvfp4_quantize_stacked(w_m_sizes, w_cat, w_global_scale)
2565+
2566+
# Reshape to [G, N, ...] for the GEMM
2567+
wq = wq.view(G, N_per_expert, K // 2)
2568+
padded_N = (N_per_expert + 127) // 128 * 128
2569+
w_scale = w_scale_2d[: G * padded_N].view(G, padded_N, -1)
2570+
2571+
# Precompute offsets for the torch API (cumulative end indices, int32)
2572+
offsets = torch.cumsum(m_sizes, dim=0).to(torch.int32)
2573+
2574+
return x_cat, wq, w_scale, w_global_scale, m_sizes, offsets
2575+
2576+
def quantize(self, x, wq, w_scale, w_global_scale, m_sizes, offsets):
2577+
x_global_scale, _ = calculate_group_max(x, m_sizes=m_sizes)
2578+
# global_scale = 1 / (x_gs * w_gs) per expert
2579+
global_scale = 1.0 / (x_global_scale * w_global_scale)
2580+
2581+
xq, x_scale = nvfp4_quantize_stacked(m_sizes, x, x_global_scale)
2582+
return (
2583+
xq,
2584+
wq,
2585+
x_scale,
2586+
w_scale,
2587+
global_scale,
2588+
offsets,
2589+
)
2590+
2591+
def compute(
2592+
self,
2593+
xq,
2594+
wq,
2595+
x_scale,
2596+
w_scale,
2597+
global_scale,
2598+
offsets,
2599+
):
2600+
return torch.ops.mslk.f4f4bf16_grouped_mm(
2601+
xq,
2602+
wq.transpose(-2, -1),
2603+
x_scale,
2604+
w_scale,
2605+
offsets,
2606+
global_scale=global_scale,
2607+
)
2608+
2609+
def quantize_and_compute(self, x, wq, w_scale, w_global_scale, m_sizes, offsets):
2610+
(
2611+
xq,
2612+
wq,
2613+
x_scale,
2614+
w_scale,
2615+
global_scale,
2616+
offsets,
2617+
) = self.quantize(x, wq, w_scale, w_global_scale, m_sizes, offsets)
2618+
return self.compute(
2619+
xq,
2620+
wq,
2621+
x_scale,
2622+
w_scale,
2623+
global_scale,
2624+
offsets,
2625+
)
2626+
2627+
@property
2628+
def supported_accelerators(self) -> set[Accelerator]:
2629+
return {Accelerator.NVIDIA_SM100, Accelerator.NVIDIA_SM103}
2630+
2631+
@property
2632+
def supported_gemm_types(self) -> set[GemmType]:
2633+
return {GemmType.GROUPED}
2634+
2635+
@property
2636+
def compute_dtype(self) -> ComputeDtype:
2637+
return ComputeDtype.FP4
2638+
2639+
25672640
# Broken with cuda graph
25682641
# @register_gemm_op
25692642
class CutlassNVFP4GroupwiseStackedGroupedPackUnpack(GemmOpBase):
@@ -2761,7 +2834,7 @@ def compute(self, x, w, offs):
27612834
)
27622835

27632836
def quantize_and_compute(self, x, w, offs):
2764-
x, w, offs = self.quantize(x, w)
2837+
x, w, offs = self.quantize(x, w, offs)
27652838
return self.compute(x, w, offs)
27662839

27672840
@property

bench/quantize/quantize_bench.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def benchmark(
149149
k: int,
150150
mem_bw_roofline_gbps: float,
151151
opts: BenchOptions,
152+
num_groups: int = 1,
152153
) -> list[Metrics]:
153154
# Create input tensors.
154155
input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
@@ -158,7 +159,7 @@ def benchmark(
158159
# Benchmark each operator.
159160
for quantize_op in quantize_ops:
160161
metrics = Metrics(op=quantize_op.name, M=m, K=k)
161-
args = quantize_op.preprocess(input)
162+
args = quantize_op.preprocess(input, num_groups=num_groups)
162163
quantized = quantize_op.quantize(input, *args)
163164
dequantized = quantize_op.dequantize(*quantized)
164165
metrics.sim = torch.mean(torch.pow(dequantized - input, 2)).item()
@@ -223,6 +224,12 @@ def print_kernels(kernels: Optional[list[str]]) -> None:
223224
is_flag=True,
224225
help="If set, instead of benchmarking cartesian product of M * K, benchmark consecutive MK pairs together.",
225226
)
227+
@click.option(
228+
"--num-groups",
229+
default=1,
230+
type=int,
231+
help="Number of groups (experts) to split M across for grouped/MoE quantize ops.",
232+
)
226233
def invoke_main(
227234
output_dir: str,
228235
num_iters: int,
@@ -236,6 +243,7 @@ def invoke_main(
236243
shapes: Optional[str],
237244
trace: bool,
238245
rep_ms: int,
246+
num_groups: int,
239247
) -> None:
240248
# If kernel filter is provided, parse it. Else, benchmark all kernels.
241249
all_kernels = kernels.strip().split(",") if kernels else None
@@ -271,6 +279,7 @@ def invoke_main(
271279
K,
272280
mem_bw_roofline_gbps,
273281
opts,
282+
num_groups=num_groups,
274283
)
275284
benchmark_results.extend(quantize_measurements)
276285
csv_row = {}

0 commit comments

Comments
 (0)