-
Notifications
You must be signed in to change notification settings - Fork 665
[UT]Ut for function cumsum_group_list in v11.0-dev (ref #5023) #5037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The head ref may contain hidden characters: "ut-for-bugfix\u2014dev"
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list | ||
|
|
||
|
|
||
| # Test configuration: Cover all supported type conversion combinations | ||
| @pytest.mark.parametrize( | ||
| "src_type, dst_type, input_tensor, kwargs, expected_output", | ||
| [ | ||
| # 1. Same source and destination type (0→0) | ||
| ( | ||
| 0, | ||
| 0, | ||
| torch.tensor([1, 3, 5, 7]), | ||
| {}, | ||
| torch.tensor([1, 3, 5, 7]), | ||
| ), | ||
| # 2. Same source and destination type (1→1) | ||
| ( | ||
| 1, | ||
| 1, | ||
| torch.tensor([2, 4, 6]), | ||
| {}, | ||
| torch.tensor([2, 4, 6]), | ||
| ), | ||
| # 3. Same source and destination type (2→2) | ||
| ( | ||
| 2, | ||
| 2, | ||
| torch.tensor([[0, 2], [2, 3], [5, 1]]), | ||
| {}, | ||
| torch.tensor([[0, 2], [2, 3], [5, 1]]), | ||
| ), | ||
| # 4. 1→0 (cumsum conversion) | ||
| ( | ||
| 1, | ||
| 0, | ||
| torch.tensor([2, 1, 3, 4]), | ||
| {}, | ||
| torch.tensor([2, 3, 6, 10]), | ||
| ), | ||
| # 5. 0→1 (difference conversion) | ||
| ( | ||
| 0, | ||
| 1, | ||
| torch.tensor([2, 3, 6, 10]), | ||
| {}, | ||
| torch.tensor([2, 1, 3, 4]), | ||
| ), | ||
| # 6. 2→0 (expert-token mapping conversion) - Basic scenario | ||
| ( | ||
| 2, | ||
| 0, | ||
| torch.tensor([[0, 2], [2, 3], [5, 1]]), | ||
| { | ||
| "active_num": 0, | ||
| "expert_num": 6 | ||
| }, | ||
| torch.tensor([2, 0, 3, 0, 0, 1]), | ||
| ), | ||
| # 7. 2→0 - Edge scenario (no expert interval) | ||
| ( | ||
| 2, | ||
| 0, | ||
| torch.tensor([[1, 5], [3, 2], [4, 4]]), | ||
| { | ||
| "active_num": -1, | ||
| "expert_num": 5 | ||
| }, | ||
| torch.tensor([-1, 5, -1, 2, 4]), | ||
| ), | ||
| # 8. 2→0 - Single expert | ||
| ( | ||
| 2, | ||
| 0, | ||
| torch.tensor([[0, 10]]), | ||
| { | ||
| "active_num": 5, | ||
| "expert_num": 1 | ||
| }, | ||
| torch.tensor([10]), | ||
| ), | ||
| ], | ||
| ) | ||
| def test_cumsum_group_list_valid_cases(src_type, dst_type, input_tensor, | ||
| kwargs, expected_output): | ||
| """Test scenarios with valid type conversions""" | ||
| result = cumsum_group_list(input_tensor, src_type, dst_type, **kwargs) | ||
| # Verify result shape and values | ||
| assert result.shape == expected_output.shape | ||
| assert torch.allclose(result, expected_output) | ||
|
|
||
|
|
||
| def test_cumsum_group_list_invalid_src_type(): | ||
| """Test invalid source type""" | ||
| input_tensor = torch.tensor([1, 2, 3]) | ||
| with pytest.raises(ValueError) as excinfo: | ||
| cumsum_group_list(input_tensor, src_list_type=3, dst_list_type=0) | ||
| assert "group_list_type should be in [0, 1, 2], but received 3" in str( | ||
| excinfo.value) | ||
|
|
||
|
|
||
| def test_cumsum_group_list_unimplemented_conversion(): | ||
| """Test unimplemented type conversions""" | ||
| input_tensor = torch.tensor([1, 2, 3]) | ||
| # Test 0→2 (unimplemented) | ||
| with pytest.raises(NotImplementedError) as excinfo: | ||
| cumsum_group_list(input_tensor, src_list_type=0, dst_list_type=2) | ||
| assert "Conversion from src_list_type=0 to dst_list_type=2 is not implemented yet" in str( | ||
| excinfo.value) | ||
|
|
||
| # Test 1→2 (unimplemented) | ||
| with pytest.raises(NotImplementedError): | ||
| cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2) | ||
|
|
||
| # Test 2→1 (unimplemented) | ||
| input_2d = torch.tensor([[0, 1], [2, 3]]) | ||
| with pytest.raises(NotImplementedError): | ||
| cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1) | ||
|
|
||
|
|
||
| def test_cumsum_group_list_edge_cases(): | ||
| """Test edge cases""" | ||
| # Empty tensor (1→0) | ||
| empty_tensor = torch.tensor([], dtype=torch.int64) | ||
| result = cumsum_group_list(empty_tensor, src_list_type=1, dst_list_type=0) | ||
| assert torch.equal(result, empty_tensor) | ||
|
|
||
| # Single-element tensor (0→1) | ||
| single_tensor = torch.tensor([5]) | ||
| result = cumsum_group_list(single_tensor, src_list_type=0, dst_list_type=1) | ||
| assert torch.equal(result, torch.tensor([5])) | ||
|
|
||
| # 2→0 - Empty input | ||
| empty_2d = torch.tensor([], dtype=torch.int64).reshape(0, 2) | ||
| result = cumsum_group_list(empty_2d, | ||
| src_list_type=2, | ||
| dst_list_type=0, | ||
| active_num=0, | ||
| expert_num=3) | ||
| assert torch.equal(result, torch.tensor([0, 0, 0])) | ||
|
|
||
|
|
||
| def test_cumsum_group_list_dtype_device_consistency(): | ||
| """Test consistency of output dtype and device with input""" | ||
| # Test GPU (if available) | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| input_tensor = torch.tensor([[1, 2], [3, 4]], | ||
| dtype=torch.float32, | ||
| device=device) | ||
| result = cumsum_group_list( | ||
| input_tensor, | ||
| src_list_type=2, | ||
| dst_list_type=0, | ||
| active_num=0.0, | ||
| expert_num=4, | ||
| ) | ||
| assert result.dtype == torch.float32 | ||
| assert result.device == device | ||
|
|
||
| # Test int64 dtype | ||
| input_int = torch.tensor([2, 4, 6], dtype=torch.int64) | ||
| result_int = cumsum_group_list(input_int, src_list_type=0, dst_list_type=1) | ||
| assert result_int.dtype == torch.int64 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -41,7 +41,7 @@ def cumsum_group_list(group_list: torch.Tensor, | |||||||||||
| return group_list.cumsum(dim=0) | ||||||||||||
| if src_list_type == 0 and dst_list_type == 1: | ||||||||||||
| group_diff = torch.diff(group_list) | ||||||||||||
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | ||||||||||||
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) | ||||||||||||
| return new_group | ||||||||||||
|
Comment on lines
43
to
45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation for
Suggested change
|
||||||||||||
| if src_list_type == 2 and dst_list_type == 0: | ||||||||||||
| experts = pad(group_list[:, 0], (1, 0)) | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The edge case tests are missing a crucial scenario: converting an empty tensor from
src_list_type=0todst_list_type=1. The logic being modified in this PR has a bug that would cause a crash in this case. Adding this test case is important to ensure the function is robust and to prevent future regressions.