-
Notifications
You must be signed in to change notification settings - Fork 661
[UT]Ut for function cumsum_group_list in main (ref #5025) #5036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
The head ref may contain hidden characters: "ut-for-bugfix\u2014main"
Conversation
Signed-off-by: tanqingshan (A) <[email protected]> Signed-off-by: tanqingshan (A) <[email protected]>
Signed-off-by: tanqingshan (A) <[email protected]> Signed-off-by: tanqingshan (A) <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a bug fix for the cumsum_group_list function and adds a comprehensive set of unit tests. The fix correctly handles the conversion from a cumulative sum list to a difference list. The new tests cover various valid conversions, invalid inputs, unimplemented paths, and edge cases.
My review identifies a critical unhandled edge case where an empty tensor input for the 0 -> 1 conversion causes an IndexError. I've provided a suggestion to fix this and another to add a corresponding test case. Additionally, I've suggested improvements to the new tests to make them more robust by asserting exception messages. Overall, this is a valuable contribution that will be even stronger with these changes.
| if src_list_type == 0 and dst_list_type == 1: | ||
| group_diff = torch.diff(group_list) | ||
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | ||
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) | ||
| return new_group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for converting from src_list_type=0 to dst_list_type=1 does not handle empty input tensors. If group_list is an empty tensor, group_list[0] on line 49 will raise an IndexError. You should add a check to handle empty tensors gracefully, returning an empty tensor in that case.
| if src_list_type == 0 and dst_list_type == 1: | |
| group_diff = torch.diff(group_list) | |
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | |
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) | |
| return new_group | |
| if src_list_type == 0 and dst_list_type == 1: | |
| if not group_list.numel(): | |
| return group_list | |
| group_diff = torch.diff(group_list) | |
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) | |
| return new_group |
| # Test 1→2 (unimplemented) | ||
| with pytest.raises(NotImplementedError): | ||
| cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2) | ||
|
|
||
| # Test 2→1 (unimplemented) | ||
| input_2d = torch.tensor([[0, 1], [2, 3]]) | ||
| with pytest.raises(NotImplementedError): | ||
| cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve test robustness, it's good practice to assert the content of the raised exception. This ensures that the correct NotImplementedError is being raised for the specific unimplemented conversion path, and not for some other reason.
| # Test 1→2 (unimplemented) | |
| with pytest.raises(NotImplementedError): | |
| cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2) | |
| # Test 2→1 (unimplemented) | |
| input_2d = torch.tensor([[0, 1], [2, 3]]) | |
| with pytest.raises(NotImplementedError): | |
| cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1) | |
| # Test 1→2 (unimplemented) | |
| with pytest.raises(NotImplementedError) as excinfo: | |
| cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2) | |
| assert "Conversion from src_list_type=1 to dst_list_type=2 is not implemented yet" in str( | |
| excinfo.value) | |
| # Test 2→1 (unimplemented) | |
| input_2d = torch.tensor([[0, 1], [2, 3]]) | |
| with pytest.raises(NotImplementedError) as excinfo: | |
| cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1) | |
| assert "Conversion from src_list_type=2 to dst_list_type=1 is not implemented yet" in str( | |
| excinfo.value) |
| def test_cumsum_group_list_edge_cases(): | ||
| """Test edge cases""" | ||
| # Empty tensor (1→0) | ||
| empty_tensor = torch.tensor([], dtype=torch.int64) | ||
| result = cumsum_group_list(empty_tensor, src_list_type=1, dst_list_type=0) | ||
| assert torch.equal(result, empty_tensor) | ||
|
|
||
| # Single-element tensor (0→1) | ||
| single_tensor = torch.tensor([5]) | ||
| result = cumsum_group_list(single_tensor, src_list_type=0, dst_list_type=1) | ||
| assert torch.equal(result, torch.tensor([5])) | ||
|
|
||
| # 2→0 - Empty input | ||
| empty_2d = torch.tensor([], dtype=torch.int64).reshape(0, 2) | ||
| result = cumsum_group_list(empty_2d, | ||
| src_list_type=2, | ||
| dst_list_type=0, | ||
| active_num=0, | ||
| expert_num=3) | ||
| assert torch.equal(result, torch.tensor([0, 0, 0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The edge case tests are good, but they are missing a case for an empty tensor input for the 0 -> 1 conversion. This is a critical case that currently raises an IndexError. Please add this test to ensure the bug is fixed and does not regress.
def test_cumsum_group_list_edge_cases():
"""Test edge cases"""
# Empty tensor (1→0)
empty_tensor = torch.tensor([], dtype=torch.int64)
result = cumsum_group_list(empty_tensor, src_list_type=1, dst_list_type=0)
assert torch.equal(result, empty_tensor)
# Empty tensor (0→1)
result = cumsum_group_list(empty_tensor, src_list_type=0, dst_list_type=1)
assert torch.equal(result, empty_tensor)
# Single-element tensor (0→1)
single_tensor = torch.tensor([5])
result = cumsum_group_list(single_tensor, src_list_type=0, dst_list_type=1)
assert torch.equal(result, torch.tensor([5]))
# 2→0 - Empty input
empty_2d = torch.tensor([], dtype=torch.int64).reshape(0, 2)
result = cumsum_group_list(empty_2d,
src_list_type=2,
dst_list_type=0,
active_num=0,
expert_num=3)
assert torch.equal(result, torch.tensor([0, 0, 0]))|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Add ut for the cumsum_group_list function, which is related to the precision issues stemming from the moe_mlp.py .
The ralated PR is #5025
Does this PR introduce any user-facing change?
No