Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"kpack": 2
},
"medium_M": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
Expand All @@ -23,7 +23,7 @@
},
"large_M": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"small_M": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"medium_M": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"large_M": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
"kpack": 1
}
}
124 changes: 95 additions & 29 deletions python/perf-kernels/fused_moe/moe-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
class MetaData():
use_fp8_w8a8 = False
use_int8_w8a16 = False
use_int8_w8a8 = False

def __init__(self, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config):
def __init__(self, top_k, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config):
self.top_k = top_k
self.topk_weights = topk_weights
self.topk_ids = topk_ids
self.sorted_token_ids = sorted_token_ids
Expand All @@ -54,10 +56,15 @@ def set_use_int8_w8a16(self, b_descale):
self.b_descale = b_descale
self.a_descale = None

def set_use_int8_w8a8(self, a_descale, b_descale):
self.use_int8_w8a8 = True
self.a_descale = a_descale
self.b_descale = b_descale

def check_args(self, a, b, o):
assert a.shape[-1] == b.shape[-1] and b.shape[1] == o.shape[-1]

assert not (self.use_fp8_w8a8 and self.use_int8_w8a16)
assert not (self.use_fp8_w8a8 and self.use_int8_w8a16 and self.use_int8_w8a8)
if self.use_fp8_w8a8:
assert self.fp8_type in supported_fp8, f"fp8 type {self.fp8_type} not supported"

Expand Down Expand Up @@ -89,6 +96,7 @@ def moe_gemm_kernel(
MUL_ROUTED_WEIGHT: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
use_int8_w8a8: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
Expand Down Expand Up @@ -146,7 +154,7 @@ def moe_gemm_kernel(
b_scale_ptrs = B_scale + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)

if use_fp8_w8a8:
if use_fp8_w8a8 or use_int8_w8a8:
a_scale = tl.load(A_scale)
b_scale = tl.load(B_scale + off_experts)

Expand All @@ -163,7 +171,7 @@ def moe_gemm_kernel(

if use_int8_w8a16:
accumulator = tl.dot(a, b.to(a.dtype), acc=accumulator)
elif use_fp8_w8a8:
elif use_fp8_w8a8 or use_int8_w8a8:
accumulator += tl.dot(a, b)
else:
accumulator = tl.dot(a, b, acc=accumulator)
Expand All @@ -177,7 +185,7 @@ def moe_gemm_kernel(

if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(Out.dtype.element_ty)
elif use_fp8_w8a8:
elif use_fp8_w8a8 or use_int8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(Out.dtype.element_ty)
else:
accumulator = accumulator.to(Out.dtype.element_ty)
Expand Down Expand Up @@ -278,11 +286,13 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int,


def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False):
use_int8_w8a8: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False):
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int8_w8a8:
return "int8_w8a8"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
Expand Down Expand Up @@ -360,19 +370,19 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa
# TODO shard M dim
metadata.check_args(a, b, c)

topk_ids, num_tokens_post_padded, topk_weights, sorted_token_ids, expert_ids, config = metadata.topk_ids, metadata.num_tokens_post_padded, metadata.topk_weights, metadata.sorted_token_ids, metadata.expert_ids, metadata.config
num_tokens_post_padded, topk_weights, sorted_token_ids, expert_ids, config = metadata.num_tokens_post_padded, metadata.topk_weights, metadata.sorted_token_ids, metadata.expert_ids, metadata.config

use_fp8_w8a8, use_int8_w8a16 = metadata.use_fp8_w8a8, metadata.use_int8_w8a16
use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8 = metadata.use_fp8_w8a8, metadata.use_int8_w8a16, metadata.use_int8_w8a8
a_descale, b_descale = None, None
stride_bse = None
stride_bsn = None
if use_fp8_w8a8 or use_int8_w8a16:
if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8:
a_descale, b_descale = metadata.a_descale, metadata.b_descale
if use_int8_w8a16:
stride_bse = b_descale.stride(0)
stride_bsn = b_descale.stride(1)

_, top_k = topk_ids.shape
top_k = metadata.top_k

EM = num_tokens_post_padded.item()
_, N, K = b.shape
Expand All @@ -384,7 +394,7 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa
b_descale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1),
c.stride(2), stride_bse, stride_bsn, top_k, topk_weights, sorted_token_ids, expert_ids, EM, N,
K, EVEN_K, MUL_ROUTED_WEIGHT=topk_weights is not None, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, **config)
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8, **config)
return c


