Skip to content

Commit f58c16f

Browse files
committed
Merge branch 'main' into fix_da8w4
2 parents 2a9ff0e + 3d02561 commit f58c16f

File tree

27 files changed

+1372
-428
lines changed

27 files changed

+1372
-428
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def get_gemm_times(
112112

113113
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
114114

115-
if recipe_name in ("mxfp4_cutlass", "nvfp4"):
115+
if recipe_name in ("mxfp4_cutlass", "nvfp4", "nvfp4_static"):
116116
d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16
117117
A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view(
118118
d1
@@ -151,7 +151,7 @@ def get_gemm_times(
151151
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
152152
scale_a = to_blocked(scale_a)
153153
scale_b = to_blocked(scale_b)
154-
elif recipe_name == "nvfp4":
154+
elif recipe_name in ("nvfp4", "nvfp4_static"):
155155
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
156156
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
157157
scale_a = to_blocked(scale_a)
@@ -177,7 +177,7 @@ def do_matmul(A, B):
177177
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
178178
output_dtype=d3,
179179
)
180-
if recipe_name == "nvfp4":
180+
if recipe_name in ("nvfp4", "nvfp4_static"):
181181
return torch._scaled_mm(
182182
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
183183
)
@@ -795,12 +795,29 @@ def run(
795795
)
796796
elif recipe_name == "nvfp4":
797797
config = NVFP4DynamicActivationNVFP4WeightConfig(
798-
use_dynamic_per_tensor_scale=False,
798+
use_dynamic_per_tensor_scale=True,
799+
)
800+
elif recipe_name == "nvfp4_static":
801+
config_calib = NVFP4DynamicActivationNVFP4WeightConfig(
802+
step="prepare",
803+
)
804+
config = NVFP4DynamicActivationNVFP4WeightConfig(
805+
step="convert",
799806
)
800807
else:
801808
assert False, "unsupported"
802809

803810
m_fp8_dyn = copy.deepcopy(m_orig)
811+
812+
if recipe_name == "nvfp4_static":
813+
# calibrate with sample data
814+
# this benchmark is performance-only, so a toy datum is fine
815+
quantize_(m_fp8_dyn, config_calib)
816+
toy_datum = torch.randn(
817+
M_val, K_val, dtype=torch.bfloat16, device="cuda"
818+
)
819+
m_fp8_dyn(toy_datum)
820+
804821
if op_name == "linear":
805822
quantize_(m_fp8_dyn, config)
806823
elif op_name == "conv2d":

benchmarks/mx_formats/cast_bench.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def run(
118118
"dim1_mxfp8_floor",
119119
"dim1_mxfp8_rceil",
120120
"dim1_mxfp8_triton_floor",
121+
"dim1_mxfp8_triton_rceil",
121122
"dim1_mxfp8_cuda_floor",
122123
"dim1_mxfp8_cuda_rceil",
123124
)
@@ -350,12 +351,41 @@ def run(
350351
bps = (bytes_r + bytes_w) / (time_us / 1e6)
351352

352353
elif mode == "dim1_mxfp8_triton_floor":
353-
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
354+
y_d1, s_d1 = triton_to_mxfp8_dim1(
355+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
356+
)
354357

355358
for _ in range(2):
356-
__ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
359+
__ = triton_to_mxfp8_dim1(
360+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
361+
)
357362
time_us = benchmark_cuda_function_in_microseconds(
358-
lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
363+
lambda x, b: triton_to_mxfp8_dim1(
364+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
365+
),
366+
x,
367+
BLOCK_SIZE,
368+
)
369+
370+
assert y_d1.dtype == torch.float8_e4m3fn
371+
assert s_d1.dtype == torch.float8_e8m0fnu
372+
bytes_r = x.numel() * bytes_per_el_bf16
373+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
374+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
375+
376+
elif mode == "dim1_mxfp8_triton_rceil":
377+
y_d1, s_d1 = triton_to_mxfp8_dim1(
378+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
379+
)
380+
381+
for _ in range(2):
382+
__ = triton_to_mxfp8_dim1(
383+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
384+
)
385+
time_us = benchmark_cuda_function_in_microseconds(
386+
lambda x, b: triton_to_mxfp8_dim1(
387+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
388+
),
359389
x,
360390
BLOCK_SIZE,
361391
)

benchmarks/prototype/moe_training/mxfp8/bench_pad_token_groups.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
import argparse
99
import itertools
10-
import time
1110
from dataclasses import dataclass
1211
from typing import List
1312

1413
import torch
1514
from tabulate import tabulate
1615
from tqdm import tqdm
1716

18-
from benchmarks.utils import profile_fn
17+
from benchmarks.utils import benchmark_cuda_function_in_microseconds, profile_fn
1918
from torchao.prototype.moe_training.kernels.mxfp8 import (
2019
_mxfp8_cuda_kernels_available,
2120
fused_pad_token_groups_cuda,
@@ -73,19 +72,6 @@ def get_configs() -> List[ExperimentConfig]:
7372
return configs
7473

7574

76-
def benchmark_host_side_in_microseconds(fn, *args, num_iters=100, **kwargs):
77-
"""
78-
Benchmark using host-side timing, includes buffer allocation overhead.
79-
"""
80-
torch.cuda.synchronize()
81-
start = time.perf_counter()
82-
for _ in range(num_iters):
83-
fn(*args, **kwargs)
84-
torch.cuda.synchronize()
85-
end = time.perf_counter()
86-
return ((end - start) / num_iters) * 1e6 # Convert to microseconds
87-
88-
8975
def run_experiment(
9076
config: ExperimentConfig, args: argparse.Namespace
9177
) -> ExperimentResult:
@@ -102,15 +88,19 @@ def torch_eager_with_offsets():
10288
group_offsets = generate_jagged_offs(
10389
num_groups, num_tokens, multiple_of=1, device=device
10490
)
105-
return torch_pad_token_groups(inputs, group_offsets, alignment_size)
91+
return torch_pad_token_groups(
92+
inputs, group_offsets, alignment_size
93+
) # Returns 3 values
10694

10795
def warmup(fn):
10896
for _ in range(5):
10997
fn()
11098

11199
# bench torch eager (includes buffer allocation overhead)
112100
warmup(torch_eager_with_offsets)
113-
torch_eager_time_us = benchmark_host_side_in_microseconds(torch_eager_with_offsets)
101+
torch_eager_time_us = benchmark_cuda_function_in_microseconds(
102+
torch_eager_with_offsets
103+
)
114104
if args.profile:
115105
group_offsets = generate_jagged_offs(
116106
num_groups, num_tokens, multiple_of=1, device=device
@@ -133,7 +123,7 @@ def cuda_with_offsets():
133123
return fused_pad_token_groups_cuda(inputs, group_offsets, alignment_size)
134124

135125
warmup(cuda_with_offsets)
136-
cuda_time_us = benchmark_host_side_in_microseconds(cuda_with_offsets)
126+
cuda_time_us = benchmark_cuda_function_in_microseconds(cuda_with_offsets)
137127
if args.profile:
138128
group_offsets = generate_jagged_offs(
139129
num_groups, num_tokens, multiple_of=1, device=device
@@ -152,8 +142,8 @@ def cuda_with_offsets():
152142
group_offsets = generate_jagged_offs(
153143
num_groups, num_tokens, multiple_of=1, device=device
154144
)
155-
torch_padded_tokens, torch_padded_offsets = torch_pad_token_groups(
156-
inputs, group_offsets, alignment_size
145+
torch_padded_tokens, torch_padded_start_offsets, torch_padded_offsets = (
146+
torch_pad_token_groups(inputs, group_offsets, alignment_size)
157147
)
158148

159149
bytes_per_el = torch.finfo(torch.bfloat16).bits / 8

0 commit comments

Comments
 (0)