Skip to content

Fix NPU LoRA for MindSpeed MoE grouped linear#118

Open
addsubmuldiv wants to merge 3 commits into
modelscope:mainfrom
addsubmuldiv:npu-grouped-lora-gmm
Open

Fix NPU LoRA for MindSpeed MoE grouped linear#118
addsubmuldiv wants to merge 3 commits into
modelscope:mainfrom
addsubmuldiv:npu-grouped-lora-gmm

Conversation

@addsubmuldiv

Copy link
Copy Markdown
Collaborator

Summary

This PR fixes LoRA training and checkpoint save/load for MindSpeed patched MoE grouped linear layers on NPU.

Problem

On the NPU path, MindSpeed patches Megatron grouped expert linear layers. After patching, the layer may not reliably keep the is_expert attribute.

The original LoRA logic uses is_expert to distinguish MoE expert parameters from normal tensor-parallel parameters. Because of that, a MindSpeed grouped expert layer can be treated as a normal grouped linear layer.

This leads to two issues:

  • all-linear LoRA cannot correctly handle NPU MoE grouped expert linear layers.
  • LoRA expert weights are saved without expert-parallel checkpoint sharding metadata.

Changes

  • Add NpuGroupedLoraLinear for NPU grouped LoRA adapters.
  • Use the NPU adapter only on NPU grouped linear paths; keep the existing TE implementation for non-NPU paths.
  • Recover expert-layer detection for MindSpeed patched grouped linear layers.
  • Add expert-aware sharded_state_dict for NPU grouped LoRA weights.
  • Map local expert slots to global expert indices when saving checkpoints.

Verification

  • NPU Megatron LoRA smoke: Qwen3.5 MoE
    • TP=2, PP=2, EP=2, ETP=1
    • target_modules=all-linear
    • 1-step train, save, load, and resume all succeeded.
  • NPU Megatron LoRA smoke: Qwen3 MoE
    • TP=2, PP=2, EP=2, ETP=1
    • target_modules=all-linear
    • 1-step train, save, load, and resume all succeeded.

@addsubmuldiv addsubmuldiv changed the title Npu grouped lora gmm Fix NPU LoRA for MindSpeed MoE grouped linear Jun 9, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NPU-optimized grouped LoRA adapters by adding the NpuGroupedLoraLinear module and integrating it into the LoRA tuner. The implementation leverages torch_npu.npu_grouped_matmul for efficient grouped matrix multiplication on NPU devices, falling back to a standard PyTorch loop when necessary. Feedback on these changes highlights several critical issues: the forward method of NpuGroupedLoraLinear needs to accept arbitrary arguments (*args, **kwargs) to prevent runtime type errors when called, and tuple inputs to torch_npu.npu_grouped_matmul should be explicitly converted to lists to avoid compatibility issues with the C++ bindings. Additionally, adding defensive validation in the fallback path to ensure the length of m_splits matches num_gemms is recommended to prevent cryptic attribute errors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +161 to +165
def forward(self, x, m_splits):
if not self._can_use_grouped_matmul(x):
return self._fallback_forward(x, m_splits)
weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)]
return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method of NpuGroupedLoraLinear does not accept arbitrary positional (*args) or keyword (**kwargs) arguments. However, in LoraParallelLinear.forward (within src/mcore_bridge/tuners/lora.py), the adapter layers are called with *args and **kwargs (e.g., lora_A(dropout(x), *args, **kwargs)). If any extra keyword arguments or positional arguments are passed, this will raise a TypeError.

To ensure compatibility and prevent runtime crashes, update the signature of forward in NpuGroupedLoraLinear to accept *args and **kwargs.

Suggested change
def forward(self, x, m_splits):
if not self._can_use_grouped_matmul(x):
return self._fallback_forward(x, m_splits)
weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)]
return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None
def forward(self, x, m_splits, *args, **kwargs):
if not self._can_use_grouped_matmul(x):
return self._fallback_forward(x, m_splits)
weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)]
return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None

