Skip to content

Commit 2173ea9

Browse files
[mxfp8 training] cuda kernel for unpadding token groups
stack-info: PR: #4021, branch: danielvegamyhre/stack/148
1 parent 17fd81e commit 2173ea9

File tree

13 files changed

+864
-159
lines changed

13 files changed

+864
-159
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_pad_token_groups.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
import argparse
99
import itertools
10-
import time
1110
from dataclasses import dataclass
1211
from typing import List
1312

1413
import torch
1514
from tabulate import tabulate
1615
from tqdm import tqdm
1716

18-
from benchmarks.utils import profile_fn
17+
from benchmarks.utils import benchmark_cuda_function_in_microseconds, profile_fn
1918
from torchao.prototype.moe_training.kernels.mxfp8 import (
2019
_mxfp8_cuda_kernels_available,
2120
fused_pad_token_groups_cuda,
@@ -73,19 +72,6 @@ def get_configs() -> List[ExperimentConfig]:
7372
return configs
7473

7574

76-
def benchmark_host_side_in_microseconds(fn, *args, num_iters=100, **kwargs):
77-
"""
78-
Benchmark using host-side timing, includes buffer allocation overhead.
79-
"""
80-
torch.cuda.synchronize()
81-
start = time.perf_counter()
82-
for _ in range(num_iters):
83-
fn(*args, **kwargs)
84-
torch.cuda.synchronize()
85-
end = time.perf_counter()
86-
return ((end - start) / num_iters) * 1e6 # Convert to microseconds
87-
88-
8975
def run_experiment(
9076
config: ExperimentConfig, args: argparse.Namespace
9177
) -> ExperimentResult:
@@ -102,15 +88,19 @@ def torch_eager_with_offsets():
10288
group_offsets = generate_jagged_offs(
10389
num_groups, num_tokens, multiple_of=1, device=device
10490
)
105-
return torch_pad_token_groups(inputs, group_offsets, alignment_size)
91+
return torch_pad_token_groups(
92+
inputs, group_offsets, alignment_size
93+
) # Returns 3 values
10694

10795
def warmup(fn):
10896
for _ in range(5):
10997
fn()
11098

11199
# bench torch eager (includes buffer allocation overhead)
112100
warmup(torch_eager_with_offsets)
113-
torch_eager_time_us = benchmark_host_side_in_microseconds(torch_eager_with_offsets)
101+
torch_eager_time_us = benchmark_cuda_function_in_microseconds(
102+
torch_eager_with_offsets
103+
)
114104
if args.profile:
115105
group_offsets = generate_jagged_offs(
116106
num_groups, num_tokens, multiple_of=1, device=device
@@ -133,7 +123,7 @@ def cuda_with_offsets():
133123
return fused_pad_token_groups_cuda(inputs, group_offsets, alignment_size)
134124

135125
warmup(cuda_with_offsets)
136-
cuda_time_us = benchmark_host_side_in_microseconds(cuda_with_offsets)
126+
cuda_time_us = benchmark_cuda_function_in_microseconds(cuda_with_offsets)
137127
if args.profile:
138128
group_offsets = generate_jagged_offs(
139129
num_groups, num_tokens, multiple_of=1, device=device
@@ -152,8 +142,8 @@ def cuda_with_offsets():
152142
group_offsets = generate_jagged_offs(
153143
num_groups, num_tokens, multiple_of=1, device=device
154144
)
155-
torch_padded_tokens, torch_padded_offsets = torch_pad_token_groups(
156-
inputs, group_offsets, alignment_size
145+
torch_padded_tokens, torch_padded_start_offsets, torch_padded_offsets = (
146+
torch_pad_token_groups(inputs, group_offsets, alignment_size)
157147
)
158148

159149
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import argparse
9+
import itertools
10+
from dataclasses import dataclass
11+
from typing import List
12+
13+
import torch
14+
from tabulate import tabulate
15+
from tqdm import tqdm
16+
17+
from benchmarks.utils import benchmark_cuda_function_in_microseconds, profile_fn
18+
from torchao.prototype.moe_training.kernels.mxfp8 import (
19+
_mxfp8_cuda_kernels_available,
20+
fused_unpad_token_groups_cuda,
21+
torch_pad_token_groups,
22+
torch_unpad_token_groups,
23+
)
24+
from torchao.prototype.moe_training.utils import generate_jagged_offs
25+
26+
device = torch.device("cuda")
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentConfig:
34+
num_tokens: int
35+
dim: int
36+
num_groups: int
37+
alignment_size: int
38+
39+
40+
@dataclass(frozen=True)
41+
class ExperimentResult:
42+
torch_eager_time_us: float
43+
cuda_time_us: float
44+
torch_mem_bw_gbps: float
45+
cuda_mem_bw_gbps: float
46+
47+
48+
@dataclass(frozen=True)
49+
class Experiment:
50+
config: ExperimentConfig
51+
result: ExperimentResult
52+
53+
54+
def get_configs() -> List[ExperimentConfig]:
55+
# Various token group sizes and dimensions
56+
num_tokens_list = [16384]
57+
dim_list = [1536, 2048, 5120, 7168]
58+
num_groups_list = [1, 4, 8, 16]
59+
alignment_size_list = [32]
60+
61+
configs = []
62+
for num_tokens, dim, num_groups, alignment_size in itertools.product(
63+
num_tokens_list, dim_list, num_groups_list, alignment_size_list
64+
):
65+
configs.append(
66+
ExperimentConfig(
67+
num_tokens=num_tokens,
68+
dim=dim,
69+
num_groups=num_groups,
70+
alignment_size=alignment_size,
71+
)
72+
)
73+
return configs
74+
75+
76+
def run_experiment(
77+
config: ExperimentConfig, args: argparse.Namespace
78+
) -> ExperimentResult:
79+
num_tokens, dim, num_groups, alignment_size = (
80+
config.num_tokens,
81+
config.dim,
82+
config.num_groups,
83+
config.alignment_size,
84+
)
85+
86+
# Create inputs and pad them first
87+
inputs = torch.randn(num_tokens, dim, dtype=torch.bfloat16, device=device)
88+
group_offsets = generate_jagged_offs(
89+
num_groups, num_tokens, multiple_of=1, device=device
90+
)
91+
92+
# Pad the inputs to get padded tensors for unpad benchmark
93+
padded_inputs, padded_group_start_offsets, padded_group_end_offsets = (
94+
torch_pad_token_groups(inputs, group_offsets, alignment_size)
95+
)
96+
97+
def torch_eager_with_offsets():
98+
return torch_unpad_token_groups(
99+
padded_inputs,
100+
group_offsets,
101+
padded_group_start_offsets,
102+
num_tokens,
103+
alignment_size,
104+
)
105+
106+
def warmup(fn):
107+
for _ in range(5):
108+
fn()
109+
110+
# bench torch eager (includes buffer allocation overhead)
111+
warmup(torch_eager_with_offsets)
112+
torch_eager_time_us = benchmark_cuda_function_in_microseconds(
113+
torch_eager_with_offsets
114+
)
115+
if args.profile:
116+
profile_fn(
117+
torch_unpad_token_groups,
118+
padded_inputs,
119+
group_offsets,
120+
padded_group_start_offsets,
121+
alignment_size,
122+
profile_name="torch_unpad_token_groups_eager",
123+
)
124+
125+
# bench CUDA kernel if available
126+
if _mxfp8_cuda_kernels_available:
127+
128+
def cuda_with_offsets():
129+
return fused_unpad_token_groups_cuda(
130+
padded_inputs,
131+
group_offsets,
132+
padded_group_start_offsets,
133+
num_tokens,
134+
alignment_size,
135+
)
136+
137+
warmup(cuda_with_offsets)
138+
cuda_time_us = benchmark_cuda_function_in_microseconds(cuda_with_offsets)
139+
if args.profile:
140+
profile_fn(
141+
fused_unpad_token_groups_cuda,
142+
padded_inputs,
143+
group_offsets,
144+
padded_group_start_offsets,
145+
num_tokens,
146+
alignment_size,
147+
profile_name="fused_unpad_token_groups_cuda",
148+
)
149+
else:
150+
cuda_time_us = float("inf") # Not available
151+
152+
# mem bw calculations
153+
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8
154+
155+
read_bytes = (
156+
padded_inputs.numel() * bytes_per_el # Read padded input tokens
157+
+ group_offsets.numel() * 4 # Read group offsets (int32)
158+
+ padded_group_start_offsets.numel() * 4 # Read padded start offsets (int32)
159+
)
160+
161+
write_bytes = (
162+
inputs.numel() * bytes_per_el # Write unpadded data
163+
)
164+
165+
total_bytes = read_bytes + write_bytes
166+
167+
torch_mem_bw_gbps = (total_bytes / 1e9) / (torch_eager_time_us / 1e6)
168+
169+
if _mxfp8_cuda_kernels_available and cuda_time_us != float("inf"):
170+
cuda_mem_bw_gbps = (total_bytes / 1e9) / (cuda_time_us / 1e6)
171+
else:
172+
cuda_mem_bw_gbps = 0.0
173+
174+
return ExperimentResult(
175+
torch_eager_time_us=torch_eager_time_us,
176+
cuda_time_us=cuda_time_us,
177+
torch_mem_bw_gbps=torch_mem_bw_gbps,
178+
cuda_mem_bw_gbps=cuda_mem_bw_gbps,
179+
)
180+
181+
182+
def print_results(experiments: List[Experiment]):
183+
headers = [
184+
"num_tokens",
185+
"dim",
186+
"num_groups",
187+
"torch_us",
188+
"cuda_us",
189+
"torch_mem_bw_gbps",
190+
"cuda_mem_bw_gbps",
191+
"cuda_vs_torch",
192+
]
193+
rows = []
194+
for experiment in experiments:
195+
cuda_time = experiment.result.cuda_time_us
196+
cuda_vs_torch = (
197+
f"{experiment.result.torch_eager_time_us / cuda_time:.2f}x"
198+
if cuda_time != float("inf") and cuda_time > 0
199+
else "N/A"
200+
)
201+
cuda_bw_str = (
202+
f"{experiment.result.cuda_mem_bw_gbps:.2f}"
203+
if experiment.result.cuda_mem_bw_gbps > 0
204+
else "N/A"
205+
)
206+
207+
rows.append(
208+
[
209+
experiment.config.num_tokens,
210+
experiment.config.dim,
211+
experiment.config.num_groups,
212+
experiment.result.torch_eager_time_us,
213+
experiment.result.cuda_time_us,
214+
f"{experiment.result.torch_mem_bw_gbps:.2f}",
215+
cuda_bw_str,
216+
cuda_vs_torch,
217+
]
218+
)
219+
print(tabulate(rows, headers=headers))
220+
221+
222+
def main(args: argparse.Namespace):
223+
torch.random.manual_seed(123)
224+
configs = get_configs()
225+
results = []
226+
for config in tqdm(configs):
227+
result = run_experiment(config, args)
228+
results.append(Experiment(config=config, result=result))
229+
230+
# Use Tabulate to print results
231+
print_results(results)
232+
233+
234+
if __name__ == "__main__":
235+
parser = argparse.ArgumentParser()
236+
parser.add_argument(
237+
"--profile", action="store_true", help="Enable profiling with PyTorch profiler"
238+
)
239+
args = parser.parse_args()
240+
main(args)

test/prototype/moe_training/reference_moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ def generate_permute_indices(
147147
torch.int32
148148
)
149149

150+
# Ensure m_sizes sums to exactly max_len (the actual data size after permutation)
151+
current_sum = m_sizes.sum().item()
152+
if current_sum != max_len:
153+
# Add the difference to the last expert
154+
m_sizes[-1] = m_sizes[-1] + (max_len - current_sum)
155+
150156
m_offsets = torch.cumsum(m_sizes, 0)
151157
write_offsets = m_offsets - m_sizes
152158

@@ -176,8 +182,8 @@ def generate_permute_indices(
176182
# Utils from torchtitan/models/moe/utils.py
177183
# =============================================================================
178184

179-
TOKEN_GROUP_ALIGN_SIZE_M = 8
180-
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
185+
TOKEN_GROUP_ALIGN_SIZE_M = 1
186+
ValidTokenGroupAlignmentSize = Literal[1, 16, 32]
181187

182188

183189
def set_token_group_alignment_size_m(

0 commit comments

Comments
 (0)