Skip to content

Commit ba3532f

Browse files
[mxfp8 training] cuda kernel for unpadding token groups
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
1 parent 17fd81e commit ba3532f

File tree

11 files changed

+788
-93
lines changed

11 files changed

+788
-93
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import argparse
9+
import itertools
10+
import time
11+
from dataclasses import dataclass
12+
from typing import List
13+
14+
import torch
15+
from tabulate import tabulate
16+
from tqdm import tqdm
17+
18+
from benchmarks.utils import profile_fn
19+
from torchao.prototype.moe_training.kernels.mxfp8 import (
20+
_mxfp8_cuda_kernels_available,
21+
fused_unpad_token_groups_cuda,
22+
torch_pad_token_groups,
23+
torch_unpad_token_groups,
24+
)
25+
from torchao.prototype.moe_training.utils import generate_jagged_offs
26+
27+
device = torch.device("cuda")
28+
29+
# Needed since changing args to function causes recompiles
30+
torch._dynamo.config.cache_size_limit = 1000
31+
32+
33+
@dataclass(frozen=True)
34+
class ExperimentConfig:
35+
num_tokens: int
36+
dim: int
37+
num_groups: int
38+
alignment_size: int
39+
40+
41+
@dataclass(frozen=True)
42+
class ExperimentResult:
43+
torch_eager_time_us: float
44+
cuda_time_us: float
45+
torch_mem_bw_gbps: float
46+
cuda_mem_bw_gbps: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
# Various token group sizes and dimensions
57+
num_tokens_list = [16384]
58+
dim_list = [1536, 2048, 5120, 7168]
59+
num_groups_list = [1, 4, 8, 16]
60+
alignment_size_list = [32]
61+
62+
configs = []
63+
for num_tokens, dim, num_groups, alignment_size in itertools.product(
64+
num_tokens_list, dim_list, num_groups_list, alignment_size_list
65+
):
66+
configs.append(
67+
ExperimentConfig(
68+
num_tokens=num_tokens,
69+
dim=dim,
70+
num_groups=num_groups,
71+
alignment_size=alignment_size,
72+
)
73+
)
74+
return configs
75+
76+
77+
def benchmark_host_side_in_microseconds(fn, *args, num_iters=100, **kwargs):
78+
"""
79+
Benchmark using host-side timing, includes buffer allocation overhead.
80+
"""
81+
torch.cuda.synchronize()
82+
start = time.perf_counter()
83+
for _ in range(num_iters):
84+
fn(*args, **kwargs)
85+
torch.cuda.synchronize()
86+
end = time.perf_counter()
87+
return ((end - start) / num_iters) * 1e6 # Convert to microseconds
88+
89+
90+
def run_experiment(
91+
config: ExperimentConfig, args: argparse.Namespace
92+
) -> ExperimentResult:
93+
num_tokens, dim, num_groups, alignment_size = (
94+
config.num_tokens,
95+
config.dim,
96+
config.num_groups,
97+
config.alignment_size,
98+
)
99+
100+
# Create inputs and pad them first
101+
inputs = torch.randn(num_tokens, dim, dtype=torch.bfloat16, device=device)
102+
group_offsets = generate_jagged_offs(
103+
num_groups, num_tokens, multiple_of=1, device=device
104+
)
105+
106+
# Pad the inputs to get padded tensors for unpad benchmark
107+
padded_inputs, padded_group_end_offsets = torch_pad_token_groups(
108+
inputs, group_offsets, alignment_size
109+
)
110+
111+
# Compute padded group start offsets
112+
group_sizes = torch.diff(
113+
group_offsets,
114+
prepend=torch.zeros(1, dtype=group_offsets.dtype, device=group_offsets.device),
115+
)
116+
padded_sizes = (
117+
(group_sizes + alignment_size - 1) // alignment_size
118+
) * alignment_size
119+
padded_group_start_offsets = padded_group_end_offsets - padded_sizes
120+
121+
def torch_eager_with_offsets():
122+
return torch_unpad_token_groups(
123+
padded_inputs, group_offsets, padded_group_start_offsets, alignment_size
124+
)
125+
126+
def warmup(fn):
127+
for _ in range(5):
128+
fn()
129+
130+
# bench torch eager (includes buffer allocation overhead)
131+
warmup(torch_eager_with_offsets)
132+
torch_eager_time_us = benchmark_host_side_in_microseconds(torch_eager_with_offsets)
133+
if args.profile:
134+
profile_fn(
135+
torch_unpad_token_groups,
136+
padded_inputs,
137+
group_offsets,
138+
padded_group_start_offsets,
139+
alignment_size,
140+
profile_name="torch_unpad_token_groups_eager",
141+
)
142+
143+
# bench CUDA kernel if available
144+
if _mxfp8_cuda_kernels_available:
145+
146+
def cuda_with_offsets():
147+
return fused_unpad_token_groups_cuda(
148+
padded_inputs,
149+
group_offsets,
150+
padded_group_start_offsets,
151+
num_tokens,
152+
alignment_size,
153+
)
154+
155+
warmup(cuda_with_offsets)
156+
cuda_time_us = benchmark_host_side_in_microseconds(cuda_with_offsets)
157+
if args.profile:
158+
profile_fn(
159+
fused_unpad_token_groups_cuda,
160+
padded_inputs,
161+
group_offsets,
162+
padded_group_start_offsets,
163+
num_tokens,
164+
alignment_size,
165+
profile_name="fused_unpad_token_groups_cuda",
166+
)
167+
else:
168+
cuda_time_us = float("inf") # Not available
169+
170+
# mem bw calculations
171+
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8
172+
173+
read_bytes = (
174+
padded_inputs.numel() * bytes_per_el # Read padded input tokens
175+
+ group_offsets.numel() * 4 # Read group offsets (int32)
176+
+ padded_group_start_offsets.numel() * 4 # Read padded start offsets (int32)
177+
)
178+
179+
write_bytes = (
180+
inputs.numel() * bytes_per_el # Write unpadded data
181+
)
182+
183+
total_bytes = read_bytes + write_bytes
184+
185+
torch_mem_bw_gbps = (total_bytes / 1e9) / (torch_eager_time_us / 1e6)
186+
187+
if _mxfp8_cuda_kernels_available and cuda_time_us != float("inf"):
188+
cuda_mem_bw_gbps = (total_bytes / 1e9) / (cuda_time_us / 1e6)
189+
else:
190+
cuda_mem_bw_gbps = 0.0
191+
192+
return ExperimentResult(
193+
torch_eager_time_us=torch_eager_time_us,
194+
cuda_time_us=cuda_time_us,
195+
torch_mem_bw_gbps=torch_mem_bw_gbps,
196+
cuda_mem_bw_gbps=cuda_mem_bw_gbps,
197+
)
198+
199+
200+
def print_results(experiments: List[Experiment]):
201+
headers = [
202+
"num_tokens",
203+
"dim",
204+
"num_groups",
205+
"torch_us",
206+
"cuda_us",
207+
"torch_mem_bw_gbps",
208+
"cuda_mem_bw_gbps",
209+
"cuda_vs_torch",
210+
]
211+
rows = []
212+
for experiment in experiments:
213+
cuda_time = experiment.result.cuda_time_us
214+
cuda_vs_torch = (
215+
f"{experiment.result.torch_eager_time_us / cuda_time:.2f}x"
216+
if cuda_time != float("inf") and cuda_time > 0
217+
else "N/A"
218+
)
219+
cuda_bw_str = (
220+
f"{experiment.result.cuda_mem_bw_gbps:.2f}"
221+
if experiment.result.cuda_mem_bw_gbps > 0
222+
else "N/A"
223+
)
224+
225+
rows.append(
226+
[
227+
experiment.config.num_tokens,
228+
experiment.config.dim,
229+
experiment.config.num_groups,
230+
experiment.result.torch_eager_time_us,
231+
experiment.result.cuda_time_us,
232+
f"{experiment.result.torch_mem_bw_gbps:.2f}",
233+
cuda_bw_str,
234+
cuda_vs_torch,
235+
]
236+
)
237+
print(tabulate(rows, headers=headers))
238+
239+
240+
def main(args: argparse.Namespace):
241+
torch.random.manual_seed(123)
242+
configs = get_configs()
243+
results = []
244+
for config in tqdm(configs):
245+
result = run_experiment(config, args)
246+
results.append(Experiment(config=config, result=result))
247+
248+
# Use Tabulate to print results
249+
print_results(results)
250+
251+
252+
if __name__ == "__main__":
253+
parser = argparse.ArgumentParser()
254+
parser.add_argument(
255+
"--profile", action="store_true", help="Enable profiling with PyTorch profiler"
256+
)
257+
args = parser.parse_args()
258+
main(args)

