Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions tests/ops/test_count_and_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import time
import pytest
import torch

from top.ops import CountAndGatherOp


def _count_and_gather_reference(
x: torch.Tensor,
topk_ids: torch.Tensor,
num_expert: int,
rank_ep: int,
tile_m: int,
):
num_seq, hidden_size = x.shape
num_topk = topk_ids.shape[1]
total_num_topk = num_seq * num_topk

start_expert = rank_ep * num_expert
end_expert = (rank_ep + 1) * num_expert

seqlens = torch.zeros(num_expert, dtype=torch.int32, device=x.device)
topk_pos = torch.full((total_num_topk,), -1, dtype=torch.int32, device=x.device)
cu_seqlens = torch.zeros(num_expert + 1, dtype=torch.int32, device=x.device)
tiles = torch.zeros(num_expert, dtype=torch.int32, device=x.device)

topk_ids_flat = topk_ids.reshape(-1)
for idx in range(total_num_topk):
iexpert = topk_ids_flat[idx]
if iexpert >= start_expert and iexpert < end_expert:
seqlens[iexpert - start_expert] += 1

for i in range(num_expert):
cu_seqlens[i + 1] = cu_seqlens[i] + seqlens[i]
tiles[i] = (seqlens[i] + tile_m - 1) // tile_m

total_tokens = cu_seqlens[-1].item()
gate_up_input = torch.zeros(total_tokens, hidden_size, dtype=x.dtype, device=x.device)

running = torch.zeros_like(seqlens)
for idx in range(total_num_topk):
iexpert = topk_ids_flat[idx]
if iexpert >= start_expert and iexpert < end_expert:
expert_idx = iexpert - start_expert
abs_pos = cu_seqlens[expert_idx] + running[expert_idx]
topk_pos[idx] = abs_pos
gate_up_input[abs_pos] = x[idx // num_topk]
running[expert_idx] += 1

return gate_up_input, topk_pos, seqlens, cu_seqlens, tiles


class TestCountAndGatherOp:
"""Test CountAndGatherOp with pytest."""

@pytest.fixture
def test_data(self):
"""Create test data for CountAndGatherOp."""
num_seq = 16
hidden_size = 64
num_topk = 2
num_expert = 4
rank_ep = 0

x = torch.randn(num_seq, hidden_size, device="cuda")
topk_ids = torch.randint(
0, num_expert, (num_seq, num_topk), device="cuda", dtype=torch.int32)

return x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk

def test_basic_functionality(self, test_data):
"""Test basic functionality of CountAndGatherOp."""
x, topk_ids, num_expert, rank_ep, num_seq, hidden_size, num_topk = test_data

op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep)
gate_up_input, topk_pos, seqlens, cu_seqlens, tiles = op.forward(x, topk_ids)
ref_gate_up_input, ref_topk_pos, ref_seqlens, ref_cu_seqlens, ref_tiles = _count_and_gather_reference(
x=x,
topk_ids=topk_ids,
num_expert=num_expert,
rank_ep=rank_ep,
tile_m=op.config["tile_m"],
)

assert gate_up_input.shape[0] == seqlens.sum().item()
assert gate_up_input.shape[1] == hidden_size
assert topk_pos.shape == (num_seq * num_topk,)
assert seqlens.shape == (num_expert,)
assert cu_seqlens.shape == (num_expert + 1,)
assert tiles.shape == (num_expert,)

for i in range(1, num_expert + 1):
assert cu_seqlens[i].item() == cu_seqlens[i - 1].item() + seqlens[i - 1].item()

for pos in topk_pos:
assert pos.item() >= -1
if pos.item() != -1:
assert pos.item() < gate_up_input.shape[0]

assert torch.equal(seqlens, ref_seqlens)
assert torch.equal(cu_seqlens, ref_cu_seqlens)
assert torch.equal(tiles, ref_tiles)
assert torch.equal(topk_pos, ref_topk_pos)
assert torch.equal(gate_up_input, ref_gate_up_input)

def test_performance(self, test_data):
"""Test performance of CountAndGatherOp."""
x, topk_ids, num_expert, rank_ep, num_seq, _, _ = test_data

op = CountAndGatherOp(num_expert=num_expert, rank_ep=rank_ep)
num_iterations = 10
start_time = time.time()

for _ in range(num_iterations):
op.forward(x, topk_ids)
torch.cuda.synchronize()

end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
throughput = num_seq / avg_time

print("Performance:")
print(f"Average time per iteration: {avg_time * 1000:.2f} ms")
print(f"Throughput: {throughput:.2f} sequences/sec")

assert avg_time < 0.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The performance assertion assert avg_time < 0.1 uses a magic number 0.1. Hardcoded thresholds can make tests brittle and prone to failure across different environments or as code evolves. It's generally better to use more robust performance checks, such as relative thresholds or named constants to improve maintainability.

Suggested change
assert avg_time < 0.1
PERFORMANCE_THRESHOLD_MS = 100 # Example: 100ms
assert avg_time * 1000 < PERFORMANCE_THRESHOLD_MS



if __name__ == "__main__":
pytest.main([__file__, "-vvs"])
98 changes: 98 additions & 0 deletions tests/ops/test_moe_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
import torch

from top.ops import MoeReduceOp


def _moe_reduce_reference(
x: torch.Tensor,
topk_pos: torch.Tensor,
topk_scale: torch.Tensor,
shared_output: torch.Tensor | None = None,
) -> torch.Tensor:
num_seq, num_topk = topk_pos.shape
num_tokens, hidden_size = x.shape

x_fp32 = x.to(torch.float32)
topk_scale_fp32 = topk_scale.to(torch.float32)
y_accum = torch.zeros((num_seq, hidden_size), dtype=torch.float32, device=x.device)

for i in range(num_topk):
pos = topk_pos[:, i]
scale = topk_scale_fp32[:, i].unsqueeze(1)
mask = (pos >= 0) & (pos < num_tokens)
valid_pos = pos[mask]
valid_scale = scale[mask]
valid_seq = torch.where(mask)[0]

if valid_pos.numel() > 0:
expert_outputs = x_fp32[valid_pos]
weighted_outputs = expert_outputs * valid_scale
y_accum.index_add_(0, valid_seq, weighted_outputs)

if shared_output is not None:
y_accum += shared_output.to(torch.float32)

return y_accum.to(x.dtype)


@pytest.mark.parametrize(
"num_seq, num_topk, hidden_size",
[
(16, 2, 64),
(32, 1, 128),
(64, 4, 256),
],
)
def test_moe_reduce_op_basic(num_seq: int, num_topk: int, hidden_size: int) -> None:
num_tokens = num_seq * num_topk
x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16)
topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32)
topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32)

op = MoeReduceOp()
output = op.forward(x, topk_pos, topk_scale)
reference = _moe_reduce_reference(x, topk_pos, topk_scale)

assert output.shape == (num_seq, hidden_size)
assert output.dtype == x.dtype
assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3)


def test_moe_reduce_op_with_shared_output() -> None:
num_seq = 16
num_topk = 2
hidden_size = 64
num_tokens = num_seq * num_topk

x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16)
topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32)
topk_scale = torch.randn(num_seq, num_topk, device="cuda", dtype=torch.float32)
shared_output = torch.randn(num_seq, hidden_size, device="cuda", dtype=torch.float16)

op = MoeReduceOp()
output = op.forward(x, topk_pos, topk_scale, shared_output)
reference = _moe_reduce_reference(x, topk_pos, topk_scale, shared_output)

assert output.shape == (num_seq, hidden_size)
assert output.dtype == x.dtype
assert torch.allclose(output, reference, atol=1e-3, rtol=1e-3)


def test_moe_reduce_op_invalid_shape() -> None:
num_seq = 16
num_topk = 2
hidden_size = 64
num_tokens = num_seq * num_topk

x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.float16)
topk_pos = torch.randint(0, num_tokens, (num_seq, num_topk), device="cuda", dtype=torch.int32)
invalid_topk_scale = torch.randn(num_seq, num_topk + 1, device="cuda", dtype=torch.float32)

op = MoeReduceOp()
with pytest.raises(ValueError):
op.forward(x, topk_pos, invalid_topk_scale)


if __name__ == "__main__":
pytest.main([__file__, "-vvs"])
135 changes: 135 additions & 0 deletions top/functions/fuse_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Optional, Tuple
import torch

from top.kernels.fuse_moe.count_and_gather import CountAndGatherKernel
from top.kernels.fuse_moe.moe_reduce import MoeReduceKernel


