-
Notifications
You must be signed in to change notification settings - Fork 23
[Feat] Add CountAndGather and MoeReduce operations for FusedMoe #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| import time | ||
| import pytest | ||
| import torch | ||
|
|
||
| from top.ops import CountAndGatherOp | ||
|
|
||
|
|
||
| def _count_and_gather_reference( | ||
| x: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| num_expert: int, | ||
| rank_ep: int, | ||
| tile_m: int, | ||
| ): | ||
| num_seq, hidden_size = x.shape | ||
| num_topk = topk_ids.shape[1] | ||
| total_num_topk = num_seq * num_topk | ||
|
|
||
| start_expert = rank_ep * num_expert | ||
| end_expert = (rank_ep + 1) * num_expert | ||
|
|
||
| seqlens = torch.zeros(num_expert, dtype=torch.int32, device=x.device) | ||
| topk_pos = torch.full((total_num_topk,), -1, dtype=torch.int32, device=x.device) | ||
| cu_seqlens = torch.zeros(num_expert + 1, dtype=torch.int32, device=x.device) | ||
| tiles = torch.zeros(num_expert, dtype=torch.int32, device=x.device) | ||
|
|
||
| topk_ids_flat = topk_ids.reshape(-1) | ||
| for idx in range(total_num_topk): | ||
| iexpert = topk_ids_flat[idx] | ||
| if iexpert >= start_expert and iexpert < end_expert: | ||
| seqlens[iexpert - start_expert] += 1 | ||
|
|
||
| for i in range(num_expert): | ||
| cu_seqlens[i + 1] = cu_seqlens[i] + seqlens[i] | ||
| tiles[i] = (seqlens[i] + tile_m - 1) // tile_m | ||
|
|
||
| total_tokens = cu_seqlens[-1].item() | ||
| gate_up_input = torch.zeros(total_tokens, hidden_size, dtype=x.dtype, device=x.device) | ||
|
|
||
| running = torch.zeros_like(seqlens) | ||
| for idx in range(total_num_topk): | ||
| iexpert = topk_ids_flat[idx] | ||
| if iexpert >= start_expert and iexpert < end_expert: | ||
| expert_idx = iexpert - start_expert | ||
| abs_pos = cu_seqlens[expert_idx] + running[expert_idx] | ||
| topk_pos[idx] = abs_pos | ||
| gate_up_input[abs_pos] = x[idx // num_topk] | ||
| running[expert_idx] += 1 | ||
|
|
||
| return gate_up_input, topk_pos, seqlens, cu_seqlens, tiles | ||
|
|
||
|
|
||
| class TestCountAndGatherOp: | ||
| """Test CountAndGatherOp with pytest.""" | ||
|
|
||
| @pytest.fixture | ||
| def test_data(self): | ||
| """Create test data for CountAndGatherOp.""" | ||
| num_seq = 16 | ||
| hidden_size = 64 | ||
| num_topk = 2 | ||
| num_expert = 4 | ||
| rank_ep = 0 | ||
|
|
||
| x = torch.randn(num_seq, hidden_size, device="cuda") | ||
| topk_ids = torch.randint( | ||
| 0, num_expert, (num_seq, num_topk), device="cuda", dtype=torch.int32) | ||
|
|
||
| return x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk | ||
|
|
||
| def test_basic_functionality(self, test_data): | ||
| """Test basic functionality of CountAndGatherOp.""" | ||
| x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk = test_data | ||
|
|
||
| op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep) | ||
| gate_up_input, topk_pos, seqlens, cu_seqlens, tiles = op.forward(x, topk_ids) | ||
| ref_gate_up_input, ref_topk_pos, ref_seqlens, ref_cu_seqlens, ref_tiles = _count_and_gather_reference( | ||
| x=x, | ||
| topk_ids=topk_ids, | ||
| num_expert=num_expert, | ||
| rank_ep=rank_ep, | ||
| tile_m=op.config["tile_m"], | ||
| ) | ||
|
|
||
| assert gate_up_input.shape[0] == seqlens.sum().item() | ||
| assert gate_up_input.shape[1] == hidden_size | ||
| assert topk_pos.shape == (num_seq * num_topk,) | ||
| assert seqlens.shape == (num_expert,) | ||
| assert cu_seqlens.shape == (num_expert + 1,) | ||
| assert tiles.shape == (num_expert,) | ||
|
|
||
| for i in range(1, num_expert + 1): | ||
| assert cu_seqlens[i].item() == cu_seqlens[i - 1].item() + seqlens[i - 1].item() | ||
|
|
||
| for pos in topk_pos: | ||
| assert pos.item() >= -1 | ||
| if pos.item() != -1: | ||
| assert pos.item() < gate_up_input.shape[0] | ||
|
|
||
| assert torch.equal(seqlens, ref_seqlens) | ||
| assert torch.equal(cu_seqlens, ref_cu_seqlens) | ||
| assert torch.equal(tiles, ref_tiles) | ||
| assert torch.equal(topk_pos, ref_topk_pos) | ||
| assert torch.equal(gate_up_input, ref_gate_up_input) | ||
|
|
||
| def test_performance(self, test_data): | ||
| """Test performance of CountAndGatherOp.""" | ||
| x, topk_ids, num_expert, rank_ep, num_seq, _, _ = test_data | ||
|
|
||
| op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep) | ||
| num_iterations = 10 | ||
| start_time = time.time() | ||
|
|
||
| for _ in range(num_iterations): | ||
| op.forward(x, topk_ids) | ||
| torch.cuda.synchronize() | ||
|
|
||
| end_time = time.time() | ||
| avg_time = (end_time - start_time) / num_iterations | ||
| throughput = num_seq / avg_time | ||
|
|
||
| print("Performance:") | ||
| print(f"Average time per iteration: {avg_time * 1000:.2f} ms") | ||
| print(f"Throughput: {throughput:.2f} sequences/sec") | ||
|
|
||
| assert avg_time < 0.1 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-vvs"]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import pytest | ||
| import torch | ||
|
|
||
| from top.ops import MoeReduceOp | ||
|
|
||
|
|
||
| def _moe_reduce_reference( | ||
| x: torch.Tensor, | ||
| topk_pos: torch.Tensor, | ||
| topk_scale: torch.Tensor, | ||
| shared_output: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| num_seq, num_topk = topk_pos.shape | ||
| num_tokens, hidden_size = x.shape | ||
|
|
||
| x_fp32 = x.to(torch.float32) | ||
| topk_scale_fp32 = topk_scale.to(torch.float32) | ||
| y_accum = torch.zeros((num_seq, hidden_size), dtype=torch.float32, device=x.device) | ||
|
|
||
| for i in range(num_topk): | ||
| pos = topk_pos[:, i] | ||
| scale = topk_scale_fp32[:, i].unsqueeze(1) | ||
| mask = (pos >= 0) & (pos < num_tokens) | ||
| valid_pos = pos[mask] | ||
| valid_scale = scale[mask] | ||
| valid_seq = torch.where(mask)[0] | ||
|
|
||
| if valid_pos.numel() > 0: | ||
| expert_outputs = x_fp32[valid_pos] | ||
| weighted_outputs = expert_outputs * valid_scale | ||
| y_accum.index_add_(0, valid_seq, weighted_outputs) | ||
|
|
||
| if shared_output is not None: | ||
| y_accum += shared_output.to(torch.float32) | ||
|
|
||
| return y_accum.to(x.dtype) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "num_seq, num_topk, hidden_size", | ||
| [ | ||
| (16, 2, 64), | ||
| (32, 1, 128), | ||
| (64, 4, 256), | ||
| ], | ||
| ) | ||
| def test_moe_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None: | ||
| num_tokens = num_seq * num_topk | ||
| x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) | ||
| topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) | ||
| topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32) | ||
|
|
||
| op = MoeReduceOp() | ||
| output = op.forward(x, topk_pos, topk_scale) | ||
| reference = _moe_reduce_reference(x, topk_pos, topk_scale) | ||
|
|
||
| assert output.shape == (num_seq, hidden_size) | ||
| assert output.dtype == x.dtype | ||
| assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) | ||
|
|
||
|
|
||
| def test_moe_reduce_op_with_shared_output() -> None: | ||
| num_seq = 16 | ||
| num_topk = 2 | ||
| hidden_size = 64 | ||
| num_tokens = num_seq * num_topk | ||
|
|
||
| x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) | ||
| topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) | ||
| topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32) | ||
| shared_output = torch.randn(num_seq, hidden_size, device="cuda", dtype=torch.float16) | ||
|
|
||
| op = MoeReduceOp() | ||
| output = op.forward(x, topk_pos, topk_scale, shared_output) | ||
| reference = _moe_reduce_reference(x, topk_pos, topk_scale, shared_output) | ||
|
|
||
| assert output.shape == (num_seq, hidden_size) | ||
| assert output.dtype == x.dtype | ||
| assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3) | ||
|
|
||
|
|
||
| def test_moe_reduce_op_invalid_shape() -> None: | ||
| num_seq = 16 | ||
| num_topk = 2 | ||
| hidden_size = 64 | ||
| num_tokens = num_seq * num_topk | ||
|
|
||
| x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16) | ||
| topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32) | ||
| invalid_topk_scale = torch.randn(num_seq, num_topk + 1, device="cuda", dtype=torch.float32) | ||
|
|
||
| op = MoeReduceOp() | ||
| with pytest.raises(ValueError): | ||
| op.forward(x, topk_pos, invalid_topk_scale) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-vvs"]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| from typing import Optional, Tuple | ||
| import torch | ||
|
|
||
| from top.kernels.fuse_moe.count_and_gather import CountAndGatherKernel | ||
| from top.kernels.fuse_moe.moe_reduce import MoeReduceKernel | ||
|
|
||
|
|
||
| def count_and_gather( | ||
| x: torch.Tensor, # [num_seq, hidden_size] | ||
| topk_ids: torch.Tensor, # [num_seq, num_topk] | ||
| num_expert: int, | ||
| rank_ep: int = 0, | ||
| tile_m: int = 16, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Count and gather tokens by expert. | ||
|
|
||
| Args: | ||
| x: Input token features | ||
| topk_ids: Expert assignment for each token | ||
| num_expert: Number of experts | ||
| rank_ep: Expert parallel rank | ||
| tile_m: Tile size for M dimension | ||
|
|
||
| Returns: | ||
| gate_up_input: Gathered input for gate and up projection | ||
| topk_pos: Position mapping for each token | ||
| seqlens: Number of tokens per expert | ||
| cu_seqlens: Cumulative sequence lengths | ||
| tiles: Number of tiles per expert | ||
| """ | ||
| # Input validation | ||
| if x.ndim != 2: | ||
| raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") | ||
|
|
||
| if topk_ids.ndim != 2: | ||
| raise ValueError(f"Expected topk_ids to be 2D tensor, got {topk_ids.ndim}D") | ||
|
|
||
| if x.shape[0] != topk_ids.shape[0]: | ||
| raise ValueError( | ||
| f"Mismatched batch size: x has {x.shape[0]}, topk_ids has {topk_ids.shape[0]}") | ||
|
|
||
| num_seq, hidden_size = x.shape | ||
| num_topk = topk_ids.shape[1] | ||
| kernel = CountAndGatherKernel( | ||
| num_seq=num_seq, | ||
| hidden_size=hidden_size, | ||
| num_topk=num_topk, | ||
| num_expert=num_expert, | ||
| config={"tile_m": tile_m}, | ||
| ) | ||
| return kernel.count_and_gather(x, topk_ids, rank_ep) | ||
|
Comment on lines
+44
to
+51
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
zhen8838 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def reduce( | ||
| x: torch.Tensor, # [total_tokens, hidden_size] | ||
| topk_pos: torch.Tensor, # [num_seq, num_topk] | ||
| topk_scale: torch.Tensor, # [num_seq, num_topk] | ||
| shared_output: Optional[torch.Tensor] = None, # [num_seq, hidden_size] | ||
| ) -> torch.Tensor: | ||
| """Reduce (scatter-add) tokens back to original positions. | ||
|
|
||
| Args: | ||
| x: Expert output tokens | ||
| topk_pos: Position mapping from count_and_gather | ||
| topk_scale: Scaling factors for each token | ||
| shared_output: Optional shared output tensor | ||
|
|
||
| Returns: | ||
| output: Reduced output tensor [num_seq, hidden_size] | ||
| """ | ||
| # Input validation | ||
| if x.ndim != 2: | ||
| raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D") | ||
|
|
||
| if topk_pos.ndim != 2: | ||
| raise ValueError(f"Expected topk_pos to be 2D tensor, got {topk_pos.ndim}D") | ||
|
|
||
| if topk_scale.ndim != 2: | ||
| raise ValueError(f"Expected topk_scale to be 2D tensor, got {topk_scale.ndim}D") | ||
|
|
||
| if topk_pos.shape != topk_scale.shape: | ||
| raise ValueError( | ||
| f"Mismatched shape between topk_pos and topk_scale: {topk_pos.shape} vs {topk_scale.shape}" | ||
| ) | ||
|
|
||
| if shared_output is not None: | ||
| if shared_output.ndim != 2: | ||
| raise ValueError(f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D") | ||
|
|
||
| if shared_output.shape[0] != topk_pos.shape[0]: | ||
| raise ValueError( | ||
| f"Mismatched batch size: shared_output has {shared_output.shape[0]}, topk_pos has {topk_pos.shape[0]}" | ||
| ) | ||
|
|
||
| if shared_output.shape[1] != x.shape[1]: | ||
| raise ValueError( | ||
| f"Mismatched hidden size: shared_output has {shared_output.shape[1]}, x has {x.shape[1]}" | ||
| ) | ||
|
|
||
| num_seq, num_topk = topk_pos.shape | ||
| _, hidden_size = x.shape | ||
| kernel = MoeReduceKernel( | ||
| num_seq=num_seq, | ||
| hidden_size=hidden_size, | ||
| num_topk=num_topk, | ||
| ) | ||
| return kernel.forward(x, topk_pos, topk_scale, shared_output) | ||
|
Comment on lines
+102
to
+107
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to |
||
|
|
||
|
|
||
| def fuse_moe_pertensor_fp8( | ||
| x: torch.Tensor, | ||
| gate_up_weight: torch.Tensor, | ||
| down_weight: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| topk_scale: torch.Tensor, | ||
| num_expert: int, | ||
| rank_ep: int = 0, | ||
| ) -> torch.Tensor: | ||
| """Fused MoE with per-tensor FP8 quantization. | ||
|
|
||
| Args: | ||
| x: Input token features | ||
| gate_up_weight: Gate and up projection weights | ||
| down_weight: Down projection weights | ||
| topk_ids: Expert assignment for each token | ||
| topk_scale: Scaling factors for each token | ||
| num_expert: Number of experts | ||
| rank_ep: Expert parallel rank | ||
|
|
||
| Returns: | ||
| output: Fused MoE output | ||
| """ | ||
| # TODO: Implement fuse_moe_pertensor_fp8 function | ||
| # This will be implemented in a separate step | ||
| raise NotImplementedError("fuse_moe_pertensor_fp8 function not implemented yet") | ||
zhen8838 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from .count_and_gather import CountAndGatherKernel | ||
| from .moe_reduce import MoeReduceKernel | ||
|
|
||
| __all__ = [ | ||
| "CountAndGatherKernel", | ||
| "MoeReduceKernel", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The performance assertion
assert avg_time < 0.1uses a magic number0.1. Hardcoded thresholds can make tests brittle and prone to failure across different environments or as code evolves. It's generally better to use more robust performance checks, such as relative thresholds or named constants to improve maintainability.