setup.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -787,38 +787,38 @@ def get_extensions():
787787

788788
# Only build the cutlass_90a extension if sm90a is in the architecture flags
789789
# and if torch version >= 2.10
790-
if (
791-
cutlass_90a_sources is not None
792-
and len(cutlass_90a_sources) > 0
793-
and build_for_sm90a
794-
and _torch_version_at_least("2.10.0")
795-
):
796-
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
797-
# Only use sm90a architecture for these sources, ignoring other flags
798-
cutlass_90a_extra_compile_args["nvcc"].extend(
799-
[
800-
"-DUSE_CUDA",
801-
"-gencode=arch=compute_90a,code=sm_90a",
802-
"-DTORCH_TARGET_VERSION=0x020a000000000000",
803-
]
804-
)
805-
# Add compile flags for stable ABI support (requires torch >= 2.10)
806-
cutlass_90a_extra_compile_args["cxx"].extend(
807-
[
808-
"-DUSE_CUDA",
809-
"-DTORCH_TARGET_VERSION=0x020a000000000000",
810-
]
811-
)
812-
# stable ABI cutlass_90a module
813-
ext_modules.append(
814-
extension(
815-
"torchao._C_cutlass_90a",
816-
cutlass_90a_sources,
817-
py_limited_api=True,
818-
extra_compile_args=cutlass_90a_extra_compile_args,
819-
extra_link_args=extra_link_args,
820-
)
821-
)
790+
# if (
791+
# cutlass_90a_sources is not None
792+
# and len(cutlass_90a_sources) > 0
793+
# and build_for_sm90a
794+
# and _torch_version_at_least("2.10.0")
795+
# ):
796+
# cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
797+
# # Only use sm90a architecture for these sources, ignoring other flags
798+
# cutlass_90a_extra_compile_args["nvcc"].extend(
799+
# [
800+
# "-DUSE_CUDA",
801+
# "-gencode=arch=compute_90a,code=sm_90a",
802+
# "-DTORCH_TARGET_VERSION=0x020a000000000000",
803+
# ]
804+
# )
805+
# # Add compile flags for stable ABI support (requires torch >= 2.10)
806+
# cutlass_90a_extra_compile_args["cxx"].extend(
807+
# [
808+
# "-DUSE_CUDA",
809+
# "-DTORCH_TARGET_VERSION=0x020a000000000000",
810+
# ]
811+
# )
812+
# # stable ABI cutlass_90a module
813+
# ext_modules.append(
814+
# extension(
815+
# "torchao._C_cutlass_90a",
816+
# cutlass_90a_sources,
817+
# py_limited_api=True,
818+
# extra_compile_args=cutlass_90a_extra_compile_args,
819+
# extra_link_args=extra_link_args,
820+
# )
821+
# )
822822

823823
# Build CMakeLists from /torchao/csrc/cpu - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
824824
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":

0 commit comments

Comments
 (0)