Expand All @@ -410,8 +420,9 @@ def quantize_tensor(tensor: torch.Tensor, dtype, dim=()) -> tuple[torch.Tensor,
return tensor_quantized, scale, 1 / scale


def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, metatdata: MetaData, fp8_type=None):
assert not (use_fp8_w8a8 and use_int8_w8a16)
def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, use_int8_w8a8: tl.constexpr,
metatdata: MetaData, fp8_type=None):
assert not (use_fp8_w8a8 and use_int8_w8a16 and use_int8_w8a8)
assert not (use_fp8_w8a8 and fp8_type is None)

if use_fp8_w8a8:
Expand All @@ -420,14 +431,20 @@ def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexp
metatdata.set_use_fp8_w8a8(a_descale, b_descale, fp8_type)
return a_quantized, b_quantized

if use_int8_w8a8:
a_quantized, _, a_descale = quantize_tensor(a, dtype=torch.int8)
b_quantized, _, b_descale = quantize_tensor(b, dim=(0, ), dtype=torch.int8)
metatdata.set_use_int8_w8a8(a_descale, b_descale)
return a_quantized, b_quantized

if use_int8_w8a16:
b_quantized, _, b_descale = quantize_tensor(b, dim=(0, 1), dtype=torch.int8)
metatdata.set_use_int8_w8a16(b_descale)
return a, b_quantized


def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, fp8_type, dtype):
use_int8_w8a16: bool, use_int8_w8a8: bool, fp8_type, dtype):
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((E, N, K), dtype=dtype, device='cuda')
c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda')
Expand All @@ -437,7 +454,8 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
softmax_vals = torch.softmax(values, dim=1)
topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1)

config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, dtype=dtype)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16,
use_int8_w8a8=use_int8_w8a8, dtype=dtype)
get_config_func = functools.partial(
try_get_optimal_moe_config,
E,
Expand All @@ -446,11 +464,11 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
config = get_config_func(M)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E)

metadata = MetaData(topk_weights if routed_weight else None, topk_ids, sorted_token_ids, expert_ids,
metadata = MetaData(top_k, topk_weights if routed_weight else None, topk_ids, sorted_token_ids, expert_ids,
num_tokens_post_padded, config)

if use_fp8_w8a8 or use_int8_w8a16:
a, b = quantize_input(a, b, use_fp8_w8a8, use_int8_w8a16, metadata, fp8_type)
if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8:
a, b = quantize_input(a, b, use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8, metadata, fp8_type)

return a, b, c, metadata

Expand All @@ -471,7 +489,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16):
torch.manual_seed(20)
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False,
use_int8_w8a16=False, fp8_type=None, dtype=dtype)
use_int8_w8a16=False, use_int8_w8a8=False, fp8_type=None, dtype=dtype)

tri_out = moe_gemm(a, b, c, metadata)

Expand Down Expand Up @@ -508,7 +526,7 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig
dtype=torch.float16):
torch.manual_seed(20)
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=False, fp8_type=fp8_type, dtype=dtype)
use_int8_w8a16=False, fp8_type=fp8_type, use_int8_w8a8=False, dtype=dtype)

tri_out = moe_gemm(a, b, c, metadata)

Expand Down Expand Up @@ -545,11 +563,11 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig
])
@pytest.mark.parametrize('routed_weight', [True, False])
@pytest.mark.parametrize('use_int8_w8a16', [True])
def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a16,
dtype=torch.float16):
def test_correctness_int8_w8a16(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a16,
dtype=torch.float16):
torch.manual_seed(20)
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False,
use_int8_w8a16=use_int8_w8a16, fp8_type=None, dtype=dtype)
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=False, fp8_type=None, dtype=dtype)

tri_out = moe_gemm(a, b, c, metadata)

Expand All @@ -560,7 +578,7 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei
a_expanded = a.unsqueeze(1).repeat(1, top_k, 1)
# (M, top_k, N, K)
b_indexed = b[topk_ids]
ref_out = torch.einsum("mek,menk->men", a_expanded.to(torch.float32), b_indexed.to(torch.float32))
ref_out = torch.einsum("mek,menk->men", a_expanded.float(), b_indexed.float())
if routed_weight:
ref_out *= topk_weights.unsqueeze(-1)

Expand All @@ -571,6 +589,46 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2)


