Skip to content

Commit 26ae5c7

Browse files
[mxfp8 training] cuda kernel for unpadding token groups
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
1 parent 17fd81e commit 26ae5c7

File tree

14 files changed

+903
-174
lines changed

14 files changed

+903
-174
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_pad_token_groups.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def torch_eager_with_offsets():
102102
group_offsets = generate_jagged_offs(
103103
num_groups, num_tokens, multiple_of=1, device=device
104104
)
105-
return torch_pad_token_groups(inputs, group_offsets, alignment_size)
105+
return torch_pad_token_groups(
106+
inputs, group_offsets, alignment_size
107+
) # Returns 3 values
106108

107109
def warmup(fn):
108110
for _ in range(5):
@@ -152,8 +154,8 @@ def cuda_with_offsets():
152154
group_offsets = generate_jagged_offs(
153155
num_groups, num_tokens, multiple_of=1, device=device
154156
)
155-
torch_padded_tokens, torch_padded_offsets = torch_pad_token_groups(
156-
inputs, group_offsets, alignment_size
157+
torch_padded_tokens, torch_padded_start_offsets, torch_padded_offsets = (
158+
torch_pad_token_groups(inputs, group_offsets, alignment_size)
157159
)
158160

159161
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import argparse
9+
import itertools
10+
import time
11+
from dataclasses import dataclass
12+
from typing import List
13+
14+
import torch
15+
from tabulate import tabulate
16+
from tqdm import tqdm
17+
18+
from benchmarks.utils import profile_fn
19+
from torchao.prototype.moe_training.kernels.mxfp8 import (
20+
_mxfp8_cuda_kernels_available,
21+
fused_unpad_token_groups_cuda,
22+
torch_pad_token_groups,
23+
torch_unpad_token_groups,
24+
)
25+
from torchao.prototype.moe_training.utils import generate_jagged_offs
26+
27+
device = torch.device("cuda")
28+
29+
# Needed since changing args to function causes recompiles
30+
torch._dynamo.config.cache_size_limit = 1000
31+
32+
33+
@dataclass(frozen=True)
34+
class ExperimentConfig:
35+
num_tokens: int
36+
dim: int
37+
num_groups: int
38+
alignment_size: int
39+
40+
41+
@dataclass(frozen=True)
42+
class ExperimentResult:
43+
torch_eager_time_us: float
44+
cuda_time_us: float
45+
torch_mem_bw_gbps: float
46+
cuda_mem_bw_gbps: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
# Various token group sizes and dimensions
57+
num_tokens_list = [16384]
58+
dim_list = [1536, 2048, 5120, 7168]
59+
num_groups_list = [1, 4, 8, 16]
60+
alignment_size_list = [32]
61+
62+
configs = []
63+
for num_tokens, dim, num_groups, alignment_size in itertools.product(
64+
num_tokens_list, dim_list, num_groups_list, alignment_size_list
65+
):
66+
configs.append(
67+
ExperimentConfig(
68+
num_tokens=num_tokens,
69+
dim=dim,
70+
num_groups=num_groups,
71+
alignment_size=alignment_size,
72+
)
73+
)
74+
return configs
75+
76+
77+
def benchmark_host_side_in_microseconds(fn, *args, num_iters=100, **kwargs):
78+
"""
79+
Benchmark using host-side timing, includes buffer allocation overhead.
80+
"""
81+
torch.cuda.synchronize()
82+
start = time.perf_counter()
83+
for _ in range(num_iters):
84+
fn(*args, **kwargs)
85+
torch.cuda.synchronize()
86+
end = time.perf_counter()
87+
return ((end - start) / num_iters) * 1e6 # Convert to microseconds
88+
89+
90+
def run_experiment(
91+
config: ExperimentConfig, args: argparse.Namespace
92+
) -> ExperimentResult:
93+
num_tokens, dim, num_groups, alignment_size = (
94+
config.num_tokens,
95+
config.dim,
96+
config.num_groups,
97+
config.alignment_size,
98+
)
99+
100+
# Create inputs and pad them first
101+
inputs = torch.randn(num_tokens, dim, dtype=torch.bfloat16, device=device)
102+
group_offsets = generate_jagged_offs(
103+
num_groups, num_tokens, multiple_of=1, device=device
104+
)
105+
106+
# Pad the inputs to get padded tensors for unpad benchmark
107+
padded_inputs, padded_group_start_offsets, padded_group_end_offsets = (
108+
torch_pad_token_groups(inputs, group_offsets, alignment_size)
109+
)
110+
111+
def torch_eager_with_offsets():
112+
return torch_unpad_token_groups(
113+
padded_inputs,
114+
group_offsets,
115+
padded_group_start_offsets,
116+
num_tokens,
117+
alignment_size,
118+
)
119+
120+
def warmup(fn):
121+
for _ in range(5):
122+
fn()
123+
124+
# bench torch eager (includes buffer allocation overhead)
125+
warmup(torch_eager_with_offsets)
126+
torch_eager_time_us = benchmark_host_side_in_microseconds(torch_eager_with_offsets)
127+
if args.profile:
128+
profile_fn(
129+
torch_unpad_token_groups,
130+
padded_inputs,
131+
group_offsets,
132+
padded_group_start_offsets,
133+
alignment_size,
134+
profile_name="torch_unpad_token_groups_eager",
135+
)
136+
137+
# bench CUDA kernel if available
138+
if _mxfp8_cuda_kernels_available:
139+
140+
def cuda_with_offsets():
141+
return fused_unpad_token_groups_cuda(
142+
padded_inputs,
143+
group_offsets,
144+
padded_group_start_offsets,
145+
num_tokens,
146+
alignment_size,
147+
)
148+
149+
warmup(cuda_with_offsets)
150+
cuda_time_us = benchmark_host_side_in_microseconds(cuda_with_offsets)
151+
if args.profile:
152+
profile_fn(
153+
fused_unpad_token_groups_cuda,
154+
padded_inputs,
155+
group_offsets,
156+
padded_group_start_offsets,
157+
num_tokens,
158+
alignment_size,
159+
profile_name="fused_unpad_token_groups_cuda",
160+
)
161+
else:
162+
cuda_time_us = float("inf") # Not available
163+
164+
# mem bw calculations
165+
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8
166+
167+
read_bytes = (
168+
padded_inputs.numel() * bytes_per_el # Read padded input tokens
169+
+ group_offsets.numel() * 4 # Read group offsets (int32)
170+
+ padded_group_start_offsets.numel() * 4 # Read padded start offsets (int32)
171+
)
172+
173+
write_bytes = (
174+
inputs.numel() * bytes_per_el # Write unpadded data
175+
)
176+
177+
total_bytes = read_bytes + write_bytes
178+
179+
torch_mem_bw_gbps = (total_bytes / 1e9) / (torch_eager_time_us / 1e6)
180+
181+
if _mxfp8_cuda_kernels_available and cuda_time_us != float("inf"):
182+
cuda_mem_bw_gbps = (total_bytes / 1e9) / (cuda_time_us / 1e6)
183+
else:
184+
cuda_mem_bw_gbps = 0.0
185+
186+
return ExperimentResult(
187+
torch_eager_time_us=torch_eager_time_us,
188+
cuda_time_us=cuda_time_us,
189+
torch_mem_bw_gbps=torch_mem_bw_gbps,
190+
cuda_mem_bw_gbps=cuda_mem_bw_gbps,
191+
)
192+
193+
194+
def print_results(experiments: List[Experiment]):
195+
headers = [
196+
"num_tokens",
197+
"dim",
198+
"num_groups",
199+
"torch_us",
200+
"cuda_us",
201+
"torch_mem_bw_gbps",
202+
"cuda_mem_bw_gbps",
203+
"cuda_vs_torch",
204+
]
205+
rows = []
206+
for experiment in experiments:
207+
cuda_time = experiment.result.cuda_time_us
208+
cuda_vs_torch = (
209+
f"{experiment.result.torch_eager_time_us / cuda_time:.2f}x"
210+
if cuda_time != float("inf") and cuda_time > 0
211+
else "N/A"
212+
)
213+
cuda_bw_str = (
214+
f"{experiment.result.cuda_mem_bw_gbps:.2f}"
215+
if experiment.result.cuda_mem_bw_gbps > 0
216+
else "N/A"
217+
)
218+
219+
rows.append(
220+
[
221+
experiment.config.num_tokens,
222+
experiment.config.dim,
223+
experiment.config.num_groups,
224+
experiment.result.torch_eager_time_us,
225+
experiment.result.cuda_time_us,
226+
f"{experiment.result.torch_mem_bw_gbps:.2f}",
227+
cuda_bw_str,
228+
cuda_vs_torch,
229+
]
230+
)
231+
print(tabulate(rows, headers=headers))
232+
233+
234+
def main(args: argparse.Namespace):
235+
torch.random.manual_seed(123)
236+
configs = get_configs()
237+
results = []
238+
for config in tqdm(configs):
239+
result = run_experiment(config, args)
240+
results.append(Experiment(config=config, result=result))
241+
242+
# Use Tabulate to print results
243+
print_results(results)
244+
245+
246+
if __name__ == "__main__":
247+
parser = argparse.ArgumentParser()
248+
parser.add_argument(
249+
"--profile", action="store_true", help="Enable profiling with PyTorch profiler"
250+
)
251+
args = parser.parse_args()
252+
main(args)

setup.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -787,38 +787,38 @@ def get_extensions():
787787

788788
# Only build the cutlass_90a extension if sm90a is in the architecture flags
789789
# and if torch version >= 2.10
790-
if (
791-
cutlass_90a_sources is not None
792-
and len(cutlass_90a_sources) > 0
793-
and build_for_sm90a
794-
and _torch_version_at_least("2.10.0")
795-
):
796-
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
797-
# Only use sm90a architecture for these sources, ignoring other flags
798-
cutlass_90a_extra_compile_args["nvcc"].extend(
799-
[
800-
"-DUSE_CUDA",
801-
"-gencode=arch=compute_90a,code=sm_90a",
802-
"-DTORCH_TARGET_VERSION=0x020a000000000000",
803-
]
804-
)
805-
# Add compile flags for stable ABI support (requires torch >= 2.10)
806-
cutlass_90a_extra_compile_args["cxx"].extend(
807-
[
808-
"-DUSE_CUDA",
809-
"-DTORCH_TARGET_VERSION=0x020a000000000000",
810-
]
811-
)
812-
# stable ABI cutlass_90a module
813-
ext_modules.append(
814-
extension(
815-
"torchao._C_cutlass_90a",
816-
cutlass_90a_sources,
817-
py_limited_api=True,
818-
extra_compile_args=cutlass_90a_extra_compile_args,
819-
extra_link_args=extra_link_args,
820-
)
821-
)
790+
# if (
791+
# cutlass_90a_sources is not None
792+
# and len(cutlass_90a_sources) > 0
793+
# and build_for_sm90a
794+
# and _torch_version_at_least("2.10.0")
795+
# ):
796+
# cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
797+
# # Only use sm90a architecture for these sources, ignoring other flags
798+
# cutlass_90a_extra_compile_args["nvcc"].extend(
799+
# [
800+
# "-DUSE_CUDA",
801+
# "-gencode=arch=compute_90a,code=sm_90a",
802+
# "-DTORCH_TARGET_VERSION=0x020a000000000000",
803+
# ]
804+
# )
805+
# # Add compile flags for stable ABI support (requires torch >= 2.10)
806+
# cutlass_90a_extra_compile_args["cxx"].extend(
807+
# [
808+
# "-DUSE_CUDA",
809+
# "-DTORCH_TARGET_VERSION=0x020a000000000000",
810+
# ]
811+
# )
812+
# # stable ABI cutlass_90a module
813+
# ext_modules.append(
814+
# extension(
815+
# "torchao._C_cutlass_90a",
816+
# cutlass_90a_sources,
817+
# py_limited_api=True,
818+
# extra_compile_args=cutlass_90a_extra_compile_args,
819+
# extra_link_args=extra_link_args,
820+
# )
821+
# )
822822

823823
# Build CMakeLists from /torchao/csrc/cpu - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
824824
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":

test/prototype/moe_training/reference_moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ def generate_permute_indices(
147147
torch.int32
148148
)
149149

150+
# Ensure m_sizes sums to exactly max_len (the actual data size after permutation)
151+
current_sum = m_sizes.sum().item()
152+
if current_sum != max_len:
153+
# Add the difference to the last expert
154+
m_sizes[-1] = m_sizes[-1] + (max_len - current_sum)
155+
150156
m_offsets = torch.cumsum(m_sizes, 0)
151157
write_offsets = m_offsets - m_sizes
152158

@@ -176,8 +182,8 @@ def generate_permute_indices(
176182
# Utils from torchtitan/models/moe/utils.py
177183
# =============================================================================
178184

179-
TOKEN_GROUP_ALIGN_SIZE_M = 8
180-
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
185+
TOKEN_GROUP_ALIGN_SIZE_M = 1
186+
ValidTokenGroupAlignmentSize = Literal[1, 16, 32]
181187

182188

183189
def set_token_group_alignment_size_m(

0 commit comments

Comments
 (0)