Fix NPU LoRA for MindSpeed MoE grouped linear#118
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for NPU-optimized grouped LoRA adapters by adding the NpuGroupedLoraLinear module and integrating it into the LoRA tuner. The implementation leverages torch_npu.npu_grouped_matmul for efficient grouped matrix multiplication on NPU devices, falling back to a standard PyTorch loop when necessary. Feedback on these changes highlights several critical issues: the forward method of NpuGroupedLoraLinear needs to accept arbitrary arguments (*args, **kwargs) to prevent runtime type errors when called, and tuple inputs to torch_npu.npu_grouped_matmul should be explicitly converted to lists to avoid compatibility issues with the C++ bindings. Additionally, adding defensive validation in the fallback path to ensure the length of m_splits matches num_gemms is recommended to prevent cryptic attribute errors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def forward(self, x, m_splits): | ||
| if not self._can_use_grouped_matmul(x): | ||
| return self._fallback_forward(x, m_splits) | ||
| weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)] | ||
| return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None |
There was a problem hiding this comment.
The forward method of NpuGroupedLoraLinear does not accept arbitrary positional (*args) or keyword (**kwargs) arguments. However, in LoraParallelLinear.forward (within src/mcore_bridge/tuners/lora.py), the adapter layers are called with *args and **kwargs (e.g., lora_A(dropout(x), *args, **kwargs)). If any extra keyword arguments or positional arguments are passed, this will raise a TypeError.
To ensure compatibility and prevent runtime crashes, update the signature of forward in NpuGroupedLoraLinear to accept *args and **kwargs.
| def forward(self, x, m_splits): | |
| if not self._can_use_grouped_matmul(x): | |
| return self._fallback_forward(x, m_splits) | |
| weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)] | |
| return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None | |
| def forward(self, x, m_splits, *args, **kwargs): | |
| if not self._can_use_grouped_matmul(x): | |
| return self._fallback_forward(x, m_splits) | |
| weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)] | |
| return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None |
| output = torch_npu.npu_grouped_matmul( | ||
| [input_tensor], | ||
| weight_input_T, | ||
| bias=None, | ||
| group_list=group_list, | ||
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | ||
| group_type=_GMM_GROUP_BY_M_AXIS, | ||
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | ||
| )[0] |
There was a problem hiding this comment.
In _NpuGroupedLoraLinearGMM.forward, weight_input_T is a tuple of tensors. The torch_npu.npu_grouped_matmul operator expects a list of tensors for its inputs and weights. Passing a tuple instead of a list can lead to type errors or undefined behavior in the C++ bindings of torch_npu.
Convert weight_input_T to a list before passing it to npu_grouped_matmul.
| output = torch_npu.npu_grouped_matmul( | |
| [input_tensor], | |
| weight_input_T, | |
| bias=None, | |
| group_list=group_list, | |
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | |
| group_type=_GMM_GROUP_BY_M_AXIS, | |
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | |
| )[0] | |
| output = torch_npu.npu_grouped_matmul( | |
| [input_tensor], | |
| list(weight_input_T), | |
| bias=None, | |
| group_list=group_list, | |
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | |
| group_type=_GMM_GROUP_BY_M_AXIS, | |
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | |
| )[0] |
| grad_input = torch_npu.npu_grouped_matmul( | ||
| [grad_output], | ||
| weights, | ||
| bias=None, | ||
| group_list=group_list, | ||
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | ||
| group_type=_GMM_GROUP_BY_M_AXIS, | ||
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | ||
| )[0] |
There was a problem hiding this comment.
In _NpuGroupedLoraLinearGMM.backward, weights is retrieved as a slice of ctx.saved_tensors, which is a tuple. Similar to the forward pass, torch_npu.npu_grouped_matmul expects a list of tensors.
Convert weights to a list before passing it to npu_grouped_matmul.
| grad_input = torch_npu.npu_grouped_matmul( | |
| [grad_output], | |
| weights, | |
| bias=None, | |
| group_list=group_list, | |
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | |
| group_type=_GMM_GROUP_BY_M_AXIS, | |
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | |
| )[0] | |
| grad_input = torch_npu.npu_grouped_matmul( | |
| [grad_output], | |
| list(weights), | |
| bias=None, | |
| group_list=group_list, | |
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | |
| group_type=_GMM_GROUP_BY_M_AXIS, | |
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | |
| )[0] |
| def _fallback_forward(self, x, m_splits): | ||
| if isinstance(m_splits, torch.Tensor): | ||
| m_splits = m_splits.tolist() | ||
| outputs = [] | ||
| offset = 0 |
There was a problem hiding this comment.
In _fallback_forward, there is no validation to ensure that the length of m_splits matches self.num_gemms. If len(m_splits) is greater than self.num_gemms, the loop will attempt to access non-existent attributes (e.g., weight{i}), resulting in a cryptic AttributeError.
Add a defensive check at the beginning of _fallback_forward to verify that len(m_splits) matches self.num_gemms.
def _fallback_forward(self, x, m_splits):
if len(m_splits) != self.num_gemms:
raise RuntimeError(
f"Expected m_splits length to be equal to num_gemms ({self.num_gemms}), got {len(m_splits)}"
)
if isinstance(m_splits, torch.Tensor):
m_splits = m_splits.tolist()
outputs = []
offset = 0There was a problem hiding this comment.
Pull request overview
This PR aims to fix LoRA training plus checkpoint save/load for MindSpeed-patched MoE grouped linear layers on NPU by restoring reliable expert-layer detection and adding an NPU-specific grouped LoRA adapter with expert-aware sharded checkpointing.
Changes:
- Add
NpuGroupedLoraLinearand an NPU grouped-matmul autograd path for grouped LoRA adapters. - Replace
base_layer.is_expertprobing with a more robustis_expert_layer()(including MindSpeed TEGroupedLinear fallbacks). - Add expert-aware
sharded_state_dict()behavior for NPU grouped LoRA weights, including local-to-global expert index mapping.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/mcore_bridge/tuners/npu_lora.py |
Introduces NPU grouped LoRA linear adapter, MindSpeed expert detection, and expert-aware sharded checkpointing utilities. |
src/mcore_bridge/tuners/lora.py |
Switches grouped LoRA adapter construction to the NPU implementation on NPU, and uses the new expert-layer detection helper. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def backward(ctx, grad_output): | ||
| import torch_npu | ||
|
|
||
| input_tensor = ctx.saved_tensors[0] | ||
| weights = ctx.saved_tensors[1:] | ||
| group_list = ctx.group_list | ||
| grad_input = torch_npu.npu_grouped_matmul( | ||
| [grad_output], | ||
| weights, | ||
| bias=None, | ||
| group_list=group_list, | ||
| split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, | ||
| group_type=_GMM_GROUP_BY_M_AXIS, | ||
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | ||
| )[0] | ||
| grad_weight_T = torch_npu.npu_grouped_matmul( | ||
| [input_tensor.T], | ||
| [grad_output], | ||
| bias=None, | ||
| group_list=group_list, | ||
| split_item=_GMM_SPLIT_ITEM_DWEIGHT, | ||
| group_type=_GMM_GROUP_BY_K_AXIS, | ||
| group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, | ||
| )[0] | ||
| return grad_input, None, None, *grad_weight_T |
| for param_name in ('weight', 'bias'): | ||
| local_name = f'{param_name}{i}' | ||
| param = getattr(self, local_name, None) | ||
| if param is None: | ||
| continue | ||
| sharded_tensor = make_sharded_tensor_for_checkpoint( | ||
| param, | ||
| f'{key_prefix}{param_name}', | ||
| prepend_offsets=new_sharded_offsets, | ||
| ) | ||
| sharded_state_dict[f'{prefix}{local_name}'] = self._set_expert_replica_id(sharded_tensor) | ||
|
|
Summary
This PR fixes LoRA training and checkpoint save/load for MindSpeed patched MoE grouped linear layers on NPU.
Problem
On the NPU path, MindSpeed patches Megatron grouped expert linear layers. After patching, the layer may not reliably keep the
is_expertattribute.The original LoRA logic uses
is_expertto distinguish MoE expert parameters from normal tensor-parallel parameters. Because of that, a MindSpeed grouped expert layer can be treated as a normal grouped linear layer.This leads to two issues:
all-linearLoRA cannot correctly handle NPU MoE grouped expert linear layers.Changes
NpuGroupedLoraLinearfor NPU grouped LoRA adapters.sharded_state_dictfor NPU grouped LoRA weights.Verification
TP=2, PP=2, EP=2, ETP=1target_modules=all-linearTP=2, PP=2, EP=2, ETP=1target_modules=all-linear