def count_and_gather(
x: torch.Tensor, # [num_seq, hidden_size]
topk_ids: torch.Tensor, # [num_seq, num_topk]
num_expert: int,
rank_ep: int = 0,
tile_m: int = 16,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Count and gather tokens by expert.

Args:
x: Input token features
topk_ids: Expert assignment for each token
num_expert: Number of experts
rank_ep: Expert parallel rank
tile_m: Tile size for M dimension

Returns:
gate_up_input: Gathered input for gate and up projection
topk_pos: Position mapping for each token
seqlens: Number of tokens per expert
cu_seqlens: Cumulative sequence lengths
tiles: Number of tiles per expert
"""
# Input validation
if x.ndim != 2:
raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D")

if topk_ids.ndim != 2:
raise ValueError(f"Expected topk_ids to be 2D tensor, got {topk_ids.ndim}D")

if x.shape[0] != topk_ids.shape[0]:
raise ValueError(
f"Mismatched batch size: x has {x.shape[0]}, topk_ids has {topk_ids.shape[0]}")

num_seq, hidden_size = x.shape
num_topk = topk_ids.shape[1]
kernel = CountAndGatherKernel(
num_seq=num_seq,
hidden_size=hidden_size,
num_topk=num_topk,
num_expert=num_expert,
config={"tile_m": tile_m},
)
return kernel.count_and_gather(x, topk_ids, rank_ep)
Comment on lines +44 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The count_and_gather function instantiates a new CountAndGatherKernel on every call. This is highly inefficient as it leads to repeated kernel initialization and potential JIT recompilation overhead, especially if this function is called frequently. Kernels, especially JIT-compiled ones, should ideally be instantiated once and reused to avoid this overhead. Consider refactoring to accept a pre-initialized kernel or Op instance, or to manage the kernel's lifecycle more effectively.



def reduce(
x: torch.Tensor, # [total_tokens, hidden_size]
topk_pos: torch.Tensor, # [num_seq, num_topk]
topk_scale: torch.Tensor, # [num_seq, num_topk]
shared_output: Optional[torch.Tensor] = None, # [num_seq, hidden_size]
) -> torch.Tensor:
"""Reduce (scatter-add) tokens back to original positions.

Args:
x: Expert output tokens
topk_pos: Position mapping from count_and_gather
topk_scale: Scaling factors for each token
shared_output: Optional shared output tensor

Returns:
output: Reduced output tensor [num_seq, hidden_size]
"""
# Input validation
if x.ndim != 2:
raise ValueError(f"Expected x to be 2D tensor, got {x.ndim}D")

if topk_pos.ndim != 2:
raise ValueError(f"Expected topk_pos to be 2D tensor, got {topk_pos.ndim}D")

if topk_scale.ndim != 2:
raise ValueError(f"Expected topk_scale to be 2D tensor, got {topk_scale.ndim}D")

if topk_pos.shape != topk_scale.shape:
raise ValueError(
f"Mismatched shape between topk_pos and topk_scale: {topk_pos.shape} vs {topk_scale.shape}"
)

if shared_output is not None:
if shared_output.ndim != 2:
raise ValueError(f"Expected shared_output to be 2D tensor, got {shared_output.ndim}D")

if shared_output.shape[0] != topk_pos.shape[0]:
raise ValueError(
f"Mismatched batch size: shared_output has {shared_output.shape[0]}, topk_pos has {topk_pos.shape[0]}"
)

if shared_output.shape[1] != x.shape[1]:
raise ValueError(
f"Mismatched hidden size: shared_output has {shared_output.shape[1]}, x has {x.shape[1]}"
)

num_seq, num_topk = topk_pos.shape
_, hidden_size = x.shape
kernel = MoeReduceKernel(
num_seq=num_seq,
hidden_size=hidden_size,
num_topk=num_topk,
)
return kernel.forward(x, topk_pos, topk_scale, shared_output)
Comment on lines +102 to +107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to count_and_gather, the reduce function instantiates a new MoeReduceKernel on every call. This will cause repeated kernel initialization and potential JIT recompilation, which is inefficient. For optimal performance, kernels should be instantiated once and reused across multiple calls. Consider passing a pre-initialized kernel or Op instance to this function.



def fuse_moe_pertensor_fp8(
x: torch.Tensor,
gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
topk_ids: torch.Tensor,
topk_scale: torch.Tensor,
num_expert: int,
rank_ep: int = 0,
) -> torch.Tensor:
"""Fused MoE with per-tensor FP8 quantization.

Args:
x: Input token features
gate_up_weight: Gate and up projection weights
down_weight: Down projection weights
topk_ids: Expert assignment for each token
topk_scale: Scaling factors for each token
num_expert: Number of experts
rank_ep: Expert parallel rank

Returns:
output: Fused MoE output
"""
# TODO: Implement fuse_moe_pertensor_fp8 function
# This will be implemented in a separate step
raise NotImplementedError("fuse_moe_pertensor_fp8 function not implemented yet")
7 changes: 7 additions & 0 deletions top/kernels/fuse_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .count_and_gather import CountAndGatherKernel
from .moe_reduce import MoeReduceKernel

__all__ = [
"CountAndGatherKernel",
"MoeReduceKernel",
]
Loading
Loading