Skip to content

Commit ea908c0

Browse files
author
tanqingshan (A)
committed
add ut for cumsum_group_list
Signed-off-by: tanqingshan (A) <[email protected]> Signed-off-by: tanqingshan (A) <[email protected]>
1 parent 0528d70 commit ea908c0

File tree

1 file changed

+165
-0
lines changed

1 file changed

+165
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import pytest
2+
import torch
3+
4+
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list
5+
6+
7+
# Test configuration: Cover all supported type conversion combinations
8+
@pytest.mark.parametrize(
9+
"src_type, dst_type, input_tensor, kwargs, expected_output",
10+
[
11+
# 1. Same source and destination type (0→0)
12+
(
13+
0,
14+
0,
15+
torch.tensor([1, 3, 5, 7]),
16+
{},
17+
torch.tensor([1, 3, 5, 7]),
18+
),
19+
# 2. Same source and destination type (1→1)
20+
(
21+
1,
22+
1,
23+
torch.tensor([2, 4, 6]),
24+
{},
25+
torch.tensor([2, 4, 6]),
26+
),
27+
# 3. Same source and destination type (2→2)
28+
(
29+
2,
30+
2,
31+
torch.tensor([[0, 2], [2, 3], [5, 1]]),
32+
{},
33+
torch.tensor([[0, 2], [2, 3], [5, 1]]),
34+
),
35+
# 4. 1→0 (cumsum conversion)
36+
(
37+
1,
38+
0,
39+
torch.tensor([2, 1, 3, 4]),
40+
{},
41+
torch.tensor([2, 3, 6, 10]),
42+
),
43+
# 5. 0→1 (difference conversion)
44+
(
45+
0,
46+
1,
47+
torch.tensor([2, 3, 6, 10]),
48+
{},
49+
torch.tensor([2, 1, 3, 4]),
50+
),
51+
# 6. 2→0 (expert-token mapping conversion) - Basic scenario
52+
(
53+
2,
54+
0,
55+
torch.tensor([[0, 2], [2, 3], [5, 1]]),
56+
{
57+
"active_num": 0,
58+
"expert_num": 6
59+
},
60+
torch.tensor([2, 0, 3, 0, 0, 1]),
61+
),
62+
# 7. 2→0 - Edge scenario (no expert interval)
63+
(
64+
2,
65+
0,
66+
torch.tensor([[1, 5], [3, 2], [4, 4]]),
67+
{
68+
"active_num": -1,
69+
"expert_num": 5
70+
},
71+
torch.tensor([-1, 5, -1, 2, 4]),
72+
),
73+
# 8. 2→0 - Single expert
74+
(
75+
2,
76+
0,
77+
torch.tensor([[0, 10]]),
78+
{
79+
"active_num": 5,
80+
"expert_num": 1
81+
},
82+
torch.tensor([10]),
83+
),
84+
],
85+
)
86+
def test_cumsum_group_list_valid_cases(src_type, dst_type, input_tensor,
87+
kwargs, expected_output):
88+
"""Test scenarios with valid type conversions"""
89+
result = cumsum_group_list(input_tensor, src_type, dst_type, **kwargs)
90+
# Verify result shape and values
91+
assert result.shape == expected_output.shape
92+
assert torch.allclose(result, expected_output)
93+
94+
95+
def test_cumsum_group_list_invalid_src_type():
96+
"""Test invalid source type"""
97+
input_tensor = torch.tensor([1, 2, 3])
98+
with pytest.raises(ValueError) as excinfo:
99+
cumsum_group_list(input_tensor, src_list_type=3, dst_list_type=0)
100+
assert "group_list_type should be in [0, 1, 2], but received 3" in str(
101+
excinfo.value)
102+
103+
104+
def test_cumsum_group_list_unimplemented_conversion():
105+
"""Test unimplemented type conversions"""
106+
input_tensor = torch.tensor([1, 2, 3])
107+
# Test 0→2 (unimplemented)
108+
with pytest.raises(NotImplementedError) as excinfo:
109+
cumsum_group_list(input_tensor, src_list_type=0, dst_list_type=2)
110+
assert "Conversion from src_list_type=0 to dst_list_type=2 is not implemented yet" in str(
111+
excinfo.value)
112+
113+
# Test 1→2 (unimplemented)
114+
with pytest.raises(NotImplementedError):
115+
cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2)
116+
117+
# Test 2→1 (unimplemented)
118+
input_2d = torch.tensor([[0, 1], [2, 3]])
119+
with pytest.raises(NotImplementedError):
120+
cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1)
121+
122+
123+
def test_cumsum_group_list_edge_cases():
124+
"""Test edge cases"""
125+
# Empty tensor (1→0)
126+
empty_tensor = torch.tensor([], dtype=torch.int64)
127+
result = cumsum_group_list(empty_tensor, src_list_type=1, dst_list_type=0)
128+
assert torch.equal(result, empty_tensor)
129+
130+
# Single-element tensor (0→1)
131+
single_tensor = torch.tensor([5])
132+
result = cumsum_group_list(single_tensor, src_list_type=0, dst_list_type=1)
133+
assert torch.equal(result, torch.tensor([5]))
134+
135+
# 2→0 - Empty input
136+
empty_2d = torch.tensor([], dtype=torch.int64).reshape(0, 2)
137+
result = cumsum_group_list(empty_2d,
138+
src_list_type=2,
139+
dst_list_type=0,
140+
active_num=0,
141+
expert_num=3)
142+
assert torch.equal(result, torch.tensor([0, 0, 0]))
143+
144+
145+
def test_cumsum_group_list_dtype_device_consistency():
146+
"""Test consistency of output dtype and device with input"""
147+
# Test GPU (if available)
148+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
149+
input_tensor = torch.tensor([[1, 2], [3, 4]],
150+
dtype=torch.float32,
151+
device=device)
152+
result = cumsum_group_list(
153+
input_tensor,
154+
src_list_type=2,
155+
dst_list_type=0,
156+
active_num=0.0,
157+
expert_num=4,
158+
)
159+
assert result.dtype == torch.float32
160+
assert result.device == device
161+
162+
# Test int64 dtype
163+
input_int = torch.tensor([2, 4, 6], dtype=torch.int64)
164+
result_int = cumsum_group_list(input_int, src_list_type=0, dst_list_type=1)
165+
assert result_int.dtype == torch.int64

0 commit comments

Comments
 (0)