diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a16.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a16.json index 816202fc88bc..ab634ae7d3ab 100644 --- a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a16.json +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a16.json @@ -11,7 +11,7 @@ "kpack": 2 }, "medium_M": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, @@ -23,7 +23,7 @@ }, "large_M": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a8.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a8.json new file mode 100644 index 000000000000..643964ca95b6 --- /dev/null +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X,dtype=int8_w8a8.json @@ -0,0 +1,35 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "medium_M": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "large_M": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json index 5955cc32c481..2870d7e6d180 100644 --- a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json @@ -30,6 +30,6 @@ "num_stages": 2, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, - "kpack": 2 + "kpack": 1 } } diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 1ce6aa8d12a4..1e95e3201539 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -34,8 +34,10 @@ class MetaData(): use_fp8_w8a8 = False use_int8_w8a16 = False + use_int8_w8a8 = False - def __init__(self, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config): + def __init__(self, top_k, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config): + self.top_k = top_k self.topk_weights = topk_weights self.topk_ids = topk_ids self.sorted_token_ids = sorted_token_ids @@ -54,10 +56,15 @@ def set_use_int8_w8a16(self, b_descale): self.b_descale = b_descale self.a_descale = None + def set_use_int8_w8a8(self, a_descale, b_descale): + self.use_int8_w8a8 = True + self.a_descale = a_descale + self.b_descale = b_descale + def check_args(self, a, b, o): assert a.shape[-1] == b.shape[-1] and b.shape[1] == o.shape[-1] - assert not (self.use_fp8_w8a8 and self.use_int8_w8a16) + assert not (self.use_fp8_w8a8 and self.use_int8_w8a16 and self.use_int8_w8a8) if self.use_fp8_w8a8: assert self.fp8_type in supported_fp8, f"fp8 type {self.fp8_type} not supported" @@ -89,6 +96,7 @@ def moe_gemm_kernel( MUL_ROUTED_WEIGHT: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + use_int8_w8a8: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -146,7 +154,7 @@ def moe_gemm_kernel( b_scale_ptrs = B_scale + off_experts * stride_bse + offs_bn[None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) - if use_fp8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: a_scale = tl.load(A_scale) b_scale = tl.load(B_scale + off_experts) @@ -163,7 +171,7 @@ def moe_gemm_kernel( if use_int8_w8a16: accumulator = tl.dot(a, b.to(a.dtype), acc=accumulator) - elif use_fp8_w8a8: + elif use_fp8_w8a8 or use_int8_w8a8: accumulator += tl.dot(a, b) else: accumulator = tl.dot(a, b, acc=accumulator) @@ -177,7 +185,7 @@ def moe_gemm_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(Out.dtype.element_ty) - elif use_fp8_w8a8: + elif use_fp8_w8a8 or use_int8_w8a8: accumulator = (accumulator * a_scale * b_scale).to(Out.dtype.element_ty) else: accumulator = accumulator.to(Out.dtype.element_ty) @@ -278,11 +286,13 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False): + use_int8_w8a8: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int8_w8a8: + return "int8_w8a8" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -360,19 +370,19 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa # TODO shard M dim metadata.check_args(a, b, c) - topk_ids, num_tokens_post_padded, topk_weights, sorted_token_ids, expert_ids, config = metadata.topk_ids, metadata.num_tokens_post_padded, metadata.topk_weights, metadata.sorted_token_ids, metadata.expert_ids, metadata.config + num_tokens_post_padded, topk_weights, sorted_token_ids, expert_ids, config = metadata.num_tokens_post_padded, metadata.topk_weights, metadata.sorted_token_ids, metadata.expert_ids, metadata.config - use_fp8_w8a8, use_int8_w8a16 = metadata.use_fp8_w8a8, metadata.use_int8_w8a16 + use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8 = metadata.use_fp8_w8a8, metadata.use_int8_w8a16, metadata.use_int8_w8a8 a_descale, b_descale = None, None stride_bse = None stride_bsn = None - if use_fp8_w8a8 or use_int8_w8a16: + if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8: a_descale, b_descale = metadata.a_descale, metadata.b_descale if use_int8_w8a16: stride_bse = b_descale.stride(0) stride_bsn = b_descale.stride(1) - _, top_k = topk_ids.shape + top_k = metadata.top_k EM = num_tokens_post_padded.item() _, N, K = b.shape @@ -384,7 +394,7 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa b_descale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1), c.stride(2), stride_bse, stride_bsn, top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, K, EVEN_K, MUL_ROUTED_WEIGHT=topk_weights is not None, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, **config) + use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8, **config) return c @@ -410,8 +420,9 @@ def quantize_tensor(tensor: torch.Tensor, dtype, dim=()) -> tuple[torch.Tensor, return tensor_quantized, scale, 1 / scale -def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, metatdata: MetaData, fp8_type=None): - assert not (use_fp8_w8a8 and use_int8_w8a16) +def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, use_int8_w8a8: tl.constexpr, + metatdata: MetaData, fp8_type=None): + assert not (use_fp8_w8a8 and use_int8_w8a16 and use_int8_w8a8) assert not (use_fp8_w8a8 and fp8_type is None) if use_fp8_w8a8: @@ -420,6 +431,12 @@ def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexp metatdata.set_use_fp8_w8a8(a_descale, b_descale, fp8_type) return a_quantized, b_quantized + if use_int8_w8a8: + a_quantized, _, a_descale = quantize_tensor(a, dtype=torch.int8) + b_quantized, _, b_descale = quantize_tensor(b, dim=(0, ), dtype=torch.int8) + metatdata.set_use_int8_w8a8(a_descale, b_descale) + return a_quantized, b_quantized + if use_int8_w8a16: b_quantized, _, b_descale = quantize_tensor(b, dim=(0, 1), dtype=torch.int8) metatdata.set_use_int8_w8a16(b_descale) @@ -427,7 +444,7 @@ def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexp def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_fp8_w8a8: bool, - use_int8_w8a16: bool, fp8_type, dtype): + use_int8_w8a16: bool, use_int8_w8a8: bool, fp8_type, dtype): a = torch.randn((M, K), dtype=dtype, device='cuda') b = torch.randn((E, N, K), dtype=dtype, device='cuda') c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda') @@ -437,7 +454,8 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool softmax_vals = torch.softmax(values, dim=1) topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1) - config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, dtype=dtype) + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int8_w8a8=use_int8_w8a8, dtype=dtype) get_config_func = functools.partial( try_get_optimal_moe_config, E, @@ -446,11 +464,11 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool config = get_config_func(M) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E) - metadata = MetaData(topk_weights if routed_weight else None, topk_ids, sorted_token_ids, expert_ids, + metadata = MetaData(top_k, topk_weights if routed_weight else None, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config) - if use_fp8_w8a8 or use_int8_w8a16: - a, b = quantize_input(a, b, use_fp8_w8a8, use_int8_w8a16, metadata, fp8_type) + if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8: + a, b = quantize_input(a, b, use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8, metadata, fp8_type) return a, b, c, metadata @@ -471,7 +489,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16): torch.manual_seed(20) a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False, - use_int8_w8a16=False, fp8_type=None, dtype=dtype) + use_int8_w8a16=False, use_int8_w8a8=False, fp8_type=None, dtype=dtype) tri_out = moe_gemm(a, b, c, metadata) @@ -508,7 +526,7 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig dtype=torch.float16): torch.manual_seed(20) a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=False, fp8_type=fp8_type, dtype=dtype) + use_int8_w8a16=False, fp8_type=fp8_type, use_int8_w8a8=False, dtype=dtype) tri_out = moe_gemm(a, b, c, metadata) @@ -545,11 +563,11 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig ]) @pytest.mark.parametrize('routed_weight', [True, False]) @pytest.mark.parametrize('use_int8_w8a16', [True]) -def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a16, - dtype=torch.float16): +def test_correctness_int8_w8a16(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a16, + dtype=torch.float16): torch.manual_seed(20) a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False, - use_int8_w8a16=use_int8_w8a16, fp8_type=None, dtype=dtype) + use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=False, fp8_type=None, dtype=dtype) tri_out = moe_gemm(a, b, c, metadata) @@ -560,7 +578,7 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei a_expanded = a.unsqueeze(1).repeat(1, top_k, 1) # (M, top_k, N, K) b_indexed = b[topk_ids] - ref_out = torch.einsum("mek,menk->men", a_expanded.to(torch.float32), b_indexed.to(torch.float32)) + ref_out = torch.einsum("mek,menk->men", a_expanded.float(), b_indexed.float()) if routed_weight: ref_out *= topk_weights.unsqueeze(-1) @@ -571,6 +589,46 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2) +@pytest.mark.parametrize("M, N, K, top_k, E", [ + (64, 14336, 4096, 2, 8), + (16, 14336, 1, 2, 4), + (1, 14336, 128, 2, 4), + (16, 14336, 128, 1, 4), + (16, 14336, 128, 1, 1), + (64, 7186, 128, 2, 8), + (64, 3584, 128, 2, 8), + (64, 1792, 128, 2, 8), + (64, 64, 128, 2, 8), +]) +@pytest.mark.parametrize('routed_weight', [True, False]) +@pytest.mark.parametrize('use_int8_w8a8', [True]) +def test_correctness_int8_w8a8(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a8, + dtype=torch.float16): + torch.manual_seed(20) + a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False, + use_int8_w8a16=False, use_int8_w8a8=use_int8_w8a8, fp8_type=None, dtype=dtype) + + tri_out = moe_gemm(a, b, c, metadata) + + topk_ids = metadata.topk_ids + topk_weights = metadata.topk_weights + ref_out = torch.empty_like(c) + # Repeat a -> (M, top_k, K) + a_expanded = a.unsqueeze(1).repeat(1, top_k, 1) + # (M, top_k, N, K) + b_indexed = b[topk_ids] + ref_out = torch.einsum("mek,menk->men", a_expanded.float(), b_indexed.float()) + if routed_weight: + ref_out *= topk_weights.unsqueeze(-1) + + ref_out = ref_out * metadata.b_descale[topk_ids].unsqueeze(-1) + ref_out = ref_out * metadata.a_descale + ref_out = ref_out.to(dtype) + + # Validate correctness + torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2) + + def get_configs(): configs = [ {"M": 64, "N": 256, "K": 128, "E": 8, "top_k": 2}, @@ -606,8 +664,10 @@ def model_benchmark_configs(args): E = 8 top_k = 2 + # The first moe layer moe_configs.append((model_name, M, N1, K1, E, top_k)) - moe_configs.append((model_name, M, N2, K2, E, top_k)) + # The second moe layer + moe_configs.append((model_name, M * top_k, N2, K2, E, 1)) return moe_configs @@ -616,6 +676,7 @@ def run_benchmark(custom, args): routed_weight = args.routed_weight use_int8_w8a16 = args.int8_w8a16 use_fp8_w8a8 = args.fp8_w8a8 + use_int8_w8a8 = args.int8_w8a8 dtype = arg_to_torch_dtype[args.dtype] fp8_type = arg_to_torch_dtype[args.fp8_type] @@ -640,14 +701,15 @@ def run_benchmark(custom, args): styles=[('red', '-'), ('blue', '-'), ('yellow', '-')], ylabel='ms / TFLOPS / GB/s', plot_name='moe-gemm-benchmark', args={ 'dtype': dtype, 'routed_weight': routed_weight, 'use_fp8_w8a8': use_fp8_w8a8, 'use_int8_w8a16': - use_int8_w8a16, 'fp8_type': fp8_type + use_int8_w8a16, 'use_int8_w8a8': use_int8_w8a8, 'fp8_type': fp8_type }) @triton.testing.perf_report([benchmark]) - def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8, use_int8_w8a16, fp8_type, - model=None): + def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8, + fp8_type, model=None): a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, fp8_type=fp8_type, dtype=dtype) + use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8, fp8_type=fp8_type, + dtype=dtype) # (M, K) * (top_k, N, K) -> (M, top_k, N). 2 for multiplication and accumulation flops = 2.0 * M * top_k * K * N @@ -658,6 +720,9 @@ def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8 if use_fp8_w8a8: a_bytes = b_bytes = torch.tensor([], dtype=fp8_type).element_size() c_bytes = torch.tensor([], dtype=dtype).element_size() + if use_int8_w8a8: + a_bytes = b_bytes = torch.tensor([], dtype=torch.int8).element_size() + c_bytes = torch.tensor([], dtype=torch.int8).element_size() elif use_int8_w8a16: b_bytes = torch.tensor([], dtype=torch.int8).element_size() a_bytes = c_bytes = torch.tensor([], dtype=dtype).element_size() @@ -705,6 +770,7 @@ def parse_args(): parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token") parser.add_argument("-routed_weight", action='store_true', default=False) parser.add_argument("-int8_w8a16", action='store_true', default=False) + parser.add_argument("-int8_w8a8", action='store_true', default=False) parser.add_argument("-fp8_w8a8", action='store_true', default=False) parser.add_argument("-dtype", default='fp16') parser.add_argument("-fp8_type", default='e5m2fnuz')