Skip to content

Commit 69b8876

Browse files
[mxfp8 training] cuda kernel for unpadding token groups
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
1 parent 17fd81e commit 69b8876

File tree

10 files changed

+760
-61
lines changed

10 files changed

+760
-61
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import argparse
9+
import itertools
10+
import time
11+
from dataclasses import dataclass
12+
from typing import List
13+
14+
import torch
15+
from tabulate import tabulate
16+
from tqdm import tqdm
17+
18+
from benchmarks.utils import profile_fn
19+
from torchao.prototype.moe_training.kernels.mxfp8 import (
20+
_mxfp8_cuda_kernels_available,
21+
fused_unpad_token_groups_cuda,
22+
torch_pad_token_groups,
23+
torch_unpad_token_groups,
24+
)
25+
from torchao.prototype.moe_training.utils import generate_jagged_offs
26+
27+
device = torch.device("cuda")
28+
29+
# Needed since changing args to function causes recompiles
30+
torch._dynamo.config.cache_size_limit = 1000
31+
32+
33+
@dataclass(frozen=True)
34+
class ExperimentConfig:
35+
num_tokens: int
36+
dim: int
37+
num_groups: int
38+
alignment_size: int
39+
40+
41+
@dataclass(frozen=True)
42+
class ExperimentResult:
43+
torch_eager_time_us: float
44+
cuda_time_us: float
45+
torch_mem_bw_gbps: float
46+
cuda_mem_bw_gbps: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
# Various token group sizes and dimensions
57+
num_tokens_list = [16384]
58+
dim_list = [1536, 2048, 5120, 7168]
59+
num_groups_list = [1, 4, 8, 16]
60+
alignment_size_list = [32]
61+
62+
configs = []
63+
for num_tokens, dim, num_groups, alignment_size in itertools.product(
64+
num_tokens_list, dim_list, num_groups_list, alignment_size_list
65+
):
66+
configs.append(
67+
ExperimentConfig(
68+
num_tokens=num_tokens,
69+
dim=dim,
70+
num_groups=num_groups,
71+
alignment_size=alignment_size,
72+
)
73+
)
74+
return configs
75+
76+
77+
def benchmark_host_side_in_microseconds(fn, *args, num_iters=100, **kwargs):
78+
"""
79+
Benchmark using host-side timing, includes buffer allocation overhead.
80+
"""
81+
torch.cuda.synchronize()
82+
start = time.perf_counter()
83+
for _ in range(num_iters):
84+
fn(*args, **kwargs)
85+
torch.cuda.synchronize()
86+
end = time.perf_counter()
87+
return ((end - start) / num_iters) * 1e6 # Convert to microseconds
88+
89+
90+
def run_experiment(
91+
config: ExperimentConfig, args: argparse.Namespace
92+
) -> ExperimentResult:
93+
num_tokens, dim, num_groups, alignment_size = (
94+
config.num_tokens,
95+
config.dim,
96+
config.num_groups,
97+
config.alignment_size,
98+
)
99+
100+
# Create inputs and pad them first
101+
inputs = torch.randn(num_tokens, dim, dtype=torch.bfloat16, device=device)
102+
group_offsets = generate_jagged_offs(
103+
num_groups, num_tokens, multiple_of=1, device=device
104+
)
105+
106+
# Pad the inputs to get padded tensors for unpad benchmark
107+
padded_inputs, padded_group_end_offsets = torch_pad_token_groups(
108+
inputs, group_offsets, alignment_size
109+
)
110+
111+
# Compute padded group start offsets
112+
group_sizes = torch.diff(
113+
group_offsets,
114+
prepend=torch.zeros(1, dtype=group_offsets.dtype, device=group_offsets.device),
115+
)
116+
padded_sizes = (
117+
(group_sizes + alignment_size - 1) // alignment_size
118+
) * alignment_size
119+
padded_group_start_offsets = padded_group_end_offsets - padded_sizes
120+
121+
def torch_eager_with_offsets():
122+
return torch_unpad_token_groups(
123+
padded_inputs, group_offsets, padded_group_start_offsets, alignment_size
124+
)
125+
126+
def warmup(fn):
127+
for _ in range(5):
128+
fn()
129+
130+
# bench torch eager (includes buffer allocation overhead)
131+
warmup(torch_eager_with_offsets)
132+
torch_eager_time_us = benchmark_host_side_in_microseconds(torch_eager_with_offsets)
133+
if args.profile:
134+
profile_fn(
135+
torch_unpad_token_groups,
136+
padded_inputs,
137+
group_offsets,
138+
padded_group_start_offsets,
139+
alignment_size,
140+
profile_name="torch_unpad_token_groups_eager",
141+
)
142+
143+
# bench CUDA kernel if available
144+
if _mxfp8_cuda_kernels_available:
145+
146+
def cuda_with_offsets():
147+
return fused_unpad_token_groups_cuda(
148+
padded_inputs,
149+
group_offsets,
150+
padded_group_start_offsets,
151+
num_tokens,
152+
alignment_size,
153+
)
154+
155+
warmup(cuda_with_offsets)
156+
cuda_time_us = benchmark_host_side_in_microseconds(cuda_with_offsets)
157+
if args.profile:
158+
profile_fn(
159+
fused_unpad_token_groups_cuda,
160+
padded_inputs,
161+
group_offsets,
162+
padded_group_start_offsets,
163+
num_tokens,
164+
alignment_size,
165+
profile_name="fused_unpad_token_groups_cuda",
166+
)
167+
else:
168+
cuda_time_us = float("inf") # Not available
169+
170+
# mem bw calculations
171+
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8
172+
173+
read_bytes = (
174+
padded_inputs.numel() * bytes_per_el # Read padded input tokens
175+
+ group_offsets.numel() * 4 # Read group offsets (int32)
176+
+ padded_group_start_offsets.numel() * 4 # Read padded start offsets (int32)
177+
)
178+
179+
write_bytes = (
180+
inputs.numel() * bytes_per_el # Write unpadded data
181+
)
182+
183+
total_bytes = read_bytes + write_bytes
184+
185+
torch_mem_bw_gbps = (total_bytes / 1e9) / (torch_eager_time_us / 1e6)
186+
187+
if _mxfp8_cuda_kernels_available and cuda_time_us != float("inf"):
188+
cuda_mem_bw_gbps = (total_bytes / 1e9) / (cuda_time_us / 1e6)
189+
else:
190+
cuda_mem_bw_gbps = 0.0
191+
192+
return ExperimentResult(
193+
torch_eager_time_us=torch_eager_time_us,
194+
cuda_time_us=cuda_time_us,
195+
torch_mem_bw_gbps=torch_mem_bw_gbps,
196+
cuda_mem_bw_gbps=cuda_mem_bw_gbps,
197+
)
198+
199+
200+
def print_results(experiments: List[Experiment]):
201+
headers = [
202+
"num_tokens",
203+
"dim",
204+
"num_groups",
205+
"torch_us",
206+
"cuda_us",
207+
"torch_mem_bw_gbps",
208+
"cuda_mem_bw_gbps",
209+
"cuda_vs_torch",
210+
]
211+
rows = []
212+
for experiment in experiments:
213+
cuda_time = experiment.result.cuda_time_us
214+
cuda_vs_torch = (
215+
f"{experiment.result.torch_eager_time_us / cuda_time:.2f}x"
216+
if cuda_time != float("inf") and cuda_time > 0
217+
else "N/A"
218+
)
219+
cuda_bw_str = (
220+
f"{experiment.result.cuda_mem_bw_gbps:.2f}"
221+
if experiment.result.cuda_mem_bw_gbps > 0
222+
else "N/A"
223+
)
224+
225+
rows.append(
226+
[
227+
experiment.config.num_tokens,
228+
experiment.config.dim,
229+
experiment.config.num_groups,
230+
experiment.result.torch_eager_time_us,
231+
experiment.result.cuda_time_us,
232+
f"{experiment.result.torch_mem_bw_gbps:.2f}",
233+
cuda_bw_str,
234+
cuda_vs_torch,
235+
]
236+
)
237+
print(tabulate(rows, headers=headers))
238+
239+
240+
def main(args: argparse.Namespace):
241+
torch.random.manual_seed(123)
242+
configs = get_configs()
243+
results = []
244+
for config in tqdm(configs):
245+
result = run_experiment(config, args)
246+
results.append(Experiment(config=config, result=result))
247+
248+
# Use Tabulate to print results
249+
print_results(results)
250+
251+
252+
if __name__ == "__main__":
253+
parser = argparse.ArgumentParser()
254+
parser.add_argument(
255+
"--profile", action="store_true", help="Enable profiling with PyTorch profiler"
256+
)
257+
args = parser.parse_args()
258+
main(args)