@pytest.mark.parametrize("M, N, K, top_k, E", [
(64, 14336, 4096, 2, 8),
(16, 14336, 1, 2, 4),
(1, 14336, 128, 2, 4),
(16, 14336, 128, 1, 4),
(16, 14336, 128, 1, 1),
(64, 7186, 128, 2, 8),
(64, 3584, 128, 2, 8),
(64, 1792, 128, 2, 8),
(64, 64, 128, 2, 8),
])
@pytest.mark.parametrize('routed_weight', [True, False])
@pytest.mark.parametrize('use_int8_w8a8', [True])
def test_correctness_int8_w8a8(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, use_int8_w8a8,
dtype=torch.float16):
torch.manual_seed(20)
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a8=use_int8_w8a8, fp8_type=None, dtype=dtype)

tri_out = moe_gemm(a, b, c, metadata)

topk_ids = metadata.topk_ids
topk_weights = metadata.topk_weights
ref_out = torch.empty_like(c)
# Repeat a -> (M, top_k, K)
a_expanded = a.unsqueeze(1).repeat(1, top_k, 1)
# (M, top_k, N, K)
b_indexed = b[topk_ids]
ref_out = torch.einsum("mek,menk->men", a_expanded.float(), b_indexed.float())
if routed_weight:
ref_out *= topk_weights.unsqueeze(-1)

ref_out = ref_out * metadata.b_descale[topk_ids].unsqueeze(-1)
ref_out = ref_out * metadata.a_descale
ref_out = ref_out.to(dtype)

# Validate correctness
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2)


def get_configs():
configs = [
{"M": 64, "N": 256, "K": 128, "E": 8, "top_k": 2},
Expand Down Expand Up @@ -606,8 +664,10 @@ def model_benchmark_configs(args):

E = 8
top_k = 2
# The first moe layer
moe_configs.append((model_name, M, N1, K1, E, top_k))
moe_configs.append((model_name, M, N2, K2, E, top_k))
# The second moe layer
moe_configs.append((model_name, M * top_k, N2, K2, E, 1))

return moe_configs

Expand All @@ -616,6 +676,7 @@ def run_benchmark(custom, args):
routed_weight = args.routed_weight
use_int8_w8a16 = args.int8_w8a16
use_fp8_w8a8 = args.fp8_w8a8
use_int8_w8a8 = args.int8_w8a8
dtype = arg_to_torch_dtype[args.dtype]
fp8_type = arg_to_torch_dtype[args.fp8_type]

Expand All @@ -640,14 +701,15 @@ def run_benchmark(custom, args):
styles=[('red', '-'), ('blue', '-'),
('yellow', '-')], ylabel='ms / TFLOPS / GB/s', plot_name='moe-gemm-benchmark', args={
'dtype': dtype, 'routed_weight': routed_weight, 'use_fp8_w8a8': use_fp8_w8a8, 'use_int8_w8a16':
use_int8_w8a16, 'fp8_type': fp8_type
use_int8_w8a16, 'use_int8_w8a8': use_int8_w8a8, 'fp8_type': fp8_type
})

@triton.testing.perf_report([benchmark])
def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8, use_int8_w8a16, fp8_type,
model=None):
def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8, use_int8_w8a16, use_int8_w8a8,
fp8_type, model=None):
a, b, c, metadata = input_helper(M, N, K, top_k, E, routed_weight=routed_weight, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, fp8_type=fp8_type, dtype=dtype)
use_int8_w8a16=use_int8_w8a16, use_int8_w8a8=use_int8_w8a8, fp8_type=fp8_type,
dtype=dtype)

# (M, K) * (top_k, N, K) -> (M, top_k, N). 2 for multiplication and accumulation
flops = 2.0 * M * top_k * K * N
Expand All @@ -658,6 +720,9 @@ def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8
if use_fp8_w8a8:
a_bytes = b_bytes = torch.tensor([], dtype=fp8_type).element_size()
c_bytes = torch.tensor([], dtype=dtype).element_size()
if use_int8_w8a8:
a_bytes = b_bytes = torch.tensor([], dtype=torch.int8).element_size()
c_bytes = torch.tensor([], dtype=torch.int8).element_size()
elif use_int8_w8a16:
b_bytes = torch.tensor([], dtype=torch.int8).element_size()
a_bytes = c_bytes = torch.tensor([], dtype=dtype).element_size()
Expand Down Expand Up @@ -705,6 +770,7 @@ def parse_args():
parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token")
parser.add_argument("-routed_weight", action='store_true', default=False)
parser.add_argument("-int8_w8a16", action='store_true', default=False)
parser.add_argument("-int8_w8a8", action='store_true', default=False)
parser.add_argument("-fp8_w8a8", action='store_true', default=False)
parser.add_argument("-dtype", default='fp16')
parser.add_argument("-fp8_type", default='e5m2fnuz')
Expand Down
Loading