diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index b207f5cb43..85fe801916 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -134,6 +134,7 @@ "bmm_fp8", "bmm_mxfp8", "mm_fp4", + "mm_mxfp8", ], "moe": [ "trtllm_fp4_block_scale_moe", @@ -296,6 +297,17 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cudnn"], "12.0": [], }, + "mm_mxfp8": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cutlass"], + "10.3": ["cutlass"], + "11.0": ["cutlass"], + "12.0": [], + }, # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here # MOE "trtllm_fp4_block_scale_moe": { diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index e9693432da..c4f2488fd6 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -43,6 +43,8 @@ def run_gemm_test(args): return testBmmMxfp8(args) elif args.routine == "mm_fp4": return testMmFp4(args) + elif args.routine == "mm_mxfp8": + return testMmMxfp8(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -147,12 +149,13 @@ def parse_gemm_args(line, parser): action="store_true", help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.", ) - # TODO: add bmm_mxfp8 ? parser.add_argument( "--autotune", action="store_true", default=False, - help=("Enable autotuner warmup for supported routines (mm_fp4 and bmm_fp8)."), + help=( + "Enable autotuner warmup for supported routines (mm_fp4, bmm_fp8, bmm_mxfp8 and mm_mxfp8)." + ), ) args = parser.parse_args(line) @@ -1233,3 +1236,212 @@ def run_backend( cur_res["case_tag"] = args.case_tag res.append(cur_res) return res + + +def testMmMxfp8(args): + """ + Test mm_mxfp8 API. + + This test: + 1. Generates random input tensors + 2. Quantizes input tensors to MXFP8 + 3. Runs mm_mxfp8 + 4. Runs reference check + 5. Measures performance metrics (TFLOPS, TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testMmMxfp8") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends + m = args.m + n = args.n + k = args.k + res_dtype = args.out_dtype + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + autotune_supported_backends = [ + "cutlass", + ] + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + res_dtype = dtype_str_to_torch_dtype(args.out_dtype) + if res_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." + ) + ## Done parsing input arguments + + if getattr(args, "autotune", False): + backends_to_remove = [] + for cur_backend in backends: + if cur_backend not in autotune_supported_backends: + print(f"[INFO] {cur_backend} backend does not support autotune") + backends_to_remove.append(cur_backend) + for cur_backend in backends_to_remove: + backends.remove(cur_backend) + + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + # Use swizzled layout for optimal performance + is_sf_swizzled_layout = True + + input = torch.randn([m, k], device=device, dtype=torch.bfloat16) + input_mxfp8, input_scale = mxfp8_quantize( + input, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + + mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) + mat2_mxfp8, mat2_scale = mxfp8_quantize( + mat2, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_mxfp8.shape = }") + print(f"[VVERBOSE] {input_mxfp8.dtype = }") + print(f"[VVERBOSE] {mat2_mxfp8.shape = }") + print(f"[VVERBOSE] {mat2_mxfp8.dtype = }") + print(f"[VVERBOSE] {input_scale.shape = }") + print(f"[VVERBOSE] {input_scale.dtype = }") + print(f"[VVERBOSE] {mat2_scale.shape = }") + print(f"[VVERBOSE] {mat2_scale.dtype = }") + + def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): + if backend == "cutlass": + return flashinfer.gemm.mm_mxfp8( + a=input_mxfp8, + b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t() + a_descale=input_scale, + b_descale=mat2_scale, # mm_mxfp8 handles swizzled 1D internally + out_dtype=res_dtype, + backend=backend, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + has_reference_output = False + if run_refcheck: + reference_output = torch.mm(input, mat2.t()) + has_reference_output = True + + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + for cur_backend in backends: + if cur_backend in autotune_supported_backends: + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for mm_mxfp8: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend( + cur_backend, + input_mxfp8, + mat2_mxfp8, + input_scale, + mat2_scale, + ) + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale + ).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=True, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale), + ) + + # Minimum cosine similarity for swizzled layout + min_cos_sim = 0.98 + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + cos_sim = F.cosine_similarity( + reference_output.reshape(-1), + tested_outputs[i].reshape(-1), + dim=0, + ) + if cos_sim < min_cos_sim: + print( + "[ERROR] Output tensor mismatch between reference " + f"{tested_backends[0]} and backend {tested_backends[i]}" + ) + if not args.allow_output_mismatch: + raise AssertionError( + "[ERROR] Output tensor mismatch between reference " + f"{tested_backends[0]} and backend {tested_backends[i]} " + f"with {cos_sim=} (expected >= {min_cos_sim})" + ) + for backend in backends: + backend_name = backend + ( + "_autotune" + if ( + getattr(args, "autotune", False) + and backend in autotune_supported_backends + ) + else "" + ) + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + problem_flops = 2 * m * n * k + # MXFP8 uses fp8_e4m3fn for data (1 byte) and uint8 for scales + # Scale tensors are much smaller, so approximate as 1 byte per element for simplicity + problem_bytes = ( + m * k * torch.float8_e4m3fn.itemsize + + n * k * torch.float8_e4m3fn.itemsize + + m * n * res_dtype.itemsize + ) + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["n"] = n + cur_res["k"] = k + cur_res["out_dtype"] = res_dtype + cur_res["backend"] = backend_name + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/csrc/mxfp8_gemm_cutlass.cu b/csrc/mxfp8_gemm_cutlass.cu new file mode 100644 index 0000000000..29493d9d5f --- /dev/null +++ b/csrc/mxfp8_gemm_cutlass.cu @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include "flashinfer/gemm/cutlass_gemm_configs.h" +#include "flashinfer/gemm/mxfp8_gemm_cutlass.h" +#include "flashinfer/gemm/mxfp8_gemm_cutlass_template.h" +#include "tvm_ffi_utils.h" + +using flashinfer::gemm::ClusterShape; +using flashinfer::gemm::CutlassGemmConfig; +using flashinfer::gemm::CutlassMxfp8GemmRunner; +using flashinfer::gemm::CutlassMxfp8GemmRunnerInterface; +using flashinfer::gemm::CutlassTileConfigSM100; +using flashinfer::gemm::EpilogueScheduleType; +using flashinfer::gemm::MainloopScheduleType; +using flashinfer::gemm::MXFP8GemmType; + +namespace flashinfer { +namespace gemm { +template class CutlassMxfp8GemmRunner<__nv_bfloat16, MXFP8GemmType::W8A8_MXFP8_MXFP8>; +template class CutlassMxfp8GemmRunner; +} // namespace gemm +} // namespace flashinfer + +namespace torch_ext { + +namespace { + +CutlassGemmConfig getMxfp8GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) { + auto getCutlassMxfp8GemmConfigs = []() { + CutlassMxfp8GemmRunner<__nv_bfloat16, MXFP8GemmType::W8A8_MXFP8_MXFP8> gemmRunner; + return gemmRunner.getConfigs(); + }; + static std::vector globalConfigs = getCutlassMxfp8GemmConfigs(); + TVM_FFI_ICHECK(tactic >= 0 && tactic < globalConfigs.size()) + << "tactic must be between 0 and " << globalConfigs.size(); + return globalConfigs[tactic]; +} + +template +void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Scale, + TensorView mat2Scale, int64_t m, int64_t n, int64_t k, int64_t batch_count, + CutlassGemmConfig const& gemmConfig, TensorView workspace_buffer) { + CutlassMxfp8GemmRunner gemmRunner; + + int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count); + int64_t const provided_workspace_size = + workspace_buffer.numel() * get_element_size(workspace_buffer); + + auto runKernel = [&](void* workspace) { + gemmRunner.gemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(), mat1Scale.data_ptr(), + mat2Scale.data_ptr(), m, n, k, batch_count, gemmConfig, + reinterpret_cast(workspace), required_workspace_size, + get_stream(mat1.device())); + }; + + if (provided_workspace_size < required_workspace_size) { + Tensor new_workspace = + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); + runKernel(new_workspace.data_ptr()); + } else { + runKernel(workspace_buffer.data_ptr()); + } +} + +constexpr auto FLOAT8_E4M3FN = dl_float8_e4m3fn; // float8_e4m3fn +constexpr auto SF_DTYPE = dl_uint8; // uint8_t + +// mat1: [B, M, K], FLOAT8_E4M3FN +// mat2: [B, N, K], FLOAT8_E4M3FN (passed as transposed, TensorView sees [N, K]) +// out: [B, M, N], fp16/bf16 +// mat1Scale/mat2Scale: SF_DTYPE (UE8M0), sfVecSize is always 32 +// B = 1 for GEMM op as a special case +void mxfp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView out, TensorView workspace_buffer, int64_t tactic) { + CHECK_INPUT_AND_TYPE(mat1, FLOAT8_E4M3FN); + CHECK_INPUT_AND_TYPE(mat2, FLOAT8_E4M3FN); + + CHECK_INPUT_AND_TYPE(mat1Scale, SF_DTYPE); + CHECK_INPUT_AND_TYPE(mat2Scale, SF_DTYPE); + + int64_t m, n, k, b; + // Scale validation for swizzled (1D) and non-swizzled (2D) layouts. + if (mat1.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix"; + // mat2 is passed as b.T, but TensorView reads underlying storage as [N, K] + // mat1 is [M, K] + // Check: mat1.size(1) == mat2.size(1) (both should be K) + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1)) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1) + << " and " << mat2.size(0) << "x" << mat2.size(1) << ")"; + m = mat1.size(0); + n = mat2.size(0); // mat2 is [N, K] in storage + k = mat2.size(1); // mat2 is [N, K] in storage + b = 1; + } else if (mat1.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices"; + TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size (" + << mat1.size(0) << " and " << mat2.size(0) << ")"; + // mat2 is passed as b.T, but TensorView reads underlying storage as [B, N, K] + // mat1 is [B, M, K] + // Check: mat1.size(2) == mat2.size(2) (both should be K) + TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2)) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2) + << " and " << mat2.size(1) << "x" << mat2.size(2) << ")"; + m = mat1.size(1); + n = mat2.size(1); // mat2 is [B, N, K] in storage + k = mat2.size(2); // mat2 is [B, N, K] in storage + b = mat1.size(0); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices"; + } + + constexpr int64_t sfVecSize = 32; // MXFP8 block size + auto scale_len = [&](int64_t dim) { return (dim + sfVecSize - 1) / sfVecSize; }; + auto swizzled_len = [&](int64_t rows, int64_t cols) { + auto pad_up = [](int64_t value, int64_t multiple) { + return (value + multiple - 1) / multiple * multiple; + }; + int64_t padded_rows = pad_up(rows, 128); + int64_t padded_cols = pad_up(cols, 4); + return padded_rows * padded_cols; + }; + + if (mat1.ndim() == 2) { + const int64_t k_scales = scale_len(k); + if (mat1Scale.ndim() == 1) { + int64_t expected = swizzled_len(m, k_scales); + TVM_FFI_ICHECK_EQ(mat1Scale.size(0), expected) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << expected << ", got " + << mat1Scale.size(0); + } else { + TVM_FFI_ICHECK_EQ(mat1Scale.ndim(), 2) + << "mxfp8_bmm_impl: mat1Scale must be 1D (swizzled) or 2D (non-swizzled), got " + << mat1Scale.ndim(); + TVM_FFI_ICHECK_EQ(mat1Scale.size(0), m) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << m << ", got " + << mat1Scale.size(0); + TVM_FFI_ICHECK_EQ(mat1Scale.size(1), k_scales) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << k_scales << ", got " + << mat1Scale.size(1); + } + + if (mat2Scale.ndim() == 1) { + int64_t expected = swizzled_len(n, k_scales); + TVM_FFI_ICHECK_EQ(mat2Scale.size(0), expected) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << expected << ", got " + << mat2Scale.size(0); + } else { + TVM_FFI_ICHECK_EQ(mat2Scale.ndim(), 2) + << "mxfp8_bmm_impl: mat2Scale must be 1D (swizzled) or 2D (non-swizzled), got " + << mat2Scale.ndim(); + TVM_FFI_ICHECK_EQ(mat2Scale.size(0), n) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << n << ", got " + << mat2Scale.size(0); + TVM_FFI_ICHECK_EQ(mat2Scale.size(1), k_scales) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << k_scales << ", got " + << mat2Scale.size(1); + } + } else { + const int64_t k_scales = scale_len(k); + if (mat1Scale.ndim() == 1) { + int64_t expected = swizzled_len(b * m, k_scales); + TVM_FFI_ICHECK_EQ(mat1Scale.size(0), expected) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << expected << ", got " + << mat1Scale.size(0); + } else if (mat1Scale.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat1Scale.size(1), k_scales) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << k_scales << ", got " + << mat1Scale.size(1); + TVM_FFI_ICHECK_EQ(mat1Scale.size(0), b * m) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << (b * m) << ", got " + << mat1Scale.size(0); + } else { + TVM_FFI_ICHECK_EQ(mat1Scale.ndim(), 3) + << "mxfp8_bmm_impl: mat1Scale must be 1D (swizzled), 2D (flattened), or 3D " + "(batched), got " + << mat1Scale.ndim(); + TVM_FFI_ICHECK_EQ(mat1Scale.size(0), b) + << "mxfp8_bmm_impl: mat1Scale batch size mismatch, expected " << b << ", got " + << mat1Scale.size(0); + TVM_FFI_ICHECK_EQ(mat1Scale.size(1), m) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << m << ", got " + << mat1Scale.size(1); + TVM_FFI_ICHECK_EQ(mat1Scale.size(2), k_scales) + << "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << k_scales << ", got " + << mat1Scale.size(2); + } + + if (mat2Scale.ndim() == 1) { + int64_t expected = swizzled_len(b * n, k_scales); + TVM_FFI_ICHECK_EQ(mat2Scale.size(0), expected) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << expected << ", got " + << mat2Scale.size(0); + } else if (mat2Scale.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2Scale.size(1), k_scales) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << k_scales << ", got " + << mat2Scale.size(1); + TVM_FFI_ICHECK_EQ(mat2Scale.size(0), b * n) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << (b * n) << ", got " + << mat2Scale.size(0); + } else { + TVM_FFI_ICHECK_EQ(mat2Scale.ndim(), 3) + << "mxfp8_bmm_impl: mat2Scale must be 1D (swizzled), 2D (flattened), or 3D " + "(batched), got " + << mat2Scale.ndim(); + TVM_FFI_ICHECK_EQ(mat2Scale.size(0), b) + << "mxfp8_bmm_impl: mat2Scale batch size mismatch, expected " << b << ", got " + << mat2Scale.size(0); + TVM_FFI_ICHECK_EQ(mat2Scale.size(1), n) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << n << ", got " + << mat2Scale.size(1); + TVM_FFI_ICHECK_EQ(mat2Scale.size(2), k_scales) + << "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << k_scales << ", got " + << mat2Scale.size(2); + } + } + + // No heuristic for now, we rely on the autotuner to select the best tactic. + if (tactic == -1) { + tactic = 0; + } + auto config = getMxfp8GemmConfig(m, n, k, tactic); + + constexpr int alignment = 32; + TVM_FFI_ICHECK_EQ(k % alignment, 0) + << "Expected k to be divisible by " << alignment << ", but got mat1 shape: (" << mat1.size(0) + << "x" << mat1.size(1) << "), k: " << k << "."; + TVM_FFI_ICHECK_EQ(n % alignment, 0) + << "Expected n to be divisible by " << alignment << ", but got mat2 shape: (" << mat2.size(0) + << "x" << mat2.size(1) << ")."; + + // Validate out dimensions + std::vector out_shape = + mat1.ndim() == 2 ? std::vector{m, n} : std::vector{b, m, n}; + TVM_FFI_ICHECK_EQ(out.ndim(), out_shape.size()) + << "out must have " << out_shape.size() << " dimensions, but got " << out.ndim(); + for (int i = 0; i < out_shape.size(); ++i) { + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) + << "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got " + << out.size(i); + } + + switch (encode_dlpack_dtype(out.dtype())) { + case float16_code: + runGemm(out, mat1, mat2, mat1Scale, mat2Scale, m, n, k, b, config, workspace_buffer); + break; + case bfloat16_code: + runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, m, n, k, b, config, + workspace_buffer); + break; + default: + TVM_FFI_ICHECK(false) << "out_dtype must be one of fp16/bf16."; + } +} + +} // namespace + +void mxfp8_gemm(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView out, TensorView workspace_buffer, int64_t tactic) { + mxfp8_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, out, workspace_buffer, tactic); +} + +int64_t mxfp8_gemm_tactic_num() { + auto getCutlassConfigs = []() { + CutlassMxfp8GemmRunner<__nv_bfloat16, MXFP8GemmType::W8A8_MXFP8_MXFP8> gemmRunner; + return gemmRunner.getConfigs(); + }; + static int64_t totalTactics = getCutlassConfigs().size(); + return totalTactics; +} + +} // namespace torch_ext + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mxfp8_gemm, torch_ext::mxfp8_gemm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mxfp8_gemm_tactic_num, torch_ext::mxfp8_gemm_tactic_num); diff --git a/csrc/mxfp8_gemm_cutlass.jinja b/csrc/mxfp8_gemm_cutlass.jinja new file mode 100644 index 0000000000..3e9da03a38 --- /dev/null +++ b/csrc/mxfp8_gemm_cutlass.jinja @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flashinfer/gemm/mxfp8_gemm_cutlass_template.h" + +namespace flashinfer { +namespace gemm { +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM) +INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM) + +} // namespace gemm +} // namespace flashinfer diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index c22b4a0a55..460438de74 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -92,6 +92,7 @@ from .gemm import mm_bf16 as mm_bf16 from .gemm import mm_fp4 as mm_fp4 from .gemm import mm_fp8 as mm_fp8 +from .gemm import mm_mxfp8 as mm_mxfp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper from .norm import fused_add_rmsnorm as fused_add_rmsnorm diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index bd30c178dc..9dde43b02d 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -5,6 +5,7 @@ from .gemm_base import mm_bf16 as mm_bf16 from .gemm_base import mm_fp4 as mm_fp4 from .gemm_base import mm_fp8 as mm_fp8 +from .gemm_base import mm_mxfp8 as mm_mxfp8 from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100 from .gemm_base import group_gemm_mxfp4_nt_groupwise as group_gemm_mxfp4_nt_groupwise from .gemm_base import ( @@ -31,6 +32,7 @@ "mm_bf16", "mm_fp4", "mm_fp8", + "mm_mxfp8", "tgv_gemm_sm100", "group_gemm_mxfp4_nt_groupwise", "batch_deepgemm_fp8_nt_groupwise", diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index aba97ce2ea..e5d48edddb 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -52,6 +52,7 @@ from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4 from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4 from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_mxfp8 from ..jit.gemm import gen_gemm_sm100_module_cutlass_bf16 from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module @@ -2400,6 +2401,382 @@ def mm_fp8( return out +def _create_cutlass_mxfp8_gemm_module(module, op_name: str, tuner_name: str): + """Helper function to create cutlass MXFP8 GEMM module.""" + + def cutlass_mxfp8_gemm_runner(): + class CutlassMxfp8GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.mxfp8_gemm_tactic_num())) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + _, + out, + workspace_buffer, + ) = inputs + + # CUTLASS expects b_descale in (N, K/32). + # 2D input is (K/32, N) and must be transposed; 1D swizzled is pass-through. + if b_descale.ndim == 2: + # Input is (K/32, N), transpose to (N, K/32) for CUTLASS + b_descale_processed = b_descale.T + if not b_descale_processed.is_contiguous(): + b_descale_processed = b_descale_processed.contiguous() + else: + # 1D swizzled format - pass as-is, just ensure contiguous + b_descale_processed = b_descale + if not b_descale_processed.is_contiguous(): + b_descale_processed = b_descale_processed.contiguous() + + module.mxfp8_gemm( + a, + b.T, + a_descale, + b_descale_processed, + out, + workspace_buffer, + tactic, + ) + return out + + return CutlassMxfp8GemmRunner() + + return SimpleNamespace( + cutlass_mxfp8_gemm_runner=cutlass_mxfp8_gemm_runner, + ) + + +@functools.cache +def get_gemm_sm100_module_cutlass_mxfp8(): + """Get the SM100/103/110 MXFP8 GEMM module.""" + module = gen_gemm_sm100_module_cutlass_mxfp8().build_and_load() + return _create_cutlass_mxfp8_gemm_module( + module, "flashinfer::cutlass_mxfp8_gemm", "cutlass_mxfp8_gemm" + ) + + +def get_cutlass_mxfp8_gemm_module( + sm_major: int, +): + if sm_major in [10, 11]: + return get_gemm_sm100_module_cutlass_mxfp8() + else: + raise ValueError(f"Unsupported SM major version: {sm_major}") + + +def _check_mm_mxfp8_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass", "auto"] = "auto", # unused +) -> bool: + # Generic checks + ## pre-check the input tensors and block scale tensors + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_mxfp8 accepts 2d tensors, got {a.shape=} and {b.shape=}") + + # b is passed transposed (shape [k, n]), so verify K matches. + if a.shape[1] != b.shape[0]: + raise ValueError( + f"K dimension mismatch in mm_mxfp8. got {a.shape[1]=}, {b.shape[0]=}" + ) + + # The output may contain NaN/Inf if the dimensions are too small + min_n = 128 + min_k = 128 + if b.shape[1] < min_n or a.shape[1] < min_k: + raise ValueError( + f"MXFP8 requires n >= {min_n} and k >= {min_k} for CUTLASS MXFP8. " + f"got m={a.shape[0]}, n={b.shape[1]}, k={a.shape[1]}." + ) + + # Input dtype as returned by mxfp8_quantize_sm100 + if a.dtype != torch.float8_e4m3fn: + raise ValueError(f"a must be a float8_e4m3fn tensor, got {a.dtype=}") + + if b.dtype != torch.float8_e4m3fn: + raise ValueError(f"b must be a float8_e4m3fn tensor, got {b.dtype=}") + + # Scale dtype as returned by mxfp8_quantize_sm100 + if a_descale.dtype != torch.uint8: + raise ValueError(f"a_descale must be a uint8 tensor, got {a_descale.dtype=}") + + if b_descale.dtype != torch.uint8: + raise ValueError(f"b_descale must be a uint8 tensor, got {b_descale.dtype=}") + + # MXFP8 block size + sf_vec_size = 32 + + if a_descale.ndim == 1: + expected_len = _mxfp8_swizzled_scale_len(a.shape[0], a.shape[1]) + if a_descale.shape[0] != expected_len: + raise ValueError( + "a_descale shape mismatch for swizzled layout. " + f"Expected {(expected_len,)}, got {a_descale.shape}." + ) + elif a_descale.ndim == 2: + if a.shape[1] % sf_vec_size != 0: + raise ValueError( + "a_descale shape mismatch for non-swizzled layout. " + f"a.shape[1] must be divisible by {sf_vec_size}, got {a.shape[1]}." + ) + expected_shape = (a.shape[0], a.shape[1] // sf_vec_size) + if a_descale.shape != expected_shape: + raise ValueError( + "a_descale shape mismatch for non-swizzled layout. " + f"Expected {expected_shape}, got {a_descale.shape}." + ) + else: + raise ValueError( + f"a_descale must be 1D (swizzled) or 2D (non-swizzled), got {a_descale.shape}." + ) + + if b_descale.ndim == 1: + expected_len = _mxfp8_swizzled_scale_len(b.shape[1], b.shape[0]) + if b_descale.shape[0] != expected_len: + raise ValueError( + "b_descale shape mismatch for swizzled layout. " + f"Expected {(expected_len,)}, got {b_descale.shape}." + ) + elif b_descale.ndim == 2: + if b.shape[0] % sf_vec_size != 0: + raise ValueError( + "b_descale shape mismatch for non-swizzled layout. " + f"b.shape[0] must be divisible by {sf_vec_size}, got {b.shape[0]}." + ) + expected_shape = (b.shape[0] // sf_vec_size, b.shape[1]) + if b_descale.shape != expected_shape: + raise ValueError( + "b_descale shape mismatch for non-swizzled layout. " + f"Expected {expected_shape}, got {b_descale.shape}." + ) + else: + raise ValueError( + f"b_descale must be 1D (swizzled) or 2D (non-swizzled), got {b_descale.shape}." + ) + + if out is not None: + expected_shape = (a.shape[0], b.shape[1]) + if out.shape != expected_shape: + raise ValueError( + f"Output shape mismatch. Expected {expected_shape}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + + _validate_mxfp8_output_dtype(out_dtype) + return True + + +@supported_compute_capability([100, 103, 110]) +def _cutlass_gemm_mxfp8_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass", "auto"] = "auto", +): + return True + + +def _heuristic_func_mm_mxfp8( + suitable_backends: List[str], + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass", "auto"] = "auto", +) -> List[str]: + if "cutlass" in suitable_backends: + return ["cutlass"] + return [] + + +@backend_requirement( + { + "cutlass": _cutlass_gemm_mxfp8_requirement, + }, + common_check=_check_mm_mxfp8_problem_size, + heuristic_func=_heuristic_func_mm_mxfp8, # result stored in mm_mxfp8.suitable_auto_backends +) +@flashinfer_api +def mm_mxfp8( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass", "auto"] = "auto", +) -> torch.Tensor: + r"""MM MXFP8 (block size 32) + + Parameters + ---------- + a: torch.Tensor + Input A tensor, shape (m, k), mxfp8 e4m3. + + b: torch.Tensor + Input B tensor, shape (k, n), should be column major, mxfp8 e4m3. + + a_descale: torch.Tensor + Block scale tensor for A. Can be: + - 2D non-swizzled: shape (m, k // 32) + - 1D swizzled: shape (M_padded * K_padded,) where M_padded = round_up(m, 128), K_padded = round_up(k // 32, 4) + dtype: uint8. + + b_descale: torch.Tensor + Block scale tensor for B. Can be: + - 2D non-swizzled: shape (k // 32, n) - transposed format + - 1D swizzled: shape (N_padded * K_padded,) where N_padded = round_up(n, 128), K_padded = round_up(k // 32, 4) + dtype: uint8. + Note: For 2D format, this is the transposed version (typically passed as scale.t()). + For 1D swizzled format, it's flattened from (N_padded, K_padded) layout. + + out: Optional[torch.Tensor] + Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to ``None``. + + out_dtype: torch.dtype + Output dtype, bf16 or fp16. Defaults to ``torch.bfloat16``. + + backend: Literal["cutlass", "auto"] + The backend to use for the operation. Defaults to ``"auto"``. + ``"auto"`` selects the CUTLASS backend. + + Returns + ------- + out: torch.Tensor + Out tensor, shape (m, n), bf16 or fp16. + + Examples + -------- + >>> import torch + >>> from flashinfer import mxfp8_quantize, mm_mxfp8 + >>> m, n, k = 512, 256, 128 + >>> # Create input tensors - note: weight is [n, k] for typical NN layers + >>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + >>> weight = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + >>> + >>> # Option 1: Use swizzled layout (recommended for accuracy) + >>> # Quantize input [m, k] - scales are 1D swizzled for (M, K/32) layout + >>> a_mx, a_sf = mxfp8_quantize(input=a, is_sf_swizzled_layout=True) + >>> # Quantize weight [n, k] - scales are 1D swizzled for (N, K/32) layout + >>> w_mx, w_sf = mxfp8_quantize(input=weight, is_sf_swizzled_layout=True) + >>> # Pass weight.T as [k, n] and 1D swizzled scales directly + >>> out = mm_mxfp8(a_mx, w_mx.t(), a_sf, w_sf, out_dtype=torch.bfloat16) + >>> out.shape + torch.Size([512, 256]) + >>> + >>> # Option 2: Use non-swizzled layout (for compatibility) + >>> a_mx, a_sf = mxfp8_quantize(input=a, is_sf_swizzled_layout=False) + >>> w_mx, w_sf = mxfp8_quantize(input=weight, is_sf_swizzled_layout=False) + >>> # For non-swizzled: reshape to 2D and transpose weight scale to (k//32, n) + >>> a_sf_2d = a_sf.view(m, k // 32) + >>> w_sf_2d = w_sf.view(n, k // 32).t() # Transpose to (k // 32, n) + >>> out = mm_mxfp8(a_mx, w_mx.t(), a_sf_2d, w_sf_2d, out_dtype=torch.bfloat16) + >>> out.shape + torch.Size([512, 256]) + """ + + assert a.ndim == 2, f"mm_mxfp8: a must be 2D, got {a.ndim}D with shape {a.shape}" + assert b.ndim == 2, f"mm_mxfp8: b must be 2D, got {b.ndim}D with shape {b.shape}" + assert a.shape[1] == b.shape[0], ( + f"mm_mxfp8: K dimension mismatch: a.shape[1]={a.shape[1]}, b.shape[0]={b.shape[0]}" + ) + + assert a_descale.ndim in (1, 2), ( + f"mm_mxfp8: a_descale must be 1D (swizzled) or 2D (non-swizzled), " + f"got {a_descale.ndim}D with shape {a_descale.shape}, dtype={a_descale.dtype}" + ) + assert b_descale.ndim in (1, 2), ( + f"mm_mxfp8: b_descale must be 1D (swizzled) or 2D (non-swizzled), " + f"got {b_descale.ndim}D with shape {b_descale.shape}, dtype={b_descale.dtype}" + ) + + # NOTE: do NOT reshape swizzled 1D scales to 2D; it breaks the F8_128x4 layout. + + # allocate the output tensor if not provided + if out is None: + out = torch.empty( + (a.shape[0], b.shape[1]), + device=a.device, + dtype=out_dtype, + ) + + workspace_buffer = _get_cache_buf( + "mm_mxfp8_workspace", DEFAULT_WORKSPACE_SIZE, a.device + ) + + if backend == "auto": + backends = mm_mxfp8.suitable_auto_backends + else: + backends = [backend] + + major, _ = get_compute_capability(a.device) + + backend_to_runner_factory = { + "cutlass": lambda: get_cutlass_mxfp8_gemm_module( + major + ).cutlass_mxfp8_gemm_runner(), + } + + runners: List[TunableRunner] = [ + backend_to_runner_factory[cur_backend]() for cur_backend in backends + ] + + tuner = AutoTuner.get() + + tuning_config = _MM_MXFP8_TUNING_CONFIG + + inputs = [ + a, + b, + a_descale, + b_descale, + out_dtype, + out, + workspace_buffer, + ] + + runner, tactic = tuner.choose_one( + custom_op="mxfp8_gemm", + runners=runners, + tuning_config=tuning_config, + inputs=inputs, + ) + + runner(inputs=inputs, tactic=tactic) + return out + + def _get_cudnn_fp4_gemm_graph( a: torch.Tensor, b: torch.Tensor, @@ -2768,6 +3145,13 @@ def _pad_up(x, y): return ((x + y - 1) // y) * y +def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: + """Return the 1D swizzled scale length for MXFP8 (F8_128x4 layout).""" + m_padded = _pad_up(m, 128) + num_k_tiles = _pad_up(k, 128) // 128 + return m_padded * num_k_tiles * 4 + + _MM_FP4_TUNING_CONFIG_8x4 = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -2816,6 +3200,34 @@ def _pad_up(x, y): ) +_MM_MXFP8_TUNING_CONFIG = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 2, # a_descale_tensor_index + 0, + lambda shapes: ( + _mxfp8_swizzled_scale_len(shapes[0][0], shapes[0][1]) + if len(shapes[2]) == 1 + else shapes[0][0] + ), + ), + ConstraintSpec( + 5, # out_tensor_index + 0, + lambda shapes: shapes[0][0], + ), + ), +) + + @backend_requirement( { "cudnn": _cudnn_gemm_fp4_requirement, diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index 7621a04538..c73518f125 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -19,6 +19,7 @@ gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm120_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, + gen_gemm_sm100_module_cutlass_mxfp8, gen_gemm_sm100_module_cutlass_bf16, gen_gemm_sm100_module, gen_gemm_sm120_module, @@ -35,6 +36,7 @@ "gen_gemm_sm100_module_cutlass_fp4", "gen_gemm_sm120_module_cutlass_fp4", "gen_gemm_sm100_module_cutlass_fp8", + "gen_gemm_sm100_module_cutlass_mxfp8", "gen_gemm_sm100_module_cutlass_bf16", "gen_gemm_sm100_module", "gen_gemm_sm120_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 5d40b510ac..3fbd8ebef1 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -236,6 +236,53 @@ def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec: ) +def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8" + os.makedirs(gen_directory, exist_ok=True) + source_paths = [ + jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu", + ] + + with open(jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + dtype_list = ["__nv_bfloat16", "half"] + cta_m_n_k_list = [ + (128, 64, 128), + (128, 256, 128), + (128, 128, 256), + (128, 256, 256), + ] + for cta_m, cta_n, cta_k in cta_m_n_k_list: + for dtype in dtype_list: + dest_path = ( + gen_directory + / f"mxfp8_gemm_cutlass_{dtype}_{cta_m}_{cta_n}_{cta_k}.cu" + ) + source_paths.append(dest_path) + source = kernel_inst_templ.render( + type=dtype, + cta_m=cta_m, + cta_n=cta_n, + cta_k=cta_k, + ) + write_if_different(dest_path, source) + + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10, 11] + ) + return gen_jit_spec( + "mxfp8_gemm_cutlass", + source_paths, + extra_cuda_cflags=nvcc_flags + + [ + "-DENABLE_BF16", + ], + extra_cflags=[ + "-DFAST_BUILD", + ], + ) + + def gen_gemm_sm100_module() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100" os.makedirs(gen_directory, exist_ok=True) diff --git a/include/flashinfer/gemm/mxfp8_gemm_cutlass.h b/include/flashinfer/gemm/mxfp8_gemm_cutlass.h new file mode 100644 index 0000000000..9815f3a66a --- /dev/null +++ b/include/flashinfer/gemm/mxfp8_gemm_cutlass.h @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLASHINFER_MXFP8_GEMM_CUTLASS_H_ +#define FLASHINFER_MXFP8_GEMM_CUTLASS_H_ + +#include + +#include + +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +namespace flashinfer { +namespace gemm { + +/* + This runner supports: + FP8 inputs (A and B) + E8M0 blockwise scaling factor + T output (D) where T = {float, half, __nv_bfloat16} + + Activations, biases and outputs are all assumed to be row-major. + Weights are assumed to be column-major. + Block scaling factor are interleaved. +*/ + +class CutlassMxfp8GemmRunnerInterface { + public: + CutlassMxfp8GemmRunnerInterface() {} + + virtual ~CutlassMxfp8GemmRunnerInterface() {} + + virtual void gemm(void* D, void const* A, void const* B, void const* input_sf, + void const* weight_sf, int m, int n, int k, int batch_count, + CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, + cudaStream_t stream) = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k, int batch_count) = 0; + + virtual std::vector getConfigs() const = 0; +}; + +enum class MXFP8GemmType { + W8A8_MXFP8_MXFP8, +}; + +template +class CutlassMxfp8GemmRunner : public virtual CutlassMxfp8GemmRunnerInterface { + public: + CutlassMxfp8GemmRunner(); + ~CutlassMxfp8GemmRunner(); + + void gemm(void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, + int m, int n, int k, int batch_count, CutlassGemmConfig gemmConfig, char* workspace, + const size_t workspaceBytes, cudaStream_t stream) override; + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k, int const batch_count) override; + + std::vector getConfigs() const override; + + private: + size_t dispatchToArch(T* D, void const* A, void const* B, void const* input_sf, + void const* weight_sf, int m, int n, int k, int batch_count, + CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, + cudaStream_t stream, int* occupancy = nullptr); + + size_t getWorkspaceSizeImpl(int const m, int const n, int const k, int const batch_count); +}; + +} // namespace gemm +} // namespace flashinfer + +#endif // FLASHINFER_MXFP8_GEMM_CUTLASS_H_ diff --git a/include/flashinfer/gemm/mxfp8_gemm_cutlass_template.h b/include/flashinfer/gemm/mxfp8_gemm_cutlass_template.h new file mode 100644 index 0000000000..0b58c8b2f3 --- /dev/null +++ b/include/flashinfer/gemm/mxfp8_gemm_cutlass_template.h @@ -0,0 +1,290 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLASHINFER_MXFP8_GEMM_CUTLASS_TEMPLATE_H_ +#define FLASHINFER_MXFP8_GEMM_CUTLASS_TEMPLATE_H_ + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +#include "flashinfer/gemm/mxfp8_gemm_cutlass.h" +#include "mxfp8_gemm_template_sm100.h" + +namespace flashinfer { +namespace gemm { +using namespace cute; + +template +size_t dispatchMXFP8xMXFP8GemmClusterShapeSm100(T* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, int m, + int n, int k, int batch_count, + CutlassGemmConfig gemmConfig, char* workspace, + const size_t workspaceBytes, cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemmConfig.cluster_shape) { + case ClusterShape::ClusterShape_1x1x1: + return genericMxfp8GemmKernelLauncher, cute::Int<1>, + cute::Int<1>, _1SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_2x1x1: + return genericMxfp8GemmKernelLauncher, cute::Int<1>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_1x2x1: + return genericMxfp8GemmKernelLauncher, cute::Int<2>, + cute::Int<1>, _1SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_2x2x1: + return genericMxfp8GemmKernelLauncher, cute::Int<2>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_1x4x1: + return genericMxfp8GemmKernelLauncher, cute::Int<4>, + cute::Int<1>, _1SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_4x2x1: + return genericMxfp8GemmKernelLauncher, cute::Int<2>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_2x4x1: + return genericMxfp8GemmKernelLauncher, cute::Int<4>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case ClusterShape::ClusterShape_4x4x1: + return genericMxfp8GemmKernelLauncher, cute::Int<4>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + default: + throw std::runtime_error( + "[Error][MXFP8][dispatch_gemm_cluster_shape] Config is invalid for MXFP8 GEMM."); + break; + } +} + +template +size_t dispatchMXFP8xMXFP8GemmCTAShapeSm100(T* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, int m, + int n, int k, int batch_count, + CutlassGemmConfig gemmConfig, char* workspace, + const size_t workspaceBytes, cudaStream_t stream, + int* occupancy = nullptr) { + // TODO: check if true for MXFP8 + // Several constraints: + // Cta N should be one of 64/128/192/256 for MXFP8 on SM100. + // M-mode size should be 128 or 256 for 2 CTA cluster MMA; + // M-mode size should be 128 for 1 CTA cluster OMMA. + // K256 looks to be better than K128 + switch (gemmConfig.tile_config_sm100) { + case CutlassTileConfigSM100::CtaShape128x64x128B: + return dispatchMXFP8xMXFP8GemmClusterShapeSm100, cute::Int<64>, + cute::Int<128>>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x256x128B: + return dispatchMXFP8xMXFP8GemmClusterShapeSm100, cute::Int<256>, + cute::Int<128>>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x128x256B: + return dispatchMXFP8xMXFP8GemmClusterShapeSm100, cute::Int<128>, + cute::Int<256>>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x256x256B: + return dispatchMXFP8xMXFP8GemmClusterShapeSm100, cute::Int<256>, + cute::Int<256>>( + D, A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, + stream, occupancy); + break; + case CutlassTileConfigSM100::Undefined: + throw std::runtime_error("[Error][MXFP8][dispatch_gemm_cta_shape] Gemm config undefined."); + break; + case CutlassTileConfigSM100::ChooseWithHeuristic: + throw std::runtime_error( + "[Error][MXFP8][dispatch_gemm_cta_shape] Gemm config should have already been " + "set by " + "heuristic."); + break; + default: + throw std::runtime_error( + "[Error][MXFP8][dispatch_gemm_cta_shape] Config is invalid for MXFP8 GEMM."); + break; + } +} +template +CutlassMxfp8GemmRunner::CutlassMxfp8GemmRunner() {} + +template +CutlassMxfp8GemmRunner::~CutlassMxfp8GemmRunner() {} + +template +size_t CutlassMxfp8GemmRunner::dispatchToArch( + T* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, int m, int n, + int k, int batch_count, CutlassGemmConfig gemmConfig, char* workspace, + const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { + if constexpr (mxfp8GemmType == MXFP8GemmType::W8A8_MXFP8_MXFP8) { + return dispatchMXFP8xMXFP8GemmCTAShapeSm100(D, A, B, input_sf, weight_sf, m, n, k, + batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + } else { + throw std::runtime_error( + "[Error][CutlassMxfp8GemmRunner][GEMM Dispatch] MXFP8 Gemm type unsupported for " + "CUTLASS MXFP8 GEMM"); + } +} + +template +void CutlassMxfp8GemmRunner::gemm(void* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, + int m, int n, int k, int batch_count, + CutlassGemmConfig gemmConfig, char* workspace, + const size_t workspaceBytes, + cudaStream_t stream) { + CutlassMxfp8GemmRunner::dispatchToArch( + reinterpret_cast(D), A, B, input_sf, weight_sf, m, n, k, batch_count, gemmConfig, + workspace, workspaceBytes, stream); +} + +template +std::vector CutlassMxfp8GemmRunner::getConfigs() const { + std::vector candidateConfigs; + + std::vector tilesSm100 = { + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape128x128x256B, + CutlassTileConfigSM100::CtaShape128x256x256B, + }; + + std::vector clusterShapes = { + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1, + ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1, + ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x4x1, + ClusterShape::ClusterShape_4x2x1, ClusterShape::ClusterShape_4x4x1, + }; + for (auto const& tile_config : tilesSm100) { + for (auto const& cluster_config : clusterShapes) { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + cluster_config); + candidateConfigs.push_back(config); + } + } + + // There’s no heuristic yet, so for users without autotuning, we provide an ordering based on + // performance sweeps from common workloads. Keep it safe if configs are pruned. + std::vector best_tactics_index = {22, 20, 29, 4, 18}; + std::vector newCandidateConfigs; + newCandidateConfigs.reserve(candidateConfigs.size()); + for (auto const& tactic_index : best_tactics_index) { + if (tactic_index >= 0 && tactic_index < static_cast(candidateConfigs.size())) { + newCandidateConfigs.push_back(candidateConfigs[tactic_index]); + } + } + for (int64_t i = 0; i < static_cast(candidateConfigs.size()); i++) { + if (std::find(best_tactics_index.begin(), best_tactics_index.end(), i) == + best_tactics_index.end()) { + newCandidateConfigs.push_back(candidateConfigs[i]); + } + } + return newCandidateConfigs; +} + +template +size_t CutlassMxfp8GemmRunner::getWorkspaceSizeImpl(int const m, int const n, + int const k, + int const batch_count) { + size_t workspace_size = 0; + auto gemmConfigs = CutlassMxfp8GemmRunner{}.getConfigs(); + for (auto const& gemmConfig : gemmConfigs) { + try { + size_t curr_workspace_size = CutlassMxfp8GemmRunner::dispatchToArch( + nullptr, nullptr, nullptr, nullptr, nullptr, m, n, k, batch_count, gemmConfig, nullptr, 0, + nullptr, nullptr); + workspace_size = std::max(workspace_size, curr_workspace_size); + } catch (std::runtime_error& e) { + // Swallow errors when SMEM exceeds maximum allowed + continue; + } + } + return workspace_size; +} + +template +size_t CutlassMxfp8GemmRunner::getWorkspaceSize(int const m, int const n, + int const k, + int const batch_count) { + // Custom hash function for the MNKB type + using MNK = std::tuple; + + struct MNKHash { + size_t operator()(const MNK& mnk) const { + auto h1 = std::hash{}(std::get<0>(mnk)); + auto h2 = std::hash{}(std::get<1>(mnk)); + auto h3 = std::hash{}(std::get<2>(mnk)); + auto h4 = std::hash{}(std::get<3>(mnk)); + return h1 ^ h2 ^ h3 ^ h4; + } + }; + + static std::unordered_map workspace_hashmap; + + size_t workspace_size = 0; + if (workspace_hashmap.find(std::make_tuple(m, n, k, batch_count)) == workspace_hashmap.end()) { + workspace_size = + CutlassMxfp8GemmRunner::getWorkspaceSizeImpl(m, n, k, batch_count); + workspace_hashmap[std::make_tuple(m, n, k, batch_count)] = workspace_size; + } else { + workspace_size = workspace_hashmap[std::make_tuple(m, n, k, batch_count)]; + } + return workspace_size; +} + +} // namespace gemm +} // namespace flashinfer +#endif // FLASHINFER_MXFP8_GEMM_CUTLASS_TEMPLATE_H_ diff --git a/include/flashinfer/gemm/mxfp8_gemm_template_sm100.h b/include/flashinfer/gemm/mxfp8_gemm_template_sm100.h new file mode 100644 index 0000000000..f570f1e936 --- /dev/null +++ b/include/flashinfer/gemm/mxfp8_gemm_template_sm100.h @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_MXFP8_GEMM_TEMPLATE_SM100_H_ +#define FLASHINFER_MXFP8_GEMM_TEMPLATE_SM100_H_ + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "flashinfer/arch_condition.h" +#include "flashinfer/cutlass_utils.cuh" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +namespace flashinfer { +namespace gemm { +using namespace cute; + +#ifdef ENABLE_BF16 +using SafeBF16 = __nv_bfloat16; +#else +using SafeBF16 = void; +#endif + +struct _1SM {}; + +struct _2SM {}; + +template +struct SMTypeAdapter {}; + +template <> +struct SMTypeAdapter<_1SM> { + static int const Scale = 1; + using AtomThrShape = cute::Shape<_1, _1, _1>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100; +}; + +template <> +struct SMTypeAdapter<_2SM> { + static int const Scale = 2; + using AtomThrShape = cute::Shape<_2, _1, _1>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100; +}; + +template +constexpr auto always_false = false; + +template +size_t genericMxfp8GemmKernelLauncher(void* D, void const* A, void const* B, void const* input_sf, + void const* weight_sf, int m, int n, int k, int batch_count, + CutlassGemmConfig gemmConfig, char* workspace, + size_t const workspaceBytes, cudaStream_t stream, + int* occupancy); + +#ifdef PLACEHOLDER_KERNELS + +#define INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, \ + XSM_) \ + template <> \ + size_t \ + genericMxfp8GemmKernelLauncher, cute::Int, cute::Int, \ + cute::Int, cute::Int, cute::Int, XSM_>( \ + void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, int m, \ + int n, int k, int batch_count, CutlassGemmConfig gemmConfig, char* workspace, \ + const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { \ + throw std::runtime_error( \ + "MXFP8 gemm kernel is not compiled with support for " \ + "this Architecture."); \ + } + +#else + +#define INSTANTIATE_MXFP8_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, \ + XSM_) \ + struct \ + DeviceGemmMxfp8GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ { \ + using OutElementType = flashinfer::cutlass_dtype::type; \ + using CTAShape = cute::Shape, cute::Int, cute::Int>; \ + /*using ClusterShape = cute::Shape, cute::Int, cute::Int>;*/ \ + using ClusterShape = cute::Shape; \ + using ElementType = cutlass::float_e4m3_t; \ + using Arch = cutlass::arch::Sm100; \ + /* // Input A */ \ + using ElementA = ElementType; \ + using LayoutA = cutlass::layout::RowMajor; \ + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; \ + /* // Input B */ \ + using ElementB = ElementType; \ + using LayoutB = cutlass::layout::ColumnMajor; \ + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; \ + /* // Input C */ \ + using ElementC = void; \ + using LayoutC = cutlass::layout::RowMajor; \ + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; \ + \ + using SFType = cutlass::float_ue8m0_t; \ + using ElementCompute = float; \ + using ElementAccumulator = float; \ + using OperatorClass = cutlass::arch::OpClassTensorOp; \ + using EpilogueTileType = std::conditional_t, \ + cutlass::epilogue::collective::EpilogueTileAuto>; \ + using EpilogueSchedule = SMTypeAdapter::EpilogueSchedule; \ + using MainloopSchedule = SMTypeAdapter::MainloopSchedule; \ + using MmaTileShape = cute::Shape::Scale>, \ + cute::Int, cute::Int>; \ + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< \ + Arch, OperatorClass, MmaTileShape, ClusterShape, EpilogueTileType, ElementAccumulator, \ + ElementCompute, ElementC, LayoutC, AlignmentC, OutElementType, LayoutC, AlignmentC, \ + EpilogueSchedule, \ + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; \ + \ + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< \ + Arch, cutlass::arch::OpClassBlockScaledTensorOp, cute::tuple, LayoutA, \ + AlignmentA, cute::tuple, LayoutB, AlignmentB, ElementAccumulator, \ + MmaTileShape, ClusterShape, \ + cutlass::gemm::collective::StageCountAutoCarveout( \ + sizeof(typename CollectiveEpilogue::SharedStorage))>, \ + MainloopSchedule>::CollectiveOp; \ + \ + template \ + struct Sm10x11xOnly : Base { \ + using typename Base::Params; \ + CUTLASS_DEVICE \ + void operator()(Params const& params, char* smem_buf) { \ + if constexpr (flashinfer::arch::is_major_v<10> || flashinfer::arch::is_major_v<11>) { \ + this->Base::operator()(params, smem_buf); \ + } else { \ + if (cute::thread0()) { \ + printf("%s : This kernel shall only run on SM10x and SM11x devices.\n", \ + __PRETTY_FUNCTION__); \ + __trap(); \ + } \ + } \ + } \ + }; \ + using GemmKernel = \ + Sm10x11xOnly, \ + CollectiveMainloop, CollectiveEpilogue, \ + cutlass::gemm::PersistentScheduler>>; \ + \ + using Gemm = typename cutlass::gemm::device::GemmUniversalAdapter; \ + }; \ + \ + template \ + typename Gemm::Arguments \ + prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_( \ + void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, \ + int m, int n, int k, int batch_count) { \ + using Sm1xxBlkScaledConfig = \ + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; \ + using ElementA = typename Gemm::ElementA; \ + using ElementB = typename Gemm::ElementB; \ + using ElementSFA = cutlass::float_ue8m0_t; \ + using ElementSFB = cutlass::float_ue8m0_t; \ + using ElementC = void; \ + using ElementD = typename Gemm::ElementD; \ + using ElementCompute = float; \ + \ + typename Gemm::Arguments operator_args; \ + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGemm; \ + auto& fusion_args = operator_args.epilogue.thread; \ + fusion_args.alpha_ptr = nullptr; /* MXFP8 has no global scale */ \ + \ + operator_args.problem_shape = cute::make_shape(m, n, k, batch_count); \ + \ + operator_args.mainloop.ptr_A = static_cast(A); \ + operator_args.mainloop.ptr_B = static_cast(B); \ + operator_args.mainloop.ptr_SFA = static_cast(input_sf); \ + operator_args.mainloop.ptr_SFB = static_cast(weight_sf); \ + operator_args.epilogue.ptr_C = static_cast(D); \ + operator_args.epilogue.ptr_D = static_cast(D); \ + \ + int const stride_A = batch_count == 1 ? 0 : m * k; \ + int const stride_B = batch_count == 1 ? 0 : n * k; \ + int const stride_C = batch_count == 1 ? 0 : m * n; \ + \ + operator_args.mainloop.dA = \ + cute::make_int_tuple_from(k, stride_A); \ + operator_args.mainloop.dB = \ + cute::make_int_tuple_from(k, stride_B); \ + operator_args.epilogue.dC = \ + cute::make_int_tuple_from(n, stride_C); \ + operator_args.epilogue.dD = operator_args.epilogue.dC; \ + \ + operator_args.mainloop.layout_SFA = \ + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); \ + operator_args.mainloop.layout_SFB = \ + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); \ + \ + if constexpr (!std::is_const_v) { \ + operator_args.scheduler.max_swizzle_size = 1; \ + } \ + if constexpr (!std::is_const_v) { \ + using Enum_t = decltype(operator_args.scheduler.raster_order); \ + operator_args.scheduler.raster_order = Enum_t::Heuristic; \ + } \ + operator_args.hw_info.cluster_shape = dim3(CGA_M_, CGA_N_, CGA_K_); \ + operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter::Scale, 1, 1); \ + \ + return operator_args; \ + } \ + \ + template <> \ + size_t \ + genericMxfp8GemmKernelLauncher, cute::Int, cute::Int, \ + cute::Int, cute::Int, cute::Int, XSM_>( \ + void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, int m, \ + int n, int k, int batch_count, CutlassGemmConfig gemmConfig, char* workspace, \ + const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { \ + using ElementOutput__ = \ + typename cutlass::platform::conditional::value, \ + cutlass::half_t, T>::type; \ + using ElementOutput_ = typename cutlass::platform::conditional< \ + cutlass::platform::is_same::value, float, ElementOutput__>::type; \ + using ElementOutput = typename cutlass::platform::conditional< \ + cutlass::platform::is_same::value, cutlass::bfloat16_t, \ + ElementOutput_>::type; \ + \ + using Mxfp8GemmOperator = \ + DeviceGemmMxfp8GemmSm100_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \ + Gemm; \ + Mxfp8GemmOperator gemm; \ + auto args = \ + prepareGemmArgs_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \ + Mxfp8GemmOperator>(D, A, B, input_sf, weight_sf, m, n, k, batch_count); \ + /* // Return workspace size */ \ + if (!A && !B && !D) { \ + return gemm.get_workspace_size(args); \ + } \ + if (gemm.get_workspace_size(args) > workspaceBytes) { \ + std::string errMsg("Requested workspace size insufficient. Required " + \ + std::to_string(gemm.get_workspace_size(args)) + ", got " + \ + std::to_string(workspaceBytes)); \ + throw std::runtime_error("[MXFP8 gemm Runner] " + errMsg); \ + } \ + auto can_implement = gemm.can_implement(args); \ + if (can_implement != cutlass::Status::kSuccess) { \ + std::string errMsg = "MXFP8 Gemm cutlass kernel will fail for params. Error: " + \ + std::string(cutlassGetStatusString(can_implement)); \ + throw std::runtime_error("[MXFP8 gemm Runner] " + errMsg); \ + } \ + auto initStatus = gemm.initialize(args, workspace, stream); \ + if (initStatus != cutlass::Status::kSuccess) { \ + std::string errMsg = "Failed to initialize cutlass MXFP8 gemm on sm100. Error: " + \ + std::string(cutlassGetStatusString(initStatus)); \ + throw std::runtime_error("[MXFP8 gemm Runner] " + errMsg); \ + } \ + auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); \ + if (runStatus != cutlass::Status::kSuccess) { \ + std::string errMsg = "Failed to run cutlass MXFP8 gemm on sm100. Error: " + \ + std::string(cutlassGetStatusString(runStatus)); \ + throw std::runtime_error("[MXFP8 gemm Runner] " + errMsg); \ + } \ + return gemm.get_workspace_size(args); \ + } + +#endif + +} // namespace gemm +} // namespace flashinfer +#endif // FLASHINFER_MXFP8_GEMM_TEMPLATE_SM100_H_ diff --git a/tests/gemm/test_mm_mxfp8.py b/tests/gemm/test_mm_mxfp8.py new file mode 100644 index 0000000000..736c988121 --- /dev/null +++ b/tests/gemm/test_mm_mxfp8.py @@ -0,0 +1,520 @@ +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, mm_mxfp8 +from flashinfer.fp8_quantization import mxfp8_quantize +from flashinfer.utils import get_compute_capability + + +def _get_min_cosine_sim( + is_sf_swizzled_layout: bool, scale: float | None = None +) -> float: + if is_sf_swizzled_layout: + return 0.98 + + # Lower accuracy for non-swizzled layout + if scale is not None: + if scale < 0.5 or scale > 10.0: + # For very small or large scales, we expect lower accuracy + return 0.8 + return 0.84 + + +def _assert_cosine_similarity( + reference: torch.Tensor, + result: torch.Tensor, + is_sf_swizzled_layout: bool, + *, + use_float: bool = False, + context: str = "", +) -> float: + min_cos_sim = _get_min_cosine_sim(is_sf_swizzled_layout) + if use_float: + reference = reference.float() + result = result.float() + + # Check cosine similarity between reference and result + cos_sim = F.cosine_similarity( + reference.reshape(-1), result.reshape(-1), dim=0 + ).item() + + if context: + message = ( + f"{context} Cosine similarity {cos_sim:.4f} is too low " + f"(expected > {min_cos_sim}, {is_sf_swizzled_layout=})." + ) + else: + message = ( + f"Cosine similarity {cos_sim:.4f} is too low " + f"(expected > {min_cos_sim}, {is_sf_swizzled_layout=})." + ) + assert cos_sim > min_cos_sim, message + return cos_sim + + +def _skip_if_unsupported(backend: str = "cutlass"): + if backend == "auto": + backend = "cutlass" + compute_capability = get_compute_capability(torch.device("cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if not mm_mxfp8.is_backend_supported(backend, compute_capability_number): + pytest.skip( + "Skipping test because mm_mxfp8 cutlass is not supported on compute " + f"capability {compute_capability_number}." + ) + + +def _run_mm_mxfp8( + m, + n, + k, + input_dtype, + is_sf_swizzled_layout, + out_dtype, + backend, + auto_tuning, + provide_out, +): + _skip_if_unsupported(backend) + + input = torch.randn([m, k], device="cuda", dtype=input_dtype) + mat2 = torch.randn([n, k], device="cuda", dtype=input_dtype) + + input_mxfp8, mat2_mxfp8, input_descale, mat2_descale = _prepare_mxfp8_tensors( + input, mat2, is_sf_swizzled_layout + ) + reference = torch.mm(input, mat2.T) + + res = torch.empty([m, n], device="cuda", dtype=out_dtype) if provide_out else None + + with autotune(auto_tuning): + res = mm_mxfp8( + input_mxfp8, + mat2_mxfp8.T, # mm_mxfp8 expects mat2.T (transposed) + input_descale, + mat2_descale, + out=res, + out_dtype=out_dtype, + backend=backend, + ) + + assert res.shape == (m, n) + assert res.dtype == out_dtype + assert res.device.type == "cuda" + assert torch.isfinite(res).all(), "Output contains NaN/Inf values" + + _assert_cosine_similarity(reference, res, is_sf_swizzled_layout) + + +def _prepare_descales(input_scale, weight_scale, m, n, k, is_sf_swizzled_layout): + if is_sf_swizzled_layout: + return input_scale, weight_scale + input_descale = input_scale.view(m, k // 32) + weight_descale = weight_scale.view(n, k // 32).t() + return input_descale, weight_descale + + +def _prepare_mxfp8_tensors(input_bf16, weight_bf16, is_sf_swizzled_layout): + m, k = input_bf16.shape + n = weight_bf16.shape[0] + input_mxfp8, input_scale = mxfp8_quantize( + input_bf16, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + weight_mxfp8, weight_scale = mxfp8_quantize( + weight_bf16, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + input_descale, weight_descale = _prepare_descales( + input_scale, weight_scale, m, n, k, is_sf_swizzled_layout + ) + return input_mxfp8, weight_mxfp8, input_descale, weight_descale + + +@pytest.mark.parametrize("m", [128, 256, 512, 1024]) +@pytest.mark.parametrize("n", [128, 256, 512, 1024]) +@pytest.mark.parametrize("k", [128, 256, 512, 1024, 2048, 2560, 3200]) +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("backend", ["cutlass"]) +@pytest.mark.parametrize("auto_tuning", [True, False]) +def test_mm_mxfp8( + m, n, k, input_dtype, is_sf_swizzled_layout, out_dtype, backend, auto_tuning +): + _run_mm_mxfp8( + m, + n, + k, + input_dtype, + is_sf_swizzled_layout, + out_dtype, + backend, + auto_tuning, + provide_out=True, + ) + + +@pytest.mark.parametrize("m", [128, 256, 1024, 2048, 4096]) +@pytest.mark.parametrize("n", [2688, 5376, 8192, 12288, 16384]) +@pytest.mark.parametrize("k", [4096, 8192]) +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("backend", ["cutlass", "auto"]) +def test_mm_mxfp8_large_dimensions( + m, n, k, input_dtype, is_sf_swizzled_layout, out_dtype, backend +): + _run_mm_mxfp8( + m, + n, + k, + input_dtype, + is_sf_swizzled_layout, + out_dtype, + backend, + auto_tuning=False, + provide_out=True, + ) + + +@pytest.mark.parametrize( + "m,n,k", + [ + (4, 6144, 4096), + (8, 6144, 4096), + (16, 6144, 4096), + (32, 2688, 1856), + (32, 1856, 2688), + (32, 2688, 4096), + (32, 5376, 4096), + ], +) +def test_mm_mxfp8_small_m(m, n, k): + _run_mm_mxfp8( + m, + n, + k, + torch.bfloat16, + True, # swizzled scales are the intended fast path + torch.bfloat16, + "cutlass", + auto_tuning=False, + provide_out=True, + ) + + +def test_mm_mxfp8_invalid_input_dtype(): + _skip_if_unsupported() + m, n, k = 128, 128, 128 + a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + b = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + a_scale = torch.empty([m * (k // 32)], device="cuda", dtype=torch.uint8) + b_scale = torch.empty([n * (k // 32)], device="cuda", dtype=torch.uint8) + with pytest.raises(ValueError, match="float8_e4m3fn"): + mm_mxfp8(a, b, a_scale, b_scale, out_dtype=torch.bfloat16, backend="cutlass") + + +def test_mm_mxfp8_invalid_ndim(): + _skip_if_unsupported() + m, n, k = 128, 128, 128 + a = torch.randn([1, m, k], device="cuda", dtype=torch.bfloat16) + b = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + a_scale = torch.empty([m * (k // 32)], device="cuda", dtype=torch.uint8) + b_scale = torch.empty([n * (k // 32)], device="cuda", dtype=torch.uint8) + with pytest.raises(ValueError, match="accepts 2d tensors"): + mm_mxfp8(a, b, a_scale, b_scale, out_dtype=torch.bfloat16, backend="cutlass") + + a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + b = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + a_mx, a_scale = mxfp8_quantize(a, is_sf_swizzled_layout=True) + b_mx, b_scale = mxfp8_quantize(b.T.contiguous(), is_sf_swizzled_layout=True) + a_descale = a_scale.view(1, -1, 1) + b_descale = b_scale.view(1, -1, 1) + with pytest.raises( + ValueError, + match=r"a_descale must be 1D \(swizzled\) or 2D \(non-swizzled\)", + ): + mm_mxfp8( + a_mx, + b_mx, + a_descale, + b_descale, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + + +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +def test_mm_mxfp8_find_minimum_cosine_similarity(is_sf_swizzled_layout): + """Sweep value scales and enforce a minimum cosine similarity.""" + _skip_if_unsupported() + + m, n, k = 256, 4096, 4096 + + value_scales = [0.001, 0.01, 0.02, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0] + + results = [] + for value_scale in value_scales: + input_data = ( + torch.randn([m, k], device="cuda", dtype=torch.bfloat16) * value_scale + ) + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) * value_scale + + input_mxfp8, mat2_mxfp8, input_descale, mat2_descale = _prepare_mxfp8_tensors( + input_data, mat2, is_sf_swizzled_layout + ) + + reference = torch.mm(input_data, mat2.T) + + result = mm_mxfp8( + input_mxfp8, + mat2_mxfp8.T, + input_descale, + mat2_descale, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + + cos_sim = F.cosine_similarity( + reference.reshape(-1).float(), result.reshape(-1).float(), dim=0 + ).item() + + results.append((value_scale, cos_sim)) + + print("\n" + "=" * 60) + print(f"MXFP8 Cosine Similarity vs Value Scale Summary ({is_sf_swizzled_layout=})") + print("=" * 60) + + fail_test: bool = False + for scale, sim in results: + min_cosine_sim = _get_min_cosine_sim(is_sf_swizzled_layout, scale) + fail = sim < min_cosine_sim + + status = "[OK]" if not fail else "[FAIL]" + print(f" {status} Scale={scale:8.3f}: cos_sim={sim:.4f}") + fail_test |= fail + + print("=" * 60) + + # Assert minimum acceptable similarity + assert not fail_test, "One or more cosine similarities are too low" + + +@pytest.mark.parametrize("m", [256, 512, 1024]) # Skip M=128 (edge case issues) +@pytest.mark.parametrize("n", [4096, 14336]) +@pytest.mark.parametrize("k", [4096]) # Focus on common hidden_size +@pytest.mark.parametrize( + "input_std,weight_std", + [ + (0.1, 0.02), # Typical trained model statistics + (0.5, 0.1), # Larger activations + (1.0, 1.0), # Random normal (baseline) + ], +) +def test_mm_mxfp8_realistic_model_statistics(m, n, k, input_std, weight_std): + """Test accuracy for typical activation/weight statistics.""" + _skip_if_unsupported() + + torch.manual_seed(42) # Reproducibility + + input_data = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) * input_std + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) * weight_std + + reference = torch.mm(input_data, mat2.T) + + input_mxfp8, mat2_mxfp8, input_descale, mat2_descale = _prepare_mxfp8_tensors( + input_data, mat2, True + ) + + result = mm_mxfp8( + input_mxfp8, + mat2_mxfp8.T, + input_descale, + mat2_descale, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + + # Check for NaN/Inf + if not torch.isfinite(result).all(): + pytest.fail( + f"Output contains NaN/Inf for M={m}, N={n}, K={k}, " + f"input_std={input_std}, weight_std={weight_std}" + ) + + cos_sim = F.cosine_similarity( + reference.reshape(-1).float(), result.reshape(-1).float(), dim=0 + ).item() + + # Should maintain high accuracy across all realistic value ranges + assert cos_sim > 0.95, ( + f"Accuracy too low for M={m}, N={n}, K={k}, " + f"input_std={input_std}, weight_std={weight_std}: cos_sim={cos_sim:.4f}" + ) + + +def test_mm_mxfp8_llm_full_layer_simulation(): + """Simulate a transformer layer forward pass with multiple MM calls.""" + _skip_if_unsupported() + + torch.manual_seed(42) + m = 256 # Batch size + hidden_size = 4096 + intermediate_size = 14336 + qkv_size = 6144 + gate_up_size = 28672 # gate + up combined + + hidden_states = ( + torch.randn([m, hidden_size], device="cuda", dtype=torch.bfloat16) * 0.1 + ) + + weights = { + "qkv": torch.randn([qkv_size, hidden_size], device="cuda", dtype=torch.bfloat16) + * 0.02, + "o_proj": torch.randn( + [hidden_size, hidden_size], device="cuda", dtype=torch.bfloat16 + ) + * 0.02, + "gate_up": torch.randn( + [gate_up_size, hidden_size], device="cuda", dtype=torch.bfloat16 + ) + * 0.02, + "down": torch.randn( + [hidden_size, intermediate_size], device="cuda", dtype=torch.bfloat16 + ) + * 0.02, + } + + results = {} + + for name, weight in weights.items(): + n, k = weight.shape + + if name == "down": + layer_input = ( + torch.randn([m, intermediate_size], device="cuda", dtype=torch.bfloat16) + * 0.1 + ) + else: + layer_input = hidden_states + + reference = torch.mm(layer_input, weight.T) + + input_mxfp8, weight_mxfp8, input_descale, weight_descale = ( + _prepare_mxfp8_tensors(layer_input, weight, True) + ) + + result = mm_mxfp8( + input_mxfp8, + weight_mxfp8.T, + input_descale, + weight_descale, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + + cos_sim = F.cosine_similarity( + reference.reshape(-1).float(), result.reshape(-1).float(), dim=0 + ).item() + + results[name] = cos_sim + print( + f" {name}: input=[{m}, {layer_input.shape[1]}] @ weight=[{n}, {k}].T -> cos_sim={cos_sim:.6f}" + ) + + for name, cos_sim in results.items(): + assert cos_sim > 0.98, f"Layer {name} has low accuracy: cos_sim={cos_sim:.4f}" + + print( + f"\n All layers passed with average cos_sim={sum(results.values()) / len(results):.6f}" + ) + + +def test_mm_mxfp8_scale_contiguity_requirement(): + """Test behavior with non-contiguous scale tensors.""" + _skip_if_unsupported() + + m, n, k = 256, 4096, 4096 + + input_bf16 = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + weight_bf16 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + + input_fp8, input_scale = mxfp8_quantize(input_bf16, is_sf_swizzled_layout=False) + weight_fp8, weight_scale = mxfp8_quantize(weight_bf16, is_sf_swizzled_layout=False) + + input_descale = input_scale.view(m, k // 32) + + weight_scale_2d = weight_scale.view(n, k // 32) + weight_descale_noncontig = weight_scale_2d.t() # Non-contiguous! + + assert not weight_descale_noncontig.is_contiguous(), ( + "Expected non-contiguous tensor" + ) + + output = mm_mxfp8( + input_fp8, + weight_fp8.T, + input_descale, + weight_descale_noncontig, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + assert torch.isfinite(output).all() + + weight_descale_contig = weight_descale_noncontig.contiguous() + assert weight_descale_contig.is_contiguous() + + output = mm_mxfp8( + input_fp8, + weight_fp8.T, + input_descale, + weight_descale_contig, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + assert torch.isfinite(output).all(), "Output with contiguous scale should be valid" + + +@pytest.mark.parametrize("m", [128, 256, 512, 1024, 2048, 4096, 8192, 16384]) +def test_mm_mxfp8_scale_1d_tensor_interpretation(m): + """Check that 1D swizzled scales have the expected size.""" + _skip_if_unsupported() + + n, k = 4096, 4096 + + input_bf16 = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) * 0.1 + weight_bf16 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) * 0.02 + + input_fp8, weight_fp8, input_descale, weight_descale = _prepare_mxfp8_tensors( + input_bf16, weight_bf16, True + ) + + input_scale = input_descale + # Verify scale tensor properties + assert input_scale.ndim == 1, ( + f"Swizzled scale should be 1D, got {input_scale.ndim}D" + ) + assert input_scale.is_contiguous(), "Swizzled scale must be contiguous" + + padded_m = ((m + 127) // 128) * 128 + k_scale_cols = k // 32 + padded_k_scale = ((k_scale_cols + 3) // 4) * 4 + expected_input_scale_size = padded_m * padded_k_scale + + assert input_scale.numel() == expected_input_scale_size, ( + f"Input scale size mismatch: got {input_scale.numel()}, " + f"expected {expected_input_scale_size} for M={m}, K={k} " + f"(padded_m={padded_m}, padded_k_scale={padded_k_scale})" + ) + + output = mm_mxfp8( + input_fp8, + weight_fp8.T, + input_descale, + weight_descale, + out_dtype=torch.bfloat16, + backend="cutlass", + ) + + assert output.shape == (m, n) + assert torch.isfinite(output).all()