From a2d7b40cd33ef01a6ec029aaba92895d94c13798 Mon Sep 17 00:00:00 2001 From: ZhengQiHang Date: Tue, 24 Feb 2026 07:22:29 +0000 Subject: [PATCH 1/2] [Feat] Add CountAndGather and MoeReduce operations with corresponding tests and kernels --- tests/ops/test_count_and_gather.py | 129 ++++++++++++++ tests/ops/test_reduce.py | 98 +++++++++++ top/kernels/fuse_moe/__init__.py | 7 + top/kernels/fuse_moe/count_and_gather.py | 206 +++++++++++++++++++++++ top/kernels/fuse_moe/reduce.py | 119 +++++++++++++ top/ops/__init__.py | 6 + top/ops/count_and_gather.py | 58 +++++++ top/ops/fuse_moe_pertensor_fp8.py | 58 +++++++ top/ops/moe_reduce.py | 76 +++++++++ 9 files changed, 757 insertions(+) create mode 100644 tests/ops/test_count_and_gather.py create mode 100644 tests/ops/test_reduce.py create mode 100644 top/kernels/fuse_moe/__init__.py create mode 100644 top/kernels/fuse_moe/count_and_gather.py create mode 100644 top/kernels/fuse_moe/reduce.py create mode 100644 top/ops/count_and_gather.py create mode 100644 top/ops/fuse_moe_pertensor_fp8.py create mode 100644 top/ops/moe_reduce.py diff --git a/tests/ops/test_count_and_gather.py b/tests/ops/test_count_and_gather.py new file mode 100644 index 00000000..dc067d46 --- /dev/null +++ b/tests/ops/test_count_and_gather.py @@ -0,0 +1,129 @@ +import time +import pytest +import torch + +from top.ops import CountAndGatherOp + + +def _count_and_gather_reference( + x: torch.Tensor, + topk_ids: torch.Tensor, + num_expert: int, + rank_ep: int, + tile_m: int, +): + num_seq, hidden_size = x.shape + num_topk = topk_ids.shape[1] + total_num_topk = num_seq * num_topk + + start_expert = rank_ep * num_expert + end_expert = (rank_ep + 1) * num_expert + + seqlens = torch.zeros(num_expert, dtype=torch.int32, device=x.device) + topk_pos = torch.full((total_num_topk,), -1, dtype=torch.int32, device=x.device) + cu_seqlens = torch.zeros(num_expert + 1, dtype=torch.int32, device=x.device) + tiles = torch.zeros(num_expert, dtype=torch.int32, device=x.device) + + topk_ids_flat = topk_ids.reshape(-1) + for idx in range(total_num_topk): + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + seqlens[iexpert - start_expert] += 1 + + for i in range(num_expert): + cu_seqlens[i + 1] = cu_seqlens[i] + seqlens[i] + tiles[i] = (seqlens[i] + tile_m - 1) // tile_m + + total_tokens = cu_seqlens[-1].item() + gate_up_input = torch.zeros(total_tokens, hidden_size, dtype=x.dtype, device=x.device) + + running = torch.zeros_like(seqlens) + for idx in range(total_num_topk): + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + expert_idx = iexpert - start_expert + abs_pos = cu_seqlens[expert_idx] + running[expert_idx] + topk_pos[idx] = abs_pos + gate_up_input[abs_pos] = x[idx // num_topk] + running[expert_idx] += 1 + + return gate_up_input, topk_pos, seqlens, cu_seqlens, tiles + + +class TestCountAndGatherOp: + """Test CountAndGatherOp with pytest.""" + + @pytest.fixture + def test_data(self): + """Create test data for CountAndGatherOp.""" + num_seq = 16 + hidden_size = 64 + num_topk = 2 + num_expert = 4 + rank_ep = 0 + + x = torch.randn(num_seq, hidden_size, device="cuda") + topk_ids = torch.randint(0, num_expert, (num_seq, num_topk), device="cuda", dtype=torch.int32) + + return x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk + + def test_basic_functionality(self, test_data): + """Test basic functionality of CountAndGatherOp.""" + x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk = test_data + + op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep) + gate_up_input, topk_pos, seqlens, cu_seqlens, tiles = op.forward(x, topk_ids) + ref_gate_up_input, ref_topk_pos, ref_seqlens, ref_cu_seqlens, ref_tiles = _count_and_gather_reference( + x=x, + topk_ids=topk_ids, + num_expert=num_expert, + rank_ep=rank_ep, + tile_m=op.config["tile_m"], + ) + + assert gate_up_input.shape[0] == seqlens.sum().item() + assert gate_up_input.shape[1] == hidden_size + assert topk_pos.shape == (num_seq * num_topk,) + assert seqlens.shape == (num_expert,) + assert cu_seqlens.shape == (num_expert + 1,) + assert tiles.shape == (num_expert,) + + for i in range(1, num_expert + 1): + assert cu_seqlens[i].item() == cu_seqlens[i - 1].item() + seqlens[i - 1].item() + + for pos in topk_pos: + assert pos.item() >= -1 + if pos.item() != -1: + assert pos.item() < gate_up_input.shape[0] + + assert torch.equal(seqlens, ref_seqlens) + assert torch.equal(cu_seqlens, ref_cu_seqlens) + assert torch.equal(tiles, ref_tiles) + assert torch.equal(topk_pos, ref_topk_pos) + assert torch.equal(gate_up_input, ref_gate_up_input) + + def test_performance(self, test_data): + """Test performance of CountAndGatherOp.""" + x, topk_ids, num_expert, rank_ep, num_seq, _, _ = test_data + + op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep) + num_iterations = 10 + start_time = time.time() + + for _ in range(num_iterations): + op.forward(x, topk_ids) + torch.cuda.synchronize() + + end_time = time.time() + avg_time = (end_time - start_time) / num_iterations + throughput = num_seq / avg_time + + print("Performance:") + print(f"Average time per iteration: {avg_time * 1000:.2f} ms") + print(f"Throughput: {throughput:.2f} sequences/sec") + + assert avg_time < 0.1 + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_reduce.py b/tests/ops/test_reduce.py new file mode 100644 index 00000000..27cc1393 --- /dev/null +++ b/tests/ops/test_reduce.py @@ -0,0 +1,98 @@ +import pytest +import torch + +from top.ops import MoeReduceOp + + +def _reduce_reference( + x: torch.Tensor, + topk_pos: torch.Tensor, + topk_scale: torch.Tensor, + shared_output: torch.Tensor | None = None, +) -> torch.Tensor: + num_seq, num_topk = topk_pos.shape + num_tokens, hidden_size = x.shape + + x_fp32 = x.to(torch.float32) + topk_scale_fp32 = topk_scale.to(torch.float32) + y_accum = torch.zeros((num_seq, hidden_size), dtype=torch.float32, device=x.device) + + for i in range(num_topk): + pos = topk_pos[:, i] + scale = topk_scale_fp32[:, i].unsqueeze(1) + mask = (pos >= 0) & (pos < num_tokens) + valid_pos = pos[mask] + valid_scale = scale[mask] + valid_seq = torch.where(mask)[0] + + if valid_pos.numel() > 0: + expert_outputs = x_fp32[valid_pos] + weighted_outputs = expert_outputs * valid_scale + y_accum.index_add_(0, valid_seq, weighted_outputs) + + if shared_output is not None: + y_accum += shared_output.to(torch.float32) + + return y_accum.to(x.dtype) + + +@pytest.mark.parametrize( + "num_seq, num_topk, hidden_size", + [ + (16, 2, 64), + (32, 1, 128), + (64, 4, 256), + ], +) +def test_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None: + num_tokens = num_seq * num_topk + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) + topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) + topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32) + + op = MoeReduceOp() + output = op.forward(x, topk_pos, topk_scale) + reference = _reduce_reference(x, topk_pos, topk_scale) + + assert output.shape == (num_seq, hidden_size) + assert output.dtype == x.dtype + assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) + + +def test_reduce_op_with_shared_output() -> None: + num_seq = 16 + num_topk = 2 + hidden_size = 64 + num_tokens = num_seq * num_topk + + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) + topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) + topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32) + shared_output = torch.randn(num_seq, hidden_size, device="cuda", dtype=torch.float16) + + op = MoeReduceOp() + output = op.forward(x, topk_pos, topk_scale, shared_output) + reference = _reduce_reference(x, topk_pos, topk_scale, shared_output) + + assert output.shape == (num_seq, hidden_size) + assert output.dtype == x.dtype + assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) + + +def test_reduce_op_invalid_shape() -> None: + num_seq = 16 + num_topk = 2 + hidden_size = 64 + num_tokens = num_seq * num_topk + + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) + topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) + invalid_topk_scale = torch.randn(num_seq, num_topk + 1, device="cuda", dtype=torch.float32) + + op = MoeReduceOp() + with pytest.raises(ValueError): + op.forward(x, topk_pos, invalid_topk_scale) + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/top/kernels/fuse_moe/__init__.py b/top/kernels/fuse_moe/__init__.py new file mode 100644 index 00000000..12cbf2da --- /dev/null +++ b/top/kernels/fuse_moe/__init__.py @@ -0,0 +1,7 @@ +from .count_and_gather import CountAndGatherKernel +from .reduce import ReduceKernel + +__all__ = [ + "CountAndGatherKernel", + "ReduceKernel", +] diff --git a/top/kernels/fuse_moe/count_and_gather.py b/top/kernels/fuse_moe/count_and_gather.py new file mode 100644 index 00000000..ade1db8c --- /dev/null +++ b/top/kernels/fuse_moe/count_and_gather.py @@ -0,0 +1,206 @@ +from typing import Optional +import tilelang +import tilelang.language as T +import torch + +from top.kernels.kernel import Kernel + + +def _count_seq_and_cuseq_kernel(total_num_topk: int, num_expert: int, start_expert: int, + end_expert: int, tile_m: int): + + @tilelang.jit(out_idx=[1, 2, 3, 4]) + def _count_seq_and_cuseq_fwd_func(block_size: int = 256): + + @T.prim_func + def _count_seq_and_cuseq_main( + topk_ids_flat: T.Tensor[(total_num_topk,), T.int32], + seqlens: T.Tensor[(num_expert,), T.int32], + cu_seqlens: T.Tensor[(num_expert + 1,), T.int32], + tiles: T.Tensor[(num_expert,), T.int32], + topk_pos: T.Tensor[(total_num_topk,), T.int32], + ): + with T.Kernel(1, threads=block_size) as bx: + tx = T.get_thread_binding() + seqlens_shm = T.alloc_shared((num_expert,), T.int32) + + for i in T.serial(T.ceildiv(num_expert, block_size)): + iexpert = i * block_size + tx + if iexpert < num_expert: + seqlens_shm[iexpert] = 0 + + for i in T.serial(T.ceildiv(total_num_topk, block_size)): + idx = i * block_size + tx + if idx < total_num_topk: + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + T.atomic_add(seqlens_shm[iexpert - start_expert], 1) + topk_pos[idx] = -1 + + T.sync_threads() + + if tx == 0: + cu_seqlens[0] = 0 + for i in T.serial(num_expert): + iseq = seqlens_shm[i] + seqlens[i] = iseq + tiles[i] = (iseq + tile_m - 1) // tile_m + cu_seqlens[i + 1] = cu_seqlens[i] + iseq + + return _count_seq_and_cuseq_main + + return _count_seq_and_cuseq_fwd_func + + +def _gather_kernel(num_seq: int, hidden_size: int, num_topk: int, total_num_topk: int, + num_expert: int, start_expert: int, end_expert: int): + + @tilelang.jit(out_idx=[3, 4, 5]) + def _gather_fwd_func(block_size: int = 256): + + @T.prim_func + def _gather_main( + x: T.Tensor[(num_seq, hidden_size), T.float32], + topk_ids_flat: T.Tensor[(total_num_topk,), T.int32], + cu_seqlens: T.Tensor[(num_expert + 1,), T.int32], + seqlens_runtime: T.Tensor[(num_expert,), T.int32], + topk_pos: T.Tensor[(total_num_topk,), T.int32], + gate_up_input_full: T.Tensor[(total_num_topk, hidden_size), T.float32], + ): + with T.Kernel(1, threads=block_size) as bx: + tx = T.get_thread_binding() + + for i in T.serial(T.ceildiv(num_expert, block_size)): + iexpert = i * block_size + tx + if iexpert < num_expert: + seqlens_runtime[iexpert] = 0 + + for i in T.serial(T.ceildiv(total_num_topk, block_size)): + idx = i * block_size + tx + if idx < total_num_topk: + topk_pos[idx] = -1 + iexpert = topk_ids_flat[idx] + if iexpert >= start_expert and iexpert < end_expert: + expert_idx = iexpert - start_expert + pos_in_expert = T.atomic_add(seqlens_runtime[expert_idx], + 1, + return_prev=True) + irow = cu_seqlens[expert_idx] + pos_in_expert + topk_pos[idx] = irow + iseq = idx // num_topk + for ih in T.serial(hidden_size): + gate_up_input_full[irow, ih] = x[iseq, ih] + + return _gather_main + + return _gather_fwd_func + + +class CountAndGatherKernel(Kernel): + """Count and gather kernel for MoE.""" + + supported_archs: list[int] = [80, 89, 90] + + def __init__(self, + num_seq: int, + hidden_size: int, + num_topk: int, + num_expert: int, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.num_seq = num_seq + self.hidden_size = hidden_size + self.num_topk = num_topk + self.num_expert = num_expert + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_size": 256, + "tile_m": 16 + } + + @property + def autotune_configs(self) -> list[dict]: + block_sizes = [256, 512] + tile_ms = [16, 32] + configs = [] + for block_size in block_sizes: + for tile_m in tile_ms: + configs.append({ + "block_size": block_size, + "tile_m": tile_m + }) + return configs + + def forward(self, x, topk_ids, rank_ep=0): + """Run the kernel + + Args: + x: Input token features [num_seq, hidden_size] + topk_ids: Expert assignment for each token [num_seq, num_topk] + rank_ep: Expert parallel rank + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + return self.count_and_gather(x, topk_ids, rank_ep) + + def count_and_gather(self, x, topk_ids, rank_ep=0): + """Count and gather tokens by expert. + + Args: + x: Input token features [num_seq, hidden_size] + topk_ids: Expert assignment for each token [num_seq, num_topk] + rank_ep: Expert parallel rank + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + num_seq, hidden_size = x.shape + num_topk = topk_ids.shape[1] + total_num_topk = num_seq * num_topk + num_expert = self.num_expert + + start_expert = rank_ep * num_expert + end_expert = (rank_ep + 1) * num_expert + + block_size = self.config["block_size"] + + topk_ids_flat = topk_ids.reshape(-1).contiguous() + + seqlens, cu_seqlens, tiles, _ = _count_seq_and_cuseq_kernel( + total_num_topk=total_num_topk, + num_expert=num_expert, + start_expert=start_expert, + end_expert=end_expert, + tile_m=self.config["tile_m"], + )(block_size)(topk_ids_flat) + + x_fp32 = x.to(torch.float32) + seqlens_runtime, topk_pos, gate_up_input_full = _gather_kernel( + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + total_num_topk=total_num_topk, + num_expert=num_expert, + start_expert=start_expert, + end_expert=end_expert, + )(block_size)(x_fp32, topk_ids_flat, cu_seqlens) + + total_tokens = cu_seqlens[-1].item() + gate_up_input = gate_up_input_full[:total_tokens].to(x.dtype) + + return gate_up_input, topk_pos, seqlens_runtime, cu_seqlens, tiles + diff --git a/top/kernels/fuse_moe/reduce.py b/top/kernels/fuse_moe/reduce.py new file mode 100644 index 00000000..6912eca2 --- /dev/null +++ b/top/kernels/fuse_moe/reduce.py @@ -0,0 +1,119 @@ +import tilelang +import tilelang.language as T +import torch + +from top.kernels.kernel import Kernel + + +def _reduce_kernel(total_num_seq: int, num_seq: int, hidden_size: int, num_topk: int): + + @tilelang.jit(out_idx=[4]) + def _reduce_fwd_func(block_size: int = 256): + + @T.prim_func + def _reduce_main( + x_fp32: T.Tensor[(total_num_seq, hidden_size), T.float32], + topk_pos: T.Tensor[(num_seq, num_topk), T.int32], + topk_scale_fp32: T.Tensor[(num_seq, num_topk), T.float32], + shared_output_fp32: T.Tensor[(num_seq, hidden_size), T.float32], + output_fp32: T.Tensor[(num_seq, hidden_size), T.float32], + ): + with T.Kernel(num_seq, threads=block_size) as by: + tx = T.get_thread_binding() + for ih_blk in T.serial(T.ceildiv(hidden_size, block_size)): + ih = ih_blk * block_size + tx + if ih < hidden_size: + acc = T.alloc_var(T.float32) + acc = 0.0 + for itopk in T.serial(num_topk): + ipos = topk_pos[by, itopk] + if ipos >= 0 and ipos < total_num_seq: + acc += x_fp32[ipos, ih] * topk_scale_fp32[by, itopk] + acc += shared_output_fp32[by, ih] + output_fp32[by, ih] = acc + + return _reduce_main + + return _reduce_fwd_func + + +class ReduceKernel(Kernel): + """Reduce kernel for MoE. + + Performs scatter-add aggregation operation, aggregating expert outputs back to original positions. + """ + + supported_archs: list[int] = [80, 89, 90] + + def __init__(self, + num_seq: int, + hidden_size: int, + num_topk: int, + config: dict = None, + tune: bool = False) -> None: + super().__init__() + self.num_seq = num_seq + self.hidden_size = hidden_size + self.num_topk = num_topk + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_size": 256, + "items_per_16b": 8 # For FP16 + } + + @property + def autotune_configs(self) -> list[dict]: + block_sizes = [256, 512] + configs = [] + for block_size in block_sizes: + configs.append({ + "block_size": block_size, + "items_per_16b": 8 + }) + return configs + + def forward(self, x, topk_pos, topk_scale, shared_output=None): + """Forward pass of ReduceKernel. + + Args: + x: Input tensor (expert outputs) of shape [num_tokens, hidden_size] + topk_pos: Position mapping tensor of shape [num_seq, num_topk] + topk_scale: Weight tensor of shape [num_seq, num_topk] + shared_output: Optional shared output tensor of shape [num_seq, hidden_size] + + Returns: + output: Aggregated output tensor of shape [num_seq, hidden_size] + """ + num_seq, num_topk = topk_pos.shape + num_tokens, hidden_size = x.shape + block_size = self.config["block_size"] + + if num_topk == 0: + if shared_output is None: + return torch.zeros((num_seq, hidden_size), dtype=x.dtype, device=x.device) + return shared_output.to(dtype=x.dtype) + + x_fp32 = x.to(torch.float32).contiguous() + topk_scale_fp32 = topk_scale.to(torch.float32).contiguous() + topk_pos_i32 = topk_pos.to(torch.int32).contiguous() + + if shared_output is None: + shared_output_fp32 = torch.zeros((num_seq, hidden_size), + dtype=torch.float32, + device=x.device) + else: + shared_output_fp32 = shared_output.to(torch.float32).contiguous() + + output_fp32 = _reduce_kernel( + total_num_seq=num_tokens, + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + )(block_size)(x_fp32, topk_pos_i32, topk_scale_fp32, shared_output_fp32) + + return output_fp32.to(x.dtype) + \ No newline at end of file diff --git a/top/ops/__init__.py b/top/ops/__init__.py index 6260df22..31703e4e 100644 --- a/top/ops/__init__.py +++ b/top/ops/__init__.py @@ -14,6 +14,9 @@ from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp from .mhc_pre import ManifoldConstrainedHyperConnectionPreOp from .mhc_post import ManifoldConstrainedHyperConnectionPostOp +from .count_and_gather import CountAndGatherOp +from .moe_reduce import MoeReduceOp +from .fuse_moe_pertensor_fp8 import FuseMoePertensorFp8Op from .op import Op # noqa: F401 __all__ = [ @@ -43,4 +46,7 @@ "GQAWindowSlidingOp", "ManifoldConstrainedHyperConnectionPreOp", "ManifoldConstrainedHyperConnectionPostOp", + "CountAndGatherOp", + "MoeReduceOp", + "FuseMoePertensorFp8Op", ] diff --git a/top/ops/count_and_gather.py b/top/ops/count_and_gather.py new file mode 100644 index 00000000..c7e8c785 --- /dev/null +++ b/top/ops/count_and_gather.py @@ -0,0 +1,58 @@ +from typing import Dict, Optional, Tuple +import torch + +from top.ops.op import Op +from top.kernels.fuse_moe.count_and_gather import CountAndGatherKernel +from top.kernels.kernel import Kernel + +__all__ = ["CountAndGatherOp"] + + +class CountAndGatherOp(Op): + """Count and gather Op for MoE.""" + + def __init__(self, + num_expert: int, + rank_ep: int = 0, + tile_m: int = 16, + kernel_map: Optional[Dict[str, Kernel]] = None, + config: Optional[dict] = None, + tune: bool = False) -> None: + self.num_expert = num_expert + self.rank_ep = rank_ep + self.config = {"tile_m": tile_m} + if config is not None: + self.config.update(config) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["CountAndGatherKernel"]( + num_seq=1, + hidden_size=1, + num_topk=1, + num_expert=self.num_expert, + config=self.config, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"CountAndGatherKernel": CountAndGatherKernel} + + def forward(self, x: torch.Tensor, topk_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of CountAndGatherOp. + + Args: + x: Input token features [num_seq, hidden_size] + topk_ids: Expert assignment for each token [num_seq, num_topk] + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + self.kernel.num_seq = x.shape[0] + self.kernel.hidden_size = x.shape[1] + self.kernel.num_topk = topk_ids.shape[1] + return self.kernel(x, topk_ids, self.rank_ep) diff --git a/top/ops/fuse_moe_pertensor_fp8.py b/top/ops/fuse_moe_pertensor_fp8.py new file mode 100644 index 00000000..01ed3b4c --- /dev/null +++ b/top/ops/fuse_moe_pertensor_fp8.py @@ -0,0 +1,58 @@ +from typing import Optional, Dict +import torch + +from top.ops.op import Op +from top.kernels.kernel import Kernel + + +class FuseMoePertensorFp8Op(Op): + """Fused MoE with per-tensor FP8 quantization Op.""" + + def __init__(self, + num_expert: int, + rank_ep: int = 0, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.num_expert = num_expert + self.rank_ep = rank_ep + + self.config = {} + self.init_config(config, tune) + + def init_config(self, config: Optional[dict] = None, tune: bool = False) -> None: + """Initialize configuration.""" + if config is not None: + self.config.update(config) + else: + self.config = self.default_config + + print(f"{self.__class__.__name__} initialized with config: {self.config}") + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {} + + @property + def default_config(self) -> dict: + return {} + + @property + def autotune_configs(self) -> list[dict]: + return [] + + def forward(self, x: torch.Tensor, gate_up_weight: torch.Tensor, down_weight: torch.Tensor, + topk_ids: torch.Tensor, topk_scale: torch.Tensor) -> torch.Tensor: + """Forward pass of FuseMoePertensorFp8Op. + + Args: + x: Input token features + gate_up_weight: Gate and up projection weights + down_weight: Down projection weights + topk_ids: Expert assignment for each token + topk_scale: Scaling factors for each token + + Returns: + output: Fused MoE output + """ + raise NotImplementedError("FuseMoePertensorFp8Op not implemented yet") diff --git a/top/ops/moe_reduce.py b/top/ops/moe_reduce.py new file mode 100644 index 00000000..f0835091 --- /dev/null +++ b/top/ops/moe_reduce.py @@ -0,0 +1,76 @@ +from typing import Dict, Optional +import torch + +from top.ops.op import Op +from top.kernels.kernel import Kernel +from top.kernels.fuse_moe.reduce import ReduceKernel + +__all__ = ["MoeReduceOp"] + + +class MoeReduceOp(Op): + """MoE Reduce Op.""" + + def __init__(self, + kernel_map: Optional[Dict[str, Kernel]] = None, + config: Optional[dict] = None, + tune: bool = False) -> None: + self.config = {} if config is None else dict(config) + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["ReduceKernel"]( + num_seq=1, + hidden_size=1, + num_topk=1, + config=self.config if self.config else None, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"ReduceKernel": ReduceKernel} + + def forward(self, x: torch.Tensor, topk_pos: torch.Tensor, topk_scale: torch.Tensor, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of MoeReduceOp. + + Args: + x: Expert output tokens [total_tokens, hidden_size] + topk_pos: Position mapping from count_and_gather [num_seq, num_topk] + topk_scale: Scaling factors for each token [num_seq, num_topk] + shared_output: Optional shared output tensor [num_seq, hidden_size] + + Returns: + output: Reduced output tensor [num_seq, hidden_size] + """ + if x.ndim != 2: + raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") + + if topk_pos.ndim != 2: + raise ValueError(f"Expected topk_pos to be 2D tensor, got {topk_pos.ndim}D") + + if topk_scale.ndim != 2: + raise ValueError(f"Expected topk_scale to be 2D tensor, got {topk_scale.ndim}D") + + if topk_pos.shape != topk_scale.shape: + raise ValueError( + f"Mismatched shape between topk_pos and topk_scale: {topk_pos.shape} vs {topk_scale.shape}" + ) + + if shared_output is not None: + if shared_output.ndim != 2: + raise ValueError(f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") + + if shared_output.shape[0] != topk_pos.shape[0]: + raise ValueError( + f"Mismatched batch size: shared_output has {shared_output.shape[0]}, topk_pos has {topk_pos.shape[0]}" + ) + + if shared_output.shape[1] != x.shape[1]: + raise ValueError( + f"Mismatched hidden size: shared_output has {shared_output.shape[1]}, x has {x.shape[1]}" + ) + + num_seq, num_topk = topk_pos.shape + self.kernel.num_seq = num_seq + self.kernel.hidden_size = x.shape[1] + self.kernel.num_topk = num_topk + return self.kernel(x, topk_pos, topk_scale, shared_output) From 4a792484124fb6d4a063122fb2978396ded8be11 Mon Sep 17 00:00:00 2001 From: ZhengQiHang Date: Tue, 24 Feb 2026 07:39:44 +0000 Subject: [PATCH 2/2] [Feat] Implement CountAndGather and MoeReduce operations with tests and kernels --- tests/ops/test_count_and_gather.py | 3 +- .../{test_reduce.py => test_moe_reduce.py} | 12 +- top/functions/fuse_moe.py | 135 ++++++++++++++++++ top/kernels/fuse_moe/__init__.py | 4 +- top/kernels/fuse_moe/count_and_gather.py | 35 ++--- .../fuse_moe/{reduce.py => moe_reduce.py} | 30 ++-- top/ops/count_and_gather.py | 4 +- top/ops/moe_reduce.py | 15 +- 8 files changed, 185 insertions(+), 53 deletions(-) rename tests/ops/{test_reduce.py => test_moe_reduce.py} (89%) create mode 100644 top/functions/fuse_moe.py rename top/kernels/fuse_moe/{reduce.py => moe_reduce.py} (92%) diff --git a/tests/ops/test_count_and_gather.py b/tests/ops/test_count_and_gather.py index dc067d46..c97d9d24 100644 --- a/tests/ops/test_count_and_gather.py +++ b/tests/ops/test_count_and_gather.py @@ -63,7 +63,8 @@ def test_data(self): rank_ep = 0 x = torch.randn(num_seq, hidden_size, device="cuda") - topk_ids = torch.randint(0, num_expert, (num_seq, num_topk), device="cuda", dtype=torch.int32) + topk_ids = torch.randint( + 0, num_expert, (num_seq, num_topk), device="cuda", dtype=torch.int32) return x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk diff --git a/tests/ops/test_reduce.py b/tests/ops/test_moe_reduce.py similarity index 89% rename from tests/ops/test_reduce.py rename to tests/ops/test_moe_reduce.py index 27cc1393..a714a253 100644 --- a/tests/ops/test_reduce.py +++ b/tests/ops/test_moe_reduce.py @@ -4,7 +4,7 @@ from top.ops import MoeReduceOp -def _reduce_reference( +def _moe_reduce_reference( x: torch.Tensor, topk_pos: torch.Tensor, topk_scale: torch.Tensor, @@ -44,7 +44,7 @@ def _reduce_reference( (64, 4, 256), ], ) -def test_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None: +def test_moe_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None: num_tokens = num_seq * num_topk x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) @@ -52,14 +52,14 @@ def test_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None: op = MoeReduceOp() output = op.forward(x, topk_pos, topk_scale) - reference = _reduce_reference(x, topk_pos, topk_scale) + reference = _moe_reduce_reference(x, topk_pos, topk_scale) assert output.shape == (num_seq, hidden_size) assert output.dtype == x.dtype assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) -def test_reduce_op_with_shared_output() -> None: +def test_moe_reduce_op_with_shared_output() -> None: num_seq = 16 num_topk = 2 hidden_size = 64 @@ -72,14 +72,14 @@ def test_reduce_op_with_shared_output() -> None: op = MoeReduceOp() output = op.forward(x, topk_pos, topk_scale, shared_output) - reference = _reduce_reference(x, topk_pos, topk_scale, shared_output) + reference = _moe_reduce_reference(x, topk_pos, topk_scale, shared_output) assert output.shape == (num_seq, hidden_size) assert output.dtype == x.dtype assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) -def test_reduce_op_invalid_shape() -> None: +def test_moe_reduce_op_invalid_shape() -> None: num_seq = 16 num_topk = 2 hidden_size = 64 diff --git a/top/functions/fuse_moe.py b/top/functions/fuse_moe.py new file mode 100644 index 00000000..994ea966 --- /dev/null +++ b/top/functions/fuse_moe.py @@ -0,0 +1,135 @@ +from typing import Optional, Tuple +import torch + +from top.kernels.fuse_moe.count_and_gather import CountAndGatherKernel +from top.kernels.fuse_moe.moe_reduce import MoeReduceKernel + + +def count_and_gather( + x: torch.Tensor, # [num_seq, hidden_size] + topk_ids: torch.Tensor, # [num_seq, num_topk] + num_expert: int, + rank_ep: int = 0, + tile_m: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Count and gather tokens by expert. + + Args: + x: Input token features + topk_ids: Expert assignment for each token + num_expert: Number of experts + rank_ep: Expert parallel rank + tile_m: Tile size for M dimension + + Returns: + gate_up_input: Gathered input for gate and up projection + topk_pos: Position mapping for each token + seqlens: Number of tokens per expert + cu_seqlens: Cumulative sequence lengths + tiles: Number of tiles per expert + """ + # Input validation + if x.ndim != 2: + raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") + + if topk_ids.ndim != 2: + raise ValueError(f"Expected topk_ids to be 2D tensor, got {topk_ids.ndim}D") + + if x.shape[0] != topk_ids.shape[0]: + raise ValueError( + f"Mismatched batch size: x has {x.shape[0]}, topk_ids has {topk_ids.shape[0]}") + + num_seq, hidden_size = x.shape + num_topk = topk_ids.shape[1] + kernel = CountAndGatherKernel( + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + num_expert=num_expert, + config={"tile_m": tile_m}, + ) + return kernel.count_and_gather(x, topk_ids, rank_ep) + + +def reduce( + x: torch.Tensor, # [total_tokens, hidden_size] + topk_pos: torch.Tensor, # [num_seq, num_topk] + topk_scale: torch.Tensor, # [num_seq, num_topk] + shared_output: Optional[torch.Tensor] = None, # [num_seq, hidden_size] +) -> torch.Tensor: + """Reduce (scatter-add) tokens back to original positions. + + Args: + x: Expert output tokens + topk_pos: Position mapping from count_and_gather + topk_scale: Scaling factors for each token + shared_output: Optional shared output tensor + + Returns: + output: Reduced output tensor [num_seq, hidden_size] + """ + # Input validation + if x.ndim != 2: + raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") + + if topk_pos.ndim != 2: + raise ValueError(f"Expected topk_pos to be 2D tensor, got {topk_pos.ndim}D") + + if topk_scale.ndim != 2: + raise ValueError(f"Expected topk_scale to be 2D tensor, got {topk_scale.ndim}D") + + if topk_pos.shape != topk_scale.shape: + raise ValueError( + f"Mismatched shape between topk_pos and topk_scale: {topk_pos.shape} vs {topk_scale.shape}" + ) + + if shared_output is not None: + if shared_output.ndim != 2: + raise ValueError(f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") + + if shared_output.shape[0] != topk_pos.shape[0]: + raise ValueError( + f"Mismatched batch size: shared_output has {shared_output.shape[0]}, topk_pos has {topk_pos.shape[0]}" + ) + + if shared_output.shape[1] != x.shape[1]: + raise ValueError( + f"Mismatched hidden size: shared_output has {shared_output.shape[1]}, x has {x.shape[1]}" + ) + + num_seq, num_topk = topk_pos.shape + _, hidden_size = x.shape + kernel = MoeReduceKernel( + num_seq=num_seq, + hidden_size=hidden_size, + num_topk=num_topk, + ) + return kernel.forward(x, topk_pos, topk_scale, shared_output) + + +def fuse_moe_pertensor_fp8( + x: torch.Tensor, + gate_up_weight: torch.Tensor, + down_weight: torch.Tensor, + topk_ids: torch.Tensor, + topk_scale: torch.Tensor, + num_expert: int, + rank_ep: int = 0, +) -> torch.Tensor: + """Fused MoE with per-tensor FP8 quantization. + + Args: + x: Input token features + gate_up_weight: Gate and up projection weights + down_weight: Down projection weights + topk_ids: Expert assignment for each token + topk_scale: Scaling factors for each token + num_expert: Number of experts + rank_ep: Expert parallel rank + + Returns: + output: Fused MoE output + """ + # TODO: Implement fuse_moe_pertensor_fp8 function + # This will be implemented in a separate step + raise NotImplementedError("fuse_moe_pertensor_fp8 function not implemented yet") diff --git a/top/kernels/fuse_moe/__init__.py b/top/kernels/fuse_moe/__init__.py index 12cbf2da..6e05b630 100644 --- a/top/kernels/fuse_moe/__init__.py +++ b/top/kernels/fuse_moe/__init__.py @@ -1,7 +1,7 @@ from .count_and_gather import CountAndGatherKernel -from .reduce import ReduceKernel +from .moe_reduce import MoeReduceKernel __all__ = [ "CountAndGatherKernel", - "ReduceKernel", + "MoeReduceKernel", ] diff --git a/top/kernels/fuse_moe/count_and_gather.py b/top/kernels/fuse_moe/count_and_gather.py index ade1db8c..d17dbfe8 100644 --- a/top/kernels/fuse_moe/count_and_gather.py +++ b/top/kernels/fuse_moe/count_and_gather.py @@ -20,7 +20,7 @@ def _count_seq_and_cuseq_main( tiles: T.Tensor[(num_expert,), T.int32], topk_pos: T.Tensor[(total_num_topk,), T.int32], ): - with T.Kernel(1, threads=block_size) as bx: + with T.Kernel(1, threads=block_size) as _: tx = T.get_thread_binding() seqlens_shm = T.alloc_shared((num_expert,), T.int32) @@ -67,7 +67,7 @@ def _gather_main( topk_pos: T.Tensor[(total_num_topk,), T.int32], gate_up_input_full: T.Tensor[(total_num_topk, hidden_size), T.float32], ): - with T.Kernel(1, threads=block_size) as bx: + with T.Kernel(1, threads=block_size) as _: tx = T.get_thread_binding() for i in T.serial(T.ceildiv(num_expert, block_size)): @@ -82,9 +82,8 @@ def _gather_main( iexpert = topk_ids_flat[idx] if iexpert >= start_expert and iexpert < end_expert: expert_idx = iexpert - start_expert - pos_in_expert = T.atomic_add(seqlens_runtime[expert_idx], - 1, - return_prev=True) + pos_in_expert = T.atomic_add( + seqlens_runtime[expert_idx], 1, return_prev=True) irow = cu_seqlens[expert_idx] + pos_in_expert topk_pos[idx] = irow iseq = idx // num_topk @@ -118,10 +117,7 @@ def __init__(self, @property def default_config(self) -> dict: - return { - "block_size": 256, - "tile_m": 16 - } + return {"block_size": 256, "tile_m": 16} @property def autotune_configs(self) -> list[dict]: @@ -130,20 +126,17 @@ def autotune_configs(self) -> list[dict]: configs = [] for block_size in block_sizes: for tile_m in tile_ms: - configs.append({ - "block_size": block_size, - "tile_m": tile_m - }) + configs.append({"block_size": block_size, "tile_m": tile_m}) return configs def forward(self, x, topk_ids, rank_ep=0): """Run the kernel - + Args: x: Input token features [num_seq, hidden_size] topk_ids: Expert assignment for each token [num_seq, num_topk] rank_ep: Expert parallel rank - + Returns: gate_up_input: Gathered input for gate and up projection topk_pos: Position mapping for each token @@ -155,12 +148,12 @@ def forward(self, x, topk_ids, rank_ep=0): def count_and_gather(self, x, topk_ids, rank_ep=0): """Count and gather tokens by expert. - + Args: x: Input token features [num_seq, hidden_size] topk_ids: Expert assignment for each token [num_seq, num_topk] rank_ep: Expert parallel rank - + Returns: gate_up_input: Gathered input for gate and up projection topk_pos: Position mapping for each token @@ -172,10 +165,10 @@ def count_and_gather(self, x, topk_ids, rank_ep=0): num_topk = topk_ids.shape[1] total_num_topk = num_seq * num_topk num_expert = self.num_expert - + start_expert = rank_ep * num_expert end_expert = (rank_ep + 1) * num_expert - + block_size = self.config["block_size"] topk_ids_flat = topk_ids.reshape(-1).contiguous() @@ -186,7 +179,8 @@ def count_and_gather(self, x, topk_ids, rank_ep=0): start_expert=start_expert, end_expert=end_expert, tile_m=self.config["tile_m"], - )(block_size)(topk_ids_flat) + )(block_size)( + topk_ids_flat) x_fp32 = x.to(torch.float32) seqlens_runtime, topk_pos, gate_up_input_full = _gather_kernel( @@ -203,4 +197,3 @@ def count_and_gather(self, x, topk_ids, rank_ep=0): gate_up_input = gate_up_input_full[:total_tokens].to(x.dtype) return gate_up_input, topk_pos, seqlens_runtime, cu_seqlens, tiles - diff --git a/top/kernels/fuse_moe/reduce.py b/top/kernels/fuse_moe/moe_reduce.py similarity index 92% rename from top/kernels/fuse_moe/reduce.py rename to top/kernels/fuse_moe/moe_reduce.py index 6912eca2..ca785842 100644 --- a/top/kernels/fuse_moe/reduce.py +++ b/top/kernels/fuse_moe/moe_reduce.py @@ -5,7 +5,7 @@ from top.kernels.kernel import Kernel -def _reduce_kernel(total_num_seq: int, num_seq: int, hidden_size: int, num_topk: int): +def _moe_reduce_kernel(total_num_seq: int, num_seq: int, hidden_size: int, num_topk: int): @tilelang.jit(out_idx=[4]) def _reduce_fwd_func(block_size: int = 256): @@ -37,14 +37,14 @@ def _reduce_main( return _reduce_fwd_func -class ReduceKernel(Kernel): +class MoeReduceKernel(Kernel): """Reduce kernel for MoE. - + Performs scatter-add aggregation operation, aggregating expert outputs back to original positions. """ - + supported_archs: list[int] = [80, 89, 90] - + def __init__(self, num_seq: int, hidden_size: int, @@ -55,36 +55,33 @@ def __init__(self, self.num_seq = num_seq self.hidden_size = hidden_size self.num_topk = num_topk - + self.init_config(config, tune) - + @property def default_config(self) -> dict: return { "block_size": 256, "items_per_16b": 8 # For FP16 } - + @property def autotune_configs(self) -> list[dict]: block_sizes = [256, 512] configs = [] for block_size in block_sizes: - configs.append({ - "block_size": block_size, - "items_per_16b": 8 - }) + configs.append({"block_size": block_size, "items_per_16b": 8}) return configs - + def forward(self, x, topk_pos, topk_scale, shared_output=None): """Forward pass of ReduceKernel. - + Args: x: Input tensor (expert outputs) of shape [num_tokens, hidden_size] topk_pos: Position mapping tensor of shape [num_seq, num_topk] topk_scale: Weight tensor of shape [num_seq, num_topk] shared_output: Optional shared output tensor of shape [num_seq, hidden_size] - + Returns: output: Aggregated output tensor of shape [num_seq, hidden_size] """ @@ -108,7 +105,7 @@ def forward(self, x, topk_pos, topk_scale, shared_output=None): else: shared_output_fp32 = shared_output.to(torch.float32).contiguous() - output_fp32 = _reduce_kernel( + output_fp32 = _moe_reduce_kernel( total_num_seq=num_tokens, num_seq=num_seq, hidden_size=hidden_size, @@ -116,4 +113,3 @@ def forward(self, x, topk_pos, topk_scale, shared_output=None): )(block_size)(x_fp32, topk_pos_i32, topk_scale_fp32, shared_output_fp32) return output_fp32.to(x.dtype) - \ No newline at end of file diff --git a/top/ops/count_and_gather.py b/top/ops/count_and_gather.py index c7e8c785..9163de45 100644 --- a/top/ops/count_and_gather.py +++ b/top/ops/count_and_gather.py @@ -38,7 +38,9 @@ def __init__(self, def default_kernel_map(self) -> Dict[str, Kernel]: return {"CountAndGatherKernel": CountAndGatherKernel} - def forward(self, x: torch.Tensor, topk_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, x: torch.Tensor, topk_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass of CountAndGatherOp. Args: diff --git a/top/ops/moe_reduce.py b/top/ops/moe_reduce.py index f0835091..d163fef2 100644 --- a/top/ops/moe_reduce.py +++ b/top/ops/moe_reduce.py @@ -3,7 +3,7 @@ from top.ops.op import Op from top.kernels.kernel import Kernel -from top.kernels.fuse_moe.reduce import ReduceKernel +from top.kernels.fuse_moe.moe_reduce import MoeReduceKernel __all__ = ["MoeReduceOp"] @@ -17,7 +17,7 @@ def __init__(self, tune: bool = False) -> None: self.config = {} if config is None else dict(config) self.dispatch_kernel(kernel_map) - self.kernel = self.kernel_map["ReduceKernel"]( + self.kernel = self.kernel_map["MoeReduceKernel"]( num_seq=1, hidden_size=1, num_topk=1, @@ -27,9 +27,13 @@ def __init__(self, @property def default_kernel_map(self) -> Dict[str, Kernel]: - return {"ReduceKernel": ReduceKernel} + return {"MoeReduceKernel": MoeReduceKernel} - def forward(self, x: torch.Tensor, topk_pos: torch.Tensor, topk_scale: torch.Tensor, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, + x: torch.Tensor, + topk_pos: torch.Tensor, + topk_scale: torch.Tensor, + shared_output: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward pass of MoeReduceOp. Args: @@ -57,7 +61,8 @@ def forward(self, x: torch.Tensor, topk_pos: torch.Tensor, topk_scale: torch.Ten if shared_output is not None: if shared_output.ndim != 2: - raise ValueError(f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") + raise ValueError( + f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") if shared_output.shape[0] != topk_pos.shape[0]: raise ValueError(