Skip to content

Conversation

@Clorist33
Copy link
Contributor

What this PR does / why we need it?

Add ut for the cumsum_group_list function, which is related to the precision issues stemming from the moe_mlp.py .
The ralated PR is #5023.

Does this PR introduce any user-facing change?

No

tanqingshan (A) added 2 commits December 15, 2025 18:51
Signed-off-by: tanqingshan (A) <[email protected]>

Signed-off-by: tanqingshan (A)  <[email protected]>
Signed-off-by: tanqingshan (A) <[email protected]>

Signed-off-by: tanqingshan (A)  <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces unit tests for the cumsum_group_list function and includes a bug fix for the conversion from a cumulative sum list to a difference list. While the fix is correct for non-empty tensors, it doesn't handle the edge case of an empty input tensor, which would lead to a crash. I've provided a more robust and simplified solution for this. Additionally, I've suggested adding a new test case to cover this empty tensor scenario to prevent future regressions.

Comment on lines 43 to 45
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
return new_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for src_list_type=0 to dst_list_type=1 conversion is not robust against empty input tensors. Accessing group_list[0] will raise an IndexError if group_list is empty, causing a crash. This can be simplified and made safe for all inputs, including empty tensors, by using slicing group_list[:1], which gracefully handles the empty case.

Suggested change
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
return new_group
return torch.cat([group_list[:1], torch.diff(group_list)])

dst_list_type=0,
active_num=0,
expert_num=3)
assert torch.equal(result, torch.tensor([0, 0, 0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The edge case tests are missing a crucial scenario: converting an empty tensor from src_list_type=0 to dst_list_type=1. The logic being modified in this PR has a bug that would cause a crash in this case. Adding this test case is important to ensure the function is robust and to prevent future regressions.

Suggested change
assert torch.equal(result, torch.tensor([0, 0, 0]))
assert torch.equal(result, torch.tensor([0, 0, 0]))
# Empty tensor (0→1)
empty_tensor = torch.tensor([], dtype=torch.int64)
result = cumsum_group_list(empty_tensor, src_list_type=0, dst_list_type=1)
assert torch.equal(result, empty_tensor)

@wangxiyuan
Copy link
Collaborator

no need for 0.11.0-dev

@wangxiyuan wangxiyuan closed this Dec 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants