-
Notifications
You must be signed in to change notification settings - Fork 661
[UT]Ut for function cumsum_group_list in v11.0-dev (ref #5023) #5037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The head ref may contain hidden characters: "ut-for-bugfix\u2014dev"
Conversation
Signed-off-by: tanqingshan (A) <[email protected]> Signed-off-by: tanqingshan (A) <[email protected]>
Signed-off-by: tanqingshan (A) <[email protected]> Signed-off-by: tanqingshan (A) <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces unit tests for the cumsum_group_list function and includes a bug fix for the conversion from a cumulative sum list to a difference list. While the fix is correct for non-empty tensors, it doesn't handle the edge case of an empty input tensor, which would lead to a crash. I've provided a more robust and simplified solution for this. Additionally, I've suggested adding a new test case to cover this empty tensor scenario to prevent future regressions.
| group_diff = torch.diff(group_list) | ||
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | ||
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) | ||
| return new_group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for src_list_type=0 to dst_list_type=1 conversion is not robust against empty input tensors. Accessing group_list[0] will raise an IndexError if group_list is empty, causing a crash. This can be simplified and made safe for all inputs, including empty tensors, by using slicing group_list[:1], which gracefully handles the empty case.
| group_diff = torch.diff(group_list) | |
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | |
| new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) | |
| return new_group | |
| return torch.cat([group_list[:1], torch.diff(group_list)]) |
| dst_list_type=0, | ||
| active_num=0, | ||
| expert_num=3) | ||
| assert torch.equal(result, torch.tensor([0, 0, 0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The edge case tests are missing a crucial scenario: converting an empty tensor from src_list_type=0 to dst_list_type=1. The logic being modified in this PR has a bug that would cause a crash in this case. Adding this test case is important to ensure the function is robust and to prevent future regressions.
| assert torch.equal(result, torch.tensor([0, 0, 0])) | |
| assert torch.equal(result, torch.tensor([0, 0, 0])) | |
| # Empty tensor (0→1) | |
| empty_tensor = torch.tensor([], dtype=torch.int64) | |
| result = cumsum_group_list(empty_tensor, src_list_type=0, dst_list_type=1) | |
| assert torch.equal(result, empty_tensor) |
|
no need for 0.11.0-dev |
What this PR does / why we need it?
Add ut for the cumsum_group_list function, which is related to the precision issues stemming from the moe_mlp.py .
The ralated PR is #5023.
Does this PR introduce any user-facing change?
No