test/prototype/moe_training/test_kernels.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from torchao.prototype.moe_training.kernels.mxfp8 import (
2424
_mxfp8_cuda_kernels_available,
2525
fused_pad_token_groups_cuda,
26+
fused_unpad_token_groups_cuda,
2627
mx_block_rearrange_2d_M_groups_cuda,
2728
mxfp8_quantize_cuda_3d,
2829
torch_pad_token_groups,
2930
torch_to_blocked_2d_K_groups,
3031
torch_to_blocked_2d_M_groups,
3132
torch_to_blocked_per_group_3d,
33+
torch_unpad_token_groups,
3234
triton_mx_block_rearrange_2d_K_groups,
3335
triton_mx_block_rearrange_2d_M_groups,
3436
triton_mx_block_rearrange_per_group_3d,
@@ -452,3 +454,67 @@ def test_cuda_fused_pad_token_groups(
452454
assert torch.equal(ref_padded_offsets, kernel_padded_offsets), (
453455
"Padded group offsets do not match"
454456
)
457+
458+
459+
@pytest.mark.skipif(
460+
not _mxfp8_cuda_kernels_available,
461+
reason="CUDA kernel requires sm_100 and CUDA 12.8+",
462+
)
463+
@skip_if_rocm("ROCm enablement in progress")
464+
@pytest.mark.parametrize("num_tokens", [128, 157, 4096])
465+
@pytest.mark.parametrize("dim", [7168])
466+
@pytest.mark.parametrize("num_groups", [1, 2, 4, 8])
467+
@pytest.mark.parametrize("alignment_size", [32])
468+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
469+
def test_cuda_fused_unpad_token_groups(
470+
num_tokens: int, dim: int, num_groups: int, alignment_size: int, dtype: torch.dtype
471+
):
472+
"""Test fused_unpad_token_groups_cuda kernel for removing padding from token groups."""
473+
device = "cuda"
474+
475+
# Create input activations
476+
inputs = torch.randn(num_tokens, dim, dtype=dtype, device=device)
477+
478+
# Generate group offsets (end indices for each group)
479+
group_offsets = generate_jagged_offs(
480+
num_groups, num_tokens, multiple_of=1, device=device
481+
)
482+
483+
# First pad the tokens to create padded inputs
484+
padded_tokens, padded_group_end_offsets = torch_pad_token_groups(
485+
inputs, group_offsets, alignment_size
486+
)
487+
488+
# Compute padded group start offsets
489+
group_sizes = torch.diff(
490+
group_offsets,
491+
prepend=torch.zeros(1, dtype=group_offsets.dtype, device=group_offsets.device),
492+
)
493+
padded_sizes = (
494+
(group_sizes + alignment_size - 1) // alignment_size
495+
) * alignment_size
496+
padded_group_start_offsets = padded_group_end_offsets - padded_sizes
497+
498+
# Get reference output using torch implementation
499+
ref_unpadded_tokens = torch_unpad_token_groups(
500+
padded_tokens, group_offsets, padded_group_start_offsets, alignment_size
501+
)
502+
503+
# Run CUDA kernel
504+
kernel_unpadded_tokens = fused_unpad_token_groups_cuda(
505+
padded_tokens,
506+
group_offsets,
507+
padded_group_start_offsets,
508+
num_tokens,
509+
alignment_size,
510+
)
511+
512+
# Verify outputs match
513+
assert torch.allclose(
514+
ref_unpadded_tokens, kernel_unpadded_tokens, rtol=0, atol=1e-5
515+
), "Unpadded tokens do not match"
516+
517+
# Verify that unpad correctly reverses pad operation
518+
assert torch.allclose(inputs, kernel_unpadded_tokens, rtol=0, atol=1e-5), (
519+
"Unpadded tokens should match original inputs"
520+
)

test/prototype/moe_training/test_mxfp8_grouped_mm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
126126
@pytest.mark.parametrize("num_experts", (1, 8))
127127
@pytest.mark.parametrize("wgrad_with_hp", (True, False))
128128
@pytest.mark.parametrize("use_compile", (False, True))
129+
@pytest.mark.parametrize("pad_token_groups_for_grouped_mm", (False, True))
129130
@pytest.mark.parametrize(
130131
"kernel_preference", (KernelPreference.AUTO, KernelPreference.EMULATED)
131132
)
@@ -141,12 +142,21 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
141142
use_compile,
142143
kernel_preference,
143144
scale_mode,
145+
pad_token_groups_for_grouped_mm,
144146
):
145147
# MXFP8 hardware path requires SM100
146148
if kernel_preference != KernelPreference.EMULATED and not is_sm_version(10, 0):
147149
pytest.skip(
148150
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
149151
)
152+
if (
153+
kernel_preference == KernelPreference.EMULATED
154+
and use_compile
155+
and pad_token_groups_for_grouped_mm
156+
):
157+
pytest.skip(
158+
f"torch native dynamic per group pad/unpad functions do not work with torch.compile yet: https://github.com/pytorch/pytorch/issues/176770"
159+
)
150160

151161
block_size = 32
152162
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
@@ -158,7 +168,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
158168
device="cuda",
159169
)
160170
w_t = w.transpose(-2, -1).requires_grad_(True)
161-
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
171+
172+
multiple_of = 1 if pad_token_groups_for_grouped_mm else 32
173+
offs = generate_jagged_offs(num_experts, M, multiple_of=multiple_of)
162174
x_ref, w_t_ref, offs_ref = (
163175
x.clone().detach().requires_grad_(True),
164176
w_t.clone().detach().requires_grad_(True),
@@ -179,6 +191,7 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
179191
kernel_preference=kernel_preference,
180192
wgrad_with_hp=wgrad_with_hp,
181193
scale_calculation_mode=scale_mode,
194+
pad_token_groups_for_grouped_mm=pad_token_groups_for_grouped_mm,
182195
)
183196
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
184197
sqnr = compute_error(ref_out, out)

0 commit comments

Comments
 (0)