diff --git a/tests/ops/test_count_and_gather.py b/tests/ops/test_count_and_gather.py new file mode 100644 index 00000000..c97d9d24 --- /dev/null +++ b/tests/ops/test_count_and_gather.py @@ -0,0 +1,130 @@ +import time +import pytest +import torch + +from top.ops import CountAndGatherOp + + +def _count_and_gather_reference( + x: torch.Tensor, + topk_ids: torch.Tensor, + num_expert: int, + rank_ep: int, + tile_m: int, +): + num_seq, hidden_size = x.shape + num_topk = topk_ids.shape[1] + total_num_topk = num_seq * num_topk + + start_expert = rank_ep * num_expert + end_expert = (rank_ep + 1) * num_expert + + seqlens = torch.zeros(num_expert, dtype=torch.int32, device=x.device) + topk_pos = torch.full((total_num_topk,), -1, dtype=torch.int32, device=x.device) + cu_seqlens = torch.zeros(num_expert + 1, dtype=torch.int32, device=x.device) + tiles = torch.zeros(num_expert, dtype=torch.int32, device=x.device) + + topk_ids_flat = topk_ids.reshape(-1) + for idx in range(total_num_topk): + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + seqlens[iexpert - start_expert] += 1 + + for i in range(num_expert): + cu_seqlens[i + 1] = cu_seqlens[i] + seqlens[i] + tiles[i] = (seqlens[i] + tile_m - 1) // tile_m + + total_tokens = cu_seqlens[-1].item() + gate_up_input = torch.zeros(total_tokens, hidden_size, dtype=x.dtype, device=x.device) + + running = torch.zeros_like(seqlens) + for idx in range(total_num_topk): + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + expert_idx = iexpert - start_expert + abs_pos = cu_seqlens[expert_idx] + running[expert_idx] + topk_pos[idx] = abs_pos + gate_up_input[abs_pos] = x[idx // num_topk] + running[expert_idx] += 1 + + return gate_up_input, topk_pos, seqlens, cu_seqlens, tiles + + +class TestCountAndGatherOp: + """Test CountAndGatherOp with pytest.""" + + @pytest.fixture + def test_data(self): + """Create test data for CountAndGatherOp.""" + num_seq = 16 + hidden_size = 64 + num_topk = 2 + num_expert = 4 + rank_ep = 0 + + x = torch.randn(num_seq, hidden_size, device="cuda") + topk_ids = torch.randint( + 0, num_expert, (num_seq, num_topk), device="cuda", dtype=torch.int32) + + return x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk + + def test_basic_functionality(self, test_data): + """Test basic functionality of CountAndGatherOp.""" + x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk = test_data + + op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep) + gate_up_input, topk_pos, seqlens, cu_seqlens, tiles = op.forward(x, topk_ids) + ref_gate_up_input, ref_topk_pos, ref_seqlens, ref_cu_seqlens, ref_tiles = _count_and_gather_reference( + x=x, + topk_ids=topk_ids, + num_expert=num_expert, + rank_ep=rank_ep, + tile_m=op.config["tile_m"], + ) + + assert gate_up_input.shape[0] == seqlens.sum().item() + assert gate_up_input.shape[1] == hidden_size + assert topk_pos.shape == (num_seq * num_topk,) + assert seqlens.shape == (num_expert,) + assert cu_seqlens.shape == (num_expert + 1,) + assert tiles.shape == (num_expert,) + + for i in range(1, num_expert + 1): + assert cu_seqlens[i].item() == cu_seqlens[i - 1].item() + seqlens[i - 1].item() + + for pos in topk_pos: + assert pos.item() >= -1 + if pos.item() != -1: + assert pos.item() < gate_up_input.shape[0] + + assert torch.equal(seqlens, ref_seqlens) + assert torch.equal(cu_seqlens, ref_cu_seqlens) + assert torch.equal(tiles, ref_tiles) + assert torch.equal(topk_pos, ref_topk_pos) + assert torch.equal(gate_up_input, ref_gate_up_input) + + def test_performance(self, test_data): + """Test performance of CountAndGatherOp.""" + x, topk_ids, num_expert, rank_ep, num_seq, _, _ = test_data + + op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep) + num_iterations = 10 + start_time = time.time() + + for _ in range(num_iterations): + op.forward(x, topk_ids) + torch.cuda.synchronize() + + end_time = time.time() + avg_time = (end_time - start_time) / num_iterations + throughput = num_seq / avg_time + + print("Performance:") + print(f"Average time per iteration: {avg_time * 1000:.2f} ms") + print(f"Throughput: {throughput:.2f} sequences/sec") + + assert avg_time < 0.1 + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_moe_reduce.py b/tests/ops/test_moe_reduce.py new file mode 100644 index 00000000..a714a253 --- /dev/null +++ b/tests/ops/test_moe_reduce.py @@ -0,0 +1,98 @@ +import pytest +import torch + +from top.ops import MoeReduceOp + + +def _moe_reduce_reference( + x: torch.Tensor, + topk_pos: torch.Tensor, + topk_scale: torch.Tensor, + shared_output: torch.Tensor | None = None, +) -> torch.Tensor: + num_seq, num_topk = topk_pos.shape + num_tokens, hidden_size = x.shape + + x_fp32 = x.to(torch.float32) + topk_scale_fp32 = topk_scale.to(torch.float32) + y_accum = torch.zeros((num_seq, hidden_size), dtype=torch.float32, device=x.device) + + for i in range(num_topk): + pos = topk_pos[:, i] + scale = topk_scale_fp32[:, i].unsqueeze(1) + mask = (pos >= 0) & (pos < num_tokens) + valid_pos = pos[mask] + valid_scale = scale[mask] + valid_seq = torch.where(mask)[0] + + if valid_pos.numel() > 0: + expert_outputs = x_fp32[valid_pos] + weighted_outputs = expert_outputs * valid_scale + y_accum.index_add_(0, valid_seq, weighted_outputs) + + if shared_output is not None: + y_accum += shared_output.to(torch.float32) + + return y_accum.to(x.dtype) + + +@pytest.mark.parametrize( + "num_seq, num_topk, hidden_size", + [ + (16, 2, 64), + (32, 1, 128), + (64, 4, 256), + ], +) +def test_moe_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None: + num_tokens = num_seq * num_topk + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) + topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) + topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32) + + op = MoeReduceOp() + output = op.forward(x, topk_pos, topk_scale) + reference = _moe_reduce_reference(x, topk_pos, topk_scale) + + assert output.shape == (num_seq, hidden_size) + assert output.dtype == x.dtype + assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) + + +def test_moe_reduce_op_with_shared_output() -> None: + num_seq = 16 + num_topk = 2 + hidden_size = 64 + num_tokens = num_seq * num_topk + + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) + topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) + topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32) + shared_output = torch.randn(num_seq, hidden_size, device="cuda", dtype=torch.float16) + + op = MoeReduceOp() + output = op.forward(x, topk_pos, topk_scale, shared_output) + reference = _moe_reduce_reference(x, topk_pos, topk_scale, shared_output) + + assert output.shape == (num_seq, hidden_size) + assert output.dtype == x.dtype + assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) + + +def test_moe_reduce_op_invalid_shape() -> None: + num_seq = 16 + num_topk = 2 + hidden_size = 64 + num_tokens = num_seq * num_topk + + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) + topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) + invalid_topk_scale = torch.randn(num_seq, num_topk + 1, device="cuda", dtype=torch.float32) + + op = MoeReduceOp() + with pytest.raises(ValueError): + op.forward(x, topk_pos, invalid_topk_scale) + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/top/functions/fuse_moe.py b/top/functions/fuse_moe.py new file mode 100644 index 00000000..994ea966 --- /dev/null +++ b/top/functions/fuse_moe.py @@ -0,0 +1,135 @@ +from typing import Optional, Tuple +import torch + +from top.kernels.fuse_moe.count_and_gather import CountAndGatherKernel +from top.kernels.fuse_moe.moe_reduce import MoeReduceKernel + + +def count_and_gather( + x: torch.Tensor, # [num_seq, hidden_size] + topk_ids: torch.Tensor, # [num_seq, num_topk] + num_expert: int, + rank_ep: int = 0, + tile_m: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Count and gather tokens by expert. + + Args: + x: Input token features + topk_ids: Expert assignment for each token + num_expert: Number of experts + rank_ep: Expert parallel rank + tile_m: Tile size for M dimension + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + # Input validation + if x.ndim != 2: + raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") + + if topk_ids.ndim != 2: + raise ValueError(f"Expected topk_ids to be 2D tensor, got {topk_ids.ndim}D") + + if x.shape[0] != topk_ids.shape[0]: + raise ValueError( + f"Mismatched batch size: x has {x.shape[0]}, topk_ids has {topk_ids.shape[0]}") + + num_seq, hidden_size = x.shape + num_topk = topk_ids.shape[1] + kernel = CountAndGatherKernel( + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + num_expert=num_expert, + config={"tile_m": tile_m}, + ) + return kernel.count_and_gather(x, topk_ids, rank_ep) + + +def reduce( + x: torch.Tensor, # [total_tokens, hidden_size] + topk_pos: torch.Tensor, # [num_seq, num_topk] + topk_scale: torch.Tensor, # [num_seq, num_topk] + shared_output: Optional[torch.Tensor] = None, # [num_seq, hidden_size] +) -> torch.Tensor: + """Reduce (scatter-add) tokens back to original positions. + + Args: + x: Expert output tokens + topk_pos: Position mapping from count_and_gather + topk_scale: Scaling factors for each token + shared_output: Optional shared output tensor + + Returns: + output: Reduced output tensor [num_seq, hidden_size] + """ + # Input validation + if x.ndim != 2: + raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") + + if topk_pos.ndim != 2: + raise ValueError(f"Expected topk_pos to be 2D tensor, got {topk_pos.ndim}D") + + if topk_scale.ndim != 2: + raise ValueError(f"Expected topk_scale to be 2D tensor, got {topk_scale.ndim}D") + + if topk_pos.shape != topk_scale.shape: + raise ValueError( + f"Mismatched shape between topk_pos and topk_scale: {topk_pos.shape} vs {topk_scale.shape}" + ) + + if shared_output is not None: + if shared_output.ndim != 2: + raise ValueError(f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") + + if shared_output.shape[0] != topk_pos.shape[0]: + raise ValueError( + f"Mismatched batch size: shared_output has {shared_output.shape[0]}, topk_pos has {topk_pos.shape[0]}" + ) + + if shared_output.shape[1] != x.shape[1]: + raise ValueError( + f"Mismatched hidden size: shared_output has {shared_output.shape[1]}, x has {x.shape[1]}" + ) + + num_seq, num_topk = topk_pos.shape + _, hidden_size = x.shape + kernel = MoeReduceKernel( + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + ) + return kernel.forward(x, topk_pos, topk_scale, shared_output) + + +def fuse_moe_pertensor_fp8( + x: torch.Tensor, + gate_up_weight: torch.Tensor, + down_weight: torch.Tensor, + topk_ids: torch.Tensor, + topk_scale: torch.Tensor, + num_expert: int, + rank_ep: int = 0, +) -> torch.Tensor: + """Fused MoE with per-tensor FP8 quantization. + + Args: + x: Input token features + gate_up_weight: Gate and up projection weights + down_weight: Down projection weights + topk_ids: Expert assignment for each token + topk_scale: Scaling factors for each token + num_expert: Number of experts + rank_ep: Expert parallel rank + + Returns: + output: Fused MoE output + """ + # TODO: Implement fuse_moe_pertensor_fp8 function + # This will be implemented in a separate step + raise NotImplementedError("fuse_moe_pertensor_fp8 function not implemented yet") diff --git a/top/kernels/fuse_moe/__init__.py b/top/kernels/fuse_moe/__init__.py new file mode 100644 index 00000000..6e05b630 --- /dev/null +++ b/top/kernels/fuse_moe/__init__.py @@ -0,0 +1,7 @@ +from .count_and_gather import CountAndGatherKernel +from .moe_reduce import MoeReduceKernel + +__all__ = [ + "CountAndGatherKernel", + "MoeReduceKernel", +] diff --git a/top/kernels/fuse_moe/count_and_gather.py b/top/kernels/fuse_moe/count_and_gather.py new file mode 100644 index 00000000..d17dbfe8 --- /dev/null +++ b/top/kernels/fuse_moe/count_and_gather.py @@ -0,0 +1,199 @@ +from typing import Optional +import tilelang +import tilelang.language as T +import torch + +from top.kernels.kernel import Kernel + + +def _count_seq_and_cuseq_kernel(total_num_topk: int, num_expert: int, start_expert: int, + end_expert: int, tile_m: int): + + @tilelang.jit(out_idx=[1, 2, 3, 4]) + def _count_seq_and_cuseq_fwd_func(block_size: int = 256): + + @T.prim_func + def _count_seq_and_cuseq_main( + topk_ids_flat: T.Tensor[(total_num_topk,), T.int32], + seqlens: T.Tensor[(num_expert,), T.int32], + cu_seqlens: T.Tensor[(num_expert + 1,), T.int32], + tiles: T.Tensor[(num_expert,), T.int32], + topk_pos: T.Tensor[(total_num_topk,), T.int32], + ): + with T.Kernel(1, threads=block_size) as _: + tx = T.get_thread_binding() + seqlens_shm = T.alloc_shared((num_expert,), T.int32) + + for i in T.serial(T.ceildiv(num_expert, block_size)): + iexpert = i * block_size + tx + if iexpert < num_expert: + seqlens_shm[iexpert] = 0 + + for i in T.serial(T.ceildiv(total_num_topk, block_size)): + idx = i * block_size + tx + if idx < total_num_topk: + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + T.atomic_add(seqlens_shm[iexpert - start_expert], 1) + topk_pos[idx] = -1 + + T.sync_threads() + + if tx == 0: + cu_seqlens[0] = 0 + for i in T.serial(num_expert): + iseq = seqlens_shm[i] + seqlens[i] = iseq + tiles[i] = (iseq + tile_m - 1) // tile_m + cu_seqlens[i + 1] = cu_seqlens[i] + iseq + + return _count_seq_and_cuseq_main + + return _count_seq_and_cuseq_fwd_func + + +def _gather_kernel(num_seq: int, hidden_size: int, num_topk: int, total_num_topk: int, + num_expert: int, start_expert: int, end_expert: int): + + @tilelang.jit(out_idx=[3, 4, 5]) + def _gather_fwd_func(block_size: int = 256): + + @T.prim_func + def _gather_main( + x: T.Tensor[(num_seq, hidden_size), T.float32], + topk_ids_flat: T.Tensor[(total_num_topk,), T.int32], + cu_seqlens: T.Tensor[(num_expert + 1,), T.int32], + seqlens_runtime: T.Tensor[(num_expert,), T.int32], + topk_pos: T.Tensor[(total_num_topk,), T.int32], + gate_up_input_full: T.Tensor[(total_num_topk, hidden_size), T.float32], + ): + with T.Kernel(1, threads=block_size) as _: + tx = T.get_thread_binding() + + for i in T.serial(T.ceildiv(num_expert, block_size)): + iexpert = i * block_size + tx + if iexpert < num_expert: + seqlens_runtime[iexpert] = 0 + + for i in T.serial(T.ceildiv(total_num_topk, block_size)): + idx = i * block_size + tx + if idx < total_num_topk: + topk_pos[idx] = -1 + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + expert_idx = iexpert - start_expert + pos_in_expert = T.atomic_add( + seqlens_runtime[expert_idx], 1, return_prev=True) + irow = cu_seqlens[expert_idx] + pos_in_expert + topk_pos[idx] = irow + iseq = idx // num_topk + for ih in T.serial(hidden_size): + gate_up_input_full[irow, ih] = x[iseq, ih] + + return _gather_main + + return _gather_fwd_func + + +class CountAndGatherKernel(Kernel): + """Count and gather kernel for MoE.""" + + supported_archs: list[int] = [80, 89, 90] + + def __init__(self, + num_seq: int, + hidden_size: int, + num_topk: int, + num_expert: int, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.num_seq = num_seq + self.hidden_size = hidden_size + self.num_topk = num_topk + self.num_expert = num_expert + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return {"block_size": 256, "tile_m": 16} + + @property + def autotune_configs(self) -> list[dict]: + block_sizes = [256, 512] + tile_ms = [16, 32] + configs = [] + for block_size in block_sizes: + for tile_m in tile_ms: + configs.append({"block_size": block_size, "tile_m": tile_m}) + return configs + + def forward(self, x, topk_ids, rank_ep=0): + """Run the kernel + + Args: + x: Input token features [num_seq, hidden_size] + topk_ids: Expert assignment for each token [num_seq, num_topk] + rank_ep: Expert parallel rank + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + return self.count_and_gather(x, topk_ids, rank_ep) + + def count_and_gather(self, x, topk_ids, rank_ep=0): + """Count and gather tokens by expert. + + Args: + x: Input token features [num_seq, hidden_size] + topk_ids: Expert assignment for each token [num_seq, num_topk] + rank_ep: Expert parallel rank + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + num_seq, hidden_size = x.shape + num_topk = topk_ids.shape[1] + total_num_topk = num_seq * num_topk + num_expert = self.num_expert + + start_expert = rank_ep * num_expert + end_expert = (rank_ep + 1) * num_expert + + block_size = self.config["block_size"] + + topk_ids_flat = topk_ids.reshape(-1).contiguous() + + seqlens, cu_seqlens, tiles, _ = _count_seq_and_cuseq_kernel( + total_num_topk=total_num_topk, + num_expert=num_expert, + start_expert=start_expert, + end_expert=end_expert, + tile_m=self.config["tile_m"], + )(block_size)( + topk_ids_flat) + + x_fp32 = x.to(torch.float32) + seqlens_runtime, topk_pos, gate_up_input_full = _gather_kernel( + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + total_num_topk=total_num_topk, + num_expert=num_expert, + start_expert=start_expert, + end_expert=end_expert, + )(block_size)(x_fp32, topk_ids_flat, cu_seqlens) + + total_tokens = cu_seqlens[-1].item() + gate_up_input = gate_up_input_full[:total_tokens].to(x.dtype) + + return gate_up_input, topk_pos, seqlens_runtime, cu_seqlens, tiles diff --git a/top/kernels/fuse_moe/moe_reduce.py b/top/kernels/fuse_moe/moe_reduce.py new file mode 100644 index 00000000..ca785842 --- /dev/null +++ b/top/kernels/fuse_moe/moe_reduce.py @@ -0,0 +1,115 @@ +import tilelang +import tilelang.language as T +import torch + +from top.kernels.kernel import Kernel + + +def _moe_reduce_kernel(total_num_seq: int, num_seq: int, hidden_size: int, num_topk: int): + + @tilelang.jit(out_idx=[4]) + def _reduce_fwd_func(block_size: int = 256): + + @T.prim_func + def _reduce_main( + x_fp32: T.Tensor[(total_num_seq, hidden_size), T.float32], + topk_pos: T.Tensor[(num_seq, num_topk), T.int32], + topk_scale_fp32: T.Tensor[(num_seq, num_topk), T.float32], + shared_output_fp32: T.Tensor[(num_seq, hidden_size), T.float32], + output_fp32: T.Tensor[(num_seq, hidden_size), T.float32], + ): + with T.Kernel(num_seq, threads=block_size) as by: + tx = T.get_thread_binding() + for ih_blk in T.serial(T.ceildiv(hidden_size, block_size)): + ih = ih_blk * block_size + tx + if ih < hidden_size: + acc = T.alloc_var(T.float32) + acc = 0.0 + for itopk in T.serial(num_topk): + ipos = topk_pos[by, itopk] + if ipos >= 0 and ipos < total_num_seq: + acc += x_fp32[ipos, ih] * topk_scale_fp32[by, itopk] + acc += shared_output_fp32[by, ih] + output_fp32[by, ih] = acc + + return _reduce_main + + return _reduce_fwd_func + + +class MoeReduceKernel(Kernel): + """Reduce kernel for MoE. + + Performs scatter-add aggregation operation, aggregating expert outputs back to original positions. + """ + + supported_archs: list[int] = [80, 89, 90] + + def __init__(self, + num_seq: int, + hidden_size: int, + num_topk: int, + config: dict = None, + tune: bool = False) -> None: + super().__init__() + self.num_seq = num_seq + self.hidden_size = hidden_size + self.num_topk = num_topk + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_size": 256, + "items_per_16b": 8 # For FP16 + } + + @property + def autotune_configs(self) -> list[dict]: + block_sizes = [256, 512] + configs = [] + for block_size in block_sizes: + configs.append({"block_size": block_size, "items_per_16b": 8}) + return configs + + def forward(self, x, topk_pos, topk_scale, shared_output=None): + """Forward pass of ReduceKernel. + + Args: + x: Input tensor (expert outputs) of shape [num_tokens, hidden_size] + topk_pos: Position mapping tensor of shape [num_seq, num_topk] + topk_scale: Weight tensor of shape [num_seq, num_topk] + shared_output: Optional shared output tensor of shape [num_seq, hidden_size] + + Returns: + output: Aggregated output tensor of shape [num_seq, hidden_size] + """ + num_seq, num_topk = topk_pos.shape + num_tokens, hidden_size = x.shape + block_size = self.config["block_size"] + + if num_topk == 0: + if shared_output is None: + return torch.zeros((num_seq, hidden_size), dtype=x.dtype, device=x.device) + return shared_output.to(dtype=x.dtype) + + x_fp32 = x.to(torch.float32).contiguous() + topk_scale_fp32 = topk_scale.to(torch.float32).contiguous() + topk_pos_i32 = topk_pos.to(torch.int32).contiguous() + + if shared_output is None: + shared_output_fp32 = torch.zeros((num_seq, hidden_size), + dtype=torch.float32, + device=x.device) + else: + shared_output_fp32 = shared_output.to(torch.float32).contiguous() + + output_fp32 = _moe_reduce_kernel( + total_num_seq=num_tokens, + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + )(block_size)(x_fp32, topk_pos_i32, topk_scale_fp32, shared_output_fp32) + + return output_fp32.to(x.dtype) diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 6260df22..31703e4e 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -14,6 +14,9 @@ from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp from .mhc_pre import ManifoldConstrainedHyperConnectionPreOp from .mhc_post import ManifoldConstrainedHyperConnectionPostOp +from .count_and_gather import CountAndGatherOp +from .moe_reduce import MoeReduceOp +from .fuse_moe_pertensor_fp8 import FuseMoePertensorFp8Op from .op import Op # noqa: F401 __all__ = [ @@ -43,4 +46,7 @@ "GQAWindowSlidingOp", "ManifoldConstrainedHyperConnectionPreOp", "ManifoldConstrainedHyperConnectionPostOp", + "CountAndGatherOp", + "MoeReduceOp", + "FuseMoePertensorFp8Op", ] diff --git a/top/ops/count_and_gather.py b/top/ops/count_and_gather.py new file mode 100644 index 00000000..9163de45 --- /dev/null +++ b/top/ops/count_and_gather.py @@ -0,0 +1,60 @@ +from typing import Dict, Optional, Tuple +import torch + +from top.ops.op import Op +from top.kernels.fuse_moe.count_and_gather import CountAndGatherKernel +from top.kernels.kernel import Kernel + +__all__ = ["CountAndGatherOp"] + + +class CountAndGatherOp(Op): + """Count and gather Op for MoE.""" + + def __init__(self, + num_expert: int, + rank_ep: int = 0, + tile_m: int = 16, + kernel_map: Optional[Dict[str, Kernel]] = None, + config: Optional[dict] = None, + tune: bool = False) -> None: + self.num_expert = num_expert + self.rank_ep = rank_ep + self.config = {"tile_m": tile_m} + if config is not None: + self.config.update(config) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["CountAndGatherKernel"]( + num_seq=1, + hidden_size=1, + num_topk=1, + num_expert=self.num_expert, + config=self.config, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"CountAndGatherKernel": CountAndGatherKernel} + + def forward( + self, x: torch.Tensor, topk_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of CountAndGatherOp. + + Args: + x: Input token features [num_seq, hidden_size] + topk_ids: Expert assignment for each token [num_seq, num_topk] + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + self.kernel.num_seq = x.shape[0] + self.kernel.hidden_size = x.shape[1] + self.kernel.num_topk = topk_ids.shape[1] + return self.kernel(x, topk_ids, self.rank_ep) diff --git a/top/ops/fuse_moe_pertensor_fp8.py b/top/ops/fuse_moe_pertensor_fp8.py new file mode 100644 index 00000000..01ed3b4c --- /dev/null +++ b/top/ops/fuse_moe_pertensor_fp8.py @@ -0,0 +1,58 @@ +from typing import Optional, Dict +import torch + +from top.ops.op import Op +from top.kernels.kernel import Kernel + + +class FuseMoePertensorFp8Op(Op): + """Fused MoE with per-tensor FP8 quantization Op.""" + + def __init__(self, + num_expert: int, + rank_ep: int = 0, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.num_expert = num_expert + self.rank_ep = rank_ep + + self.config = {} + self.init_config(config, tune) + + def init_config(self, config: Optional[dict] = None, tune: bool = False) -> None: + """Initialize configuration.""" + if config is not None: + self.config.update(config) + else: + self.config = self.default_config + + print(f"{self.__class__.__name__} initialized with config: {self.config}") + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {} + + @property + def default_config(self) -> dict: + return {} + + @property + def autotune_configs(self) -> list[dict]: + return [] + + def forward(self, x: torch.Tensor, gate_up_weight: torch.Tensor, down_weight: torch.Tensor, + topk_ids: torch.Tensor, topk_scale: torch.Tensor) -> torch.Tensor: + """Forward pass of FuseMoePertensorFp8Op. + + Args: + x: Input token features + gate_up_weight: Gate and up projection weights + down_weight: Down projection weights + topk_ids: Expert assignment for each token + topk_scale: Scaling factors for each token + + Returns: + output: Fused MoE output + """ + raise NotImplementedError("FuseMoePertensorFp8Op not implemented yet") diff --git a/top/ops/moe_reduce.py b/top/ops/moe_reduce.py new file mode 100644 index 00000000..d163fef2 --- /dev/null +++ b/top/ops/moe_reduce.py @@ -0,0 +1,81 @@ +from typing import Dict, Optional +import torch + +from top.ops.op import Op +from top.kernels.kernel import Kernel +from top.kernels.fuse_moe.moe_reduce import MoeReduceKernel + +__all__ = ["MoeReduceOp"] + + +class MoeReduceOp(Op): + """MoE Reduce Op.""" + + def __init__(self, + kernel_map: Optional[Dict[str, Kernel]] = None, + config: Optional[dict] = None, + tune: bool = False) -> None: + self.config = {} if config is None else dict(config) + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["MoeReduceKernel"]( + num_seq=1, + hidden_size=1, + num_topk=1, + config=self.config if self.config else None, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"MoeReduceKernel": MoeReduceKernel} + + def forward(self, + x: torch.Tensor, + topk_pos: torch.Tensor, + topk_scale: torch.Tensor, + shared_output: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of MoeReduceOp. + + Args: + x: Expert output tokens [total_tokens, hidden_size] + topk_pos: Position mapping from count_and_gather [num_seq, num_topk] + topk_scale: Scaling factors for each token [num_seq, num_topk] + shared_output: Optional shared output tensor [num_seq, hidden_size] + + Returns: + output: Reduced output tensor [num_seq, hidden_size] + """ + if x.ndim != 2: + raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") + + if topk_pos.ndim != 2: + raise ValueError(f"Expected topk_pos to be 2D tensor, got {topk_pos.ndim}D") + + if topk_scale.ndim != 2: + raise ValueError(f"Expected topk_scale to be 2D tensor, got {topk_scale.ndim}D") + + if topk_pos.shape != topk_scale.shape: + raise ValueError( + f"Mismatched shape between topk_pos and topk_scale: {topk_pos.shape} vs {topk_scale.shape}" + ) + + if shared_output is not None: + if shared_output.ndim != 2: + raise ValueError( + f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") + + if shared_output.shape[0] != topk_pos.shape[0]: + raise ValueError( + f"Mismatched batch size: shared_output has {shared_output.shape[0]}, topk_pos has {topk_pos.shape[0]}" + ) + + if shared_output.shape[1] != x.shape[1]: + raise ValueError( + f"Mismatched hidden size: shared_output has {shared_output.shape[1]}, x has {x.shape[1]}" + ) + + num_seq, num_topk = topk_pos.shape + self.kernel.num_seq = num_seq + self.kernel.hidden_size = x.shape[1] + self.kernel.num_topk = num_topk + return self.kernel(x, topk_pos, topk_scale, shared_output)