Comment on lines +181 to +189
output = torch_npu.npu_grouped_matmul(
[input_tensor],
weight_input_T,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _NpuGroupedLoraLinearGMM.forward, weight_input_T is a tuple of tensors. The torch_npu.npu_grouped_matmul operator expects a list of tensors for its inputs and weights. Passing a tuple instead of a list can lead to type errors or undefined behavior in the C++ bindings of torch_npu.

Convert weight_input_T to a list before passing it to npu_grouped_matmul.

Suggested change
output = torch_npu.npu_grouped_matmul(
[input_tensor],
weight_input_T,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
output = torch_npu.npu_grouped_matmul(
[input_tensor],
list(weight_input_T),
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]

Comment on lines +201 to +209
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
weights,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _NpuGroupedLoraLinearGMM.backward, weights is retrieved as a slice of ctx.saved_tensors, which is a tuple. Similar to the forward pass, torch_npu.npu_grouped_matmul expects a list of tensors.

Convert weights to a list before passing it to npu_grouped_matmul.

Suggested change
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
weights,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
list(weights),
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]

Comment on lines +139 to +143
def _fallback_forward(self, x, m_splits):
if isinstance(m_splits, torch.Tensor):
m_splits = m_splits.tolist()
outputs = []
offset = 0

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _fallback_forward, there is no validation to ensure that the length of m_splits matches self.num_gemms. If len(m_splits) is greater than self.num_gemms, the loop will attempt to access non-existent attributes (e.g., weight{i}), resulting in a cryptic AttributeError.

Add a defensive check at the beginning of _fallback_forward to verify that len(m_splits) matches self.num_gemms.

    def _fallback_forward(self, x, m_splits):
        if len(m_splits) != self.num_gemms:
            raise RuntimeError(
                f"Expected m_splits length to be equal to num_gemms ({self.num_gemms}), got {len(m_splits)}"
            )
        if isinstance(m_splits, torch.Tensor):
            m_splits = m_splits.tolist()
        outputs = []
        offset = 0

@addsubmuldiv addsubmuldiv marked this pull request as ready for review June 9, 2026 18:45
Copilot AI review requested due to automatic review settings June 9, 2026 18:45

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to fix LoRA training plus checkpoint save/load for MindSpeed-patched MoE grouped linear layers on NPU by restoring reliable expert-layer detection and adding an NPU-specific grouped LoRA adapter with expert-aware sharded checkpointing.

Changes:

  • Add NpuGroupedLoraLinear and an NPU grouped-matmul autograd path for grouped LoRA adapters.
  • Replace base_layer.is_expert probing with a more robust is_expert_layer() (including MindSpeed TEGroupedLinear fallbacks).
  • Add expert-aware sharded_state_dict() behavior for NPU grouped LoRA weights, including local-to-global expert index mapping.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/mcore_bridge/tuners/npu_lora.py Introduces NPU grouped LoRA linear adapter, MindSpeed expert detection, and expert-aware sharded checkpointing utilities.
src/mcore_bridge/tuners/lora.py Switches grouped LoRA adapter construction to the NPU implementation on NPU, and uses the new expert-layer detection helper.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +202 to +226
def backward(ctx, grad_output):
import torch_npu

input_tensor = ctx.saved_tensors[0]
weights = ctx.saved_tensors[1:]
group_list = ctx.group_list
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
weights,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
grad_weight_T = torch_npu.npu_grouped_matmul(
[input_tensor.T],
[grad_output],
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_DWEIGHT,
group_type=_GMM_GROUP_BY_K_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
return grad_input, None, None, *grad_weight_T
Comment on lines +131 to +142
for param_name in ('weight', 'bias'):
local_name = f'{param_name}{i}'
param = getattr(self, local_name, None)
if param is None:
continue
sharded_tensor = make_sharded_tensor_for_checkpoint(
param,
f'{key_prefix}{param_name}',
prepend_offsets=new_sharded_offsets,
)
sharded_state_dict[f'{prefix}{local_name}'] = self._set_expert_replica_id(sharded_tensor)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants