Skip to content

Conversation

@Clorist33
Copy link
Contributor

@Clorist33 Clorist33 commented Dec 15, 2025

What this PR does / why we need it?

Add ut for the cumsum_group_list function, which is related to the precision issues stemming from the moe_mlp.py .
The ralated PR is #5025

Does this PR introduce any user-facing change?

No

tanqingshan (A) added 2 commits December 15, 2025 18:37
Signed-off-by: tanqingshan (A) <[email protected]>

Signed-off-by: tanqingshan (A)  <[email protected]>
Signed-off-by: tanqingshan (A) <[email protected]>

Signed-off-by: tanqingshan (A)  <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bug fix for the cumsum_group_list function and adds a comprehensive set of unit tests. The fix correctly handles the conversion from a cumulative sum list to a difference list. The new tests cover various valid conversions, invalid inputs, unimplemented paths, and edge cases.

My review identifies a critical unhandled edge case where an empty tensor input for the 0 -> 1 conversion causes an IndexError. I've provided a suggestion to fix this and another to add a corresponding test case. Additionally, I've suggested improvements to the new tests to make them more robust by asserting exception messages. Overall, this is a valuable contribution that will be even stronger with these changes.

Comment on lines 47 to 50
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
return new_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for converting from src_list_type=0 to dst_list_type=1 does not handle empty input tensors. If group_list is an empty tensor, group_list[0] on line 49 will raise an IndexError. You should add a check to handle empty tensors gracefully, returning an empty tensor in that case.

Suggested change
if src_list_type == 0 and dst_list_type == 1:
group_diff = torch.diff(group_list)
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
return new_group
if src_list_type == 0 and dst_list_type == 1:
if not group_list.numel():
return group_list
group_diff = torch.diff(group_list)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
return new_group

Comment on lines +113 to +120
# Test 1→2 (unimplemented)
with pytest.raises(NotImplementedError):
cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2)

# Test 2→1 (unimplemented)
input_2d = torch.tensor([[0, 1], [2, 3]])
with pytest.raises(NotImplementedError):
cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve test robustness, it's good practice to assert the content of the raised exception. This ensures that the correct NotImplementedError is being raised for the specific unimplemented conversion path, and not for some other reason.

Suggested change
# Test 1→2 (unimplemented)
with pytest.raises(NotImplementedError):
cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2)
# Test 2→1 (unimplemented)
input_2d = torch.tensor([[0, 1], [2, 3]])
with pytest.raises(NotImplementedError):
cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1)
# Test 1→2 (unimplemented)
with pytest.raises(NotImplementedError) as excinfo:
cumsum_group_list(input_tensor, src_list_type=1, dst_list_type=2)
assert "Conversion from src_list_type=1 to dst_list_type=2 is not implemented yet" in str(
excinfo.value)
# Test 2→1 (unimplemented)
input_2d = torch.tensor([[0, 1], [2, 3]])
with pytest.raises(NotImplementedError) as excinfo:
cumsum_group_list(input_2d, src_list_type=2, dst_list_type=1)
assert "Conversion from src_list_type=2 to dst_list_type=1 is not implemented yet" in str(
excinfo.value)

Comment on lines +123 to +142
def test_cumsum_group_list_edge_cases():
"""Test edge cases"""
# Empty tensor (1→0)
empty_tensor = torch.tensor([], dtype=torch.int64)
result = cumsum_group_list(empty_tensor, src_list_type=1, dst_list_type=0)
assert torch.equal(result, empty_tensor)

# Single-element tensor (0→1)
single_tensor = torch.tensor([5])
result = cumsum_group_list(single_tensor, src_list_type=0, dst_list_type=1)
assert torch.equal(result, torch.tensor([5]))

# 2→0 - Empty input
empty_2d = torch.tensor([], dtype=torch.int64).reshape(0, 2)
result = cumsum_group_list(empty_2d,
src_list_type=2,
dst_list_type=0,
active_num=0,
expert_num=3)
assert torch.equal(result, torch.tensor([0, 0, 0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The edge case tests are good, but they are missing a case for an empty tensor input for the 0 -> 1 conversion. This is a critical case that currently raises an IndexError. Please add this test to ensure the bug is fixed and does not regress.

def test_cumsum_group_list_edge_cases():
    """Test edge cases"""
    # Empty tensor (1→0)
    empty_tensor = torch.tensor([], dtype=torch.int64)
    result = cumsum_group_list(empty_tensor, src_list_type=1, dst_list_type=0)
    assert torch.equal(result, empty_tensor)

    # Empty tensor (0→1)
    result = cumsum_group_list(empty_tensor, src_list_type=0, dst_list_type=1)
    assert torch.equal(result, empty_tensor)

    # Single-element tensor (0→1)
    single_tensor = torch.tensor([5])
    result = cumsum_group_list(single_tensor, src_list_type=0, dst_list_type=1)
    assert torch.equal(result, torch.tensor([5]))

    # 2→0 - Empty input
    empty_2d = torch.tensor([], dtype=torch.int64).reshape(0, 2)
    result = cumsum_group_list(empty_2d,
                               src_list_type=2,
                               dst_list_type=0,
                               active_num=0,
                               expert_num=3)
    assert torch.equal(result, torch.tensor([0, 0, 0]))

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant