From a1163a308f92475f21009af97ed625eaf13c9a34 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Sat, 6 Jun 2026 03:18:56 +0800 Subject: [PATCH 1/3] Support NPU grouped LoRA adapters --- src/mcore_bridge/tuners/lora.py | 85 ++++++++------- src/mcore_bridge/tuners/npu_lora.py | 161 ++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 src/mcore_bridge/tuners/npu_lora.py diff --git a/src/mcore_bridge/tuners/lora.py b/src/mcore_bridge/tuners/lora.py index 4a050cf..b43aa98 100644 --- a/src/mcore_bridge/tuners/lora.py +++ b/src/mcore_bridge/tuners/lora.py @@ -28,6 +28,7 @@ from mcore_bridge.utils import get_current_device +from .npu_lora import NpuGroupedLoraLinear, is_expert_layer from .utils import tuners_sharded_state_dict mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') @@ -119,7 +120,7 @@ def __init__( self.is_grouped = isinstance(base_layer, TEGroupedLinear) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name - self.is_expert = getattr(base_layer, 'is_expert', False) + self.is_expert = is_expert_layer(base_layer) self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False) if self.is_expert: self.tp_size = get_expert_tensor_parallel_world_size() @@ -181,21 +182,27 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs): elif self.is_parallel_a: in_features = self.in_features * self.tp_size if self.is_grouped: - lora_a = TERowParallelGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=in_features, - output_size=r, - bias=False, - **kwargs, - ) - lora_b = TEGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=r, - output_size=self.out_features, - bias=lora_bias, - parallel_mode=None, - **kwargs, - ) + if is_torch_npu_available(): + lora_a = NpuGroupedLoraLinear( + self.base_layer.num_gemms, in_features, r, config=self.config, bias=False) + lora_b = NpuGroupedLoraLinear( + self.base_layer.num_gemms, r, self.out_features, config=self.config, bias=lora_bias) + else: + lora_a = TERowParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=in_features, + output_size=r, + bias=False, + **kwargs, + ) + lora_b = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + **kwargs, + ) else: lora_a = TERowParallelLinear( input_size=in_features, @@ -212,20 +219,26 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs): else: out_features = self.out_features * self.tp_size if self.is_grouped: - lora_a = TEGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=self.in_features, - output_size=r, - bias=lora_bias, - parallel_mode=None, - **kwargs) - lora_b = TEColumnParallelGroupedLinear( - num_gemms=self.base_layer.num_gemms, - input_size=r, - output_size=out_features, - bias=lora_bias, - **kwargs, - ) + if is_torch_npu_available(): + lora_a = NpuGroupedLoraLinear( + self.base_layer.num_gemms, self.in_features, r, config=self.config, bias=lora_bias) + lora_b = NpuGroupedLoraLinear( + self.base_layer.num_gemms, r, out_features, config=self.config, bias=lora_bias) + else: + lora_a = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + **kwargs) + lora_b = TEColumnParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=out_features, + bias=lora_bias, + **kwargs, + ) else: lora_a = _build_local_te_linear(self.in_features, r, lora_bias, **kwargs) lora_b = TEColumnParallelLinear( @@ -279,11 +292,11 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights): if adapter_name in self.lora_A.keys(): lora_a = self.lora_A[adapter_name] lora_b = self.lora_B[adapter_name] - if isinstance(lora_a, TEGroupedLinear): + if isinstance(lora_a, (TEGroupedLinear, NpuGroupedLoraLinear)): weights_a = [getattr(lora_a, f'weight{i}') for i in range(lora_a.num_gemms)] else: weights_a = [lora_a.weight] - if isinstance(lora_b, TEGroupedLinear): + if isinstance(lora_b, (TEGroupedLinear, NpuGroupedLoraLinear)): weights_b = [getattr(lora_b, f'weight{i}') for i in range(lora_b.num_gemms)] else: weights_b = [lora_b.weight] @@ -394,15 +407,15 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype + dtype = lora_A.weight0.dtype if isinstance(lora_A, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_A.weight.dtype x = x.to(dtype) - lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A( - dropout(x)) + lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance( + lora_A, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_A(dropout(x)) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_B(lora_result, *args, **kwargs) if isinstance( - lora_B, TEGroupedLinear) else lora_B(lora_result) + lora_B, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_B(lora_result) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_result * scaling diff --git a/src/mcore_bridge/tuners/npu_lora.py b/src/mcore_bridge/tuners/npu_lora.py new file mode 100644 index 0000000..1f572ea --- /dev/null +++ b/src/mcore_bridge/tuners/npu_lora.py @@ -0,0 +1,161 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core.extensions.transformer_engine import TEGroupedLinear +from transformers.utils import is_torch_npu_available + +from mcore_bridge.utils import get_current_device + +_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT = 2 +_GMM_SPLIT_ITEM_DWEIGHT = 3 +_GMM_GROUP_BY_M_AXIS = 0 +_GMM_GROUP_BY_K_AXIS = 2 +_GMM_GROUP_LIST_IS_EXPERT_SIZES = 1 + + +def _is_mindspeed_grouped_linear(base_layer) -> bool: + if not (is_torch_npu_available() and isinstance(base_layer, TEGroupedLinear)): + return False + return type(base_layer).__module__.startswith('mindspeed.') + + +def _has_moe_local_expert_grouping(base_layer) -> bool: + config = getattr(base_layer, 'config', None) + num_moe_experts = getattr(config, 'num_moe_experts', None) + if num_moe_experts is None: + return False + ep_size = getattr(config, 'expert_model_parallel_size', 1) + if not isinstance(ep_size, int) or ep_size <= 0: + return False + if num_moe_experts % ep_size != 0: + return False + return getattr(base_layer, 'num_gemms', None) == num_moe_experts // ep_size + + +def is_expert_layer(base_layer) -> bool: + is_expert = getattr(base_layer, 'is_expert', None) + if is_expert is not None: + return bool(is_expert) + if _is_mindspeed_grouped_linear(base_layer): + if getattr(base_layer, 'explicit_expert_comm', False): + return True + has_moe_local_expert_grouping = _has_moe_local_expert_grouping(base_layer) + if getattr(base_layer, 'expert_parallel', False): + return has_moe_local_expert_grouping + # MindSpeedTEGroupedLinear receives is_expert but does not keep it as + # an attribute. When EP/ETP does not trigger explicit expert comm, + # fall back to the TEGroupedMLP invariant: one grouped slot per local + # expert. + return has_moe_local_expert_grouping + return False + + +class NpuGroupedLoraLinear(nn.Module): + """Generic grouped linear for low-rank NPU LoRA adapters.""" + + def __init__(self, num_gemms: int, input_size: int, output_size: int, *, config, bias: bool): + super().__init__() + self.num_gemms = num_gemms + self.input_size = input_size + self.output_size = output_size + self.use_bias = bias + self.parallel_mode = None + device = torch.device('cpu') if config.use_cpu_initialization else get_current_device() + dtype = config.params_dtype + for i in range(num_gemms): + self.register_parameter( + f'weight{i}', + nn.Parameter(torch.empty(output_size, input_size, device=device, dtype=dtype)), + ) + if bias: + self.register_parameter( + f'bias{i}', + nn.Parameter(torch.empty(output_size, device=device, dtype=dtype)), + ) + + @property + def weight(self): + return self.weight0 + + def _fallback_forward(self, x, m_splits): + if isinstance(m_splits, torch.Tensor): + m_splits = m_splits.tolist() + outputs = [] + offset = 0 + for i, split_size in enumerate(m_splits): + split_size = int(split_size) + x_i = x[offset:offset + split_size] + offset += split_size + weight = getattr(self, f'weight{i}') + bias = getattr(self, f'bias{i}', None) + outputs.append(F.linear(x_i, weight, bias)) + if offset != x.shape[0]: + raise RuntimeError(f'Grouped LoRA token split mismatch: got {offset}, expected {x.shape[0]}') + return torch.cat(outputs, dim=0) if outputs else x.new_empty((*x.shape[:-1], self.output_size)), None + + def _can_use_grouped_matmul(self, x): + # PEFT's lora_bias=True adds a bias to LoRA-B and is not the mainstream + # Swift path. Keep that uncommon adapter shape on the generic PyTorch + # path until it is explicitly needed and verified. + return is_torch_npu_available() and x.device.type == 'npu' and x.dim() == 2 and not self.use_bias + + def forward(self, x, m_splits): + if not self._can_use_grouped_matmul(x): + return self._fallback_forward(x, m_splits) + weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)] + return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None + + +class _NpuGroupedLoraLinearGMM(torch.autograd.Function): + @staticmethod + def forward(ctx, input_tensor, m_splits, weights, *weight_input_T): + import torch_npu + + if isinstance(m_splits, torch.Tensor): + group_list = m_splits + if group_list.device.type == 'cpu': + group_list = group_list.to(device=input_tensor.device, dtype=torch.int64) + else: + group_list = group_list.to(dtype=torch.int64) + else: + group_list = torch.tensor(m_splits, device=input_tensor.device, dtype=torch.int64) + output = torch_npu.npu_grouped_matmul( + [input_tensor], + weight_input_T, + bias=None, + group_list=group_list, + split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, + group_type=_GMM_GROUP_BY_M_AXIS, + group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, + )[0] + ctx.group_list = group_list + ctx.save_for_backward(input_tensor, *weights) + return output + + @staticmethod + def backward(ctx, grad_output): + import torch_npu + + input_tensor = ctx.saved_tensors[0] + weights = ctx.saved_tensors[1:] + group_list = ctx.group_list + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], + weights, + bias=None, + group_list=group_list, + split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT, + group_type=_GMM_GROUP_BY_M_AXIS, + group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, + )[0] + grad_weight_T = torch_npu.npu_grouped_matmul( + [input_tensor.T], + [grad_output], + bias=None, + group_list=group_list, + split_item=_GMM_SPLIT_ITEM_DWEIGHT, + group_type=_GMM_GROUP_BY_K_AXIS, + group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES, + )[0] + return grad_input, None, None, *grad_weight_T From 2cc4abf3aa35c6f6f69f61c4c6b0d4eaaafc6235 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Tue, 9 Jun 2026 20:31:38 +0800 Subject: [PATCH 2/3] Fix NPU grouped LoRA checkpoint sharding --- src/mcore_bridge/tuners/lora.py | 12 ++++-- src/mcore_bridge/tuners/npu_lora.py | 60 ++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/tuners/lora.py b/src/mcore_bridge/tuners/lora.py index b43aa98..b1a27cf 100644 --- a/src/mcore_bridge/tuners/lora.py +++ b/src/mcore_bridge/tuners/lora.py @@ -184,9 +184,11 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs): if self.is_grouped: if is_torch_npu_available(): lora_a = NpuGroupedLoraLinear( - self.base_layer.num_gemms, in_features, r, config=self.config, bias=False) + self.base_layer.num_gemms, in_features, r, config=self.config, bias=False, + is_expert=self.is_expert) lora_b = NpuGroupedLoraLinear( - self.base_layer.num_gemms, r, self.out_features, config=self.config, bias=lora_bias) + self.base_layer.num_gemms, r, self.out_features, config=self.config, bias=lora_bias, + is_expert=self.is_expert) else: lora_a = TERowParallelGroupedLinear( num_gemms=self.base_layer.num_gemms, @@ -221,9 +223,11 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs): if self.is_grouped: if is_torch_npu_available(): lora_a = NpuGroupedLoraLinear( - self.base_layer.num_gemms, self.in_features, r, config=self.config, bias=lora_bias) + self.base_layer.num_gemms, self.in_features, r, config=self.config, bias=lora_bias, + is_expert=self.is_expert) lora_b = NpuGroupedLoraLinear( - self.base_layer.num_gemms, r, out_features, config=self.config, bias=lora_bias) + self.base_layer.num_gemms, r, out_features, config=self.config, bias=lora_bias, + is_expert=self.is_expert) else: lora_a = TEGroupedLinear( num_gemms=self.base_layer.num_gemms, diff --git a/src/mcore_bridge/tuners/npu_lora.py b/src/mcore_bridge/tuners/npu_lora.py index 1f572ea..adf63fc 100644 --- a/src/mcore_bridge/tuners/npu_lora.py +++ b/src/mcore_bridge/tuners/npu_lora.py @@ -2,8 +2,12 @@ import torch import torch.nn as nn import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint +from megatron.core.utils import make_sharded_tensor_for_checkpoint from megatron.core.extensions.transformer_engine import TEGroupedLinear from transformers.utils import is_torch_npu_available +from typing import Optional, Tuple from mcore_bridge.utils import get_current_device @@ -54,12 +58,14 @@ def is_expert_layer(base_layer) -> bool: class NpuGroupedLoraLinear(nn.Module): """Generic grouped linear for low-rank NPU LoRA adapters.""" - def __init__(self, num_gemms: int, input_size: int, output_size: int, *, config, bias: bool): + def __init__(self, num_gemms: int, input_size: int, output_size: int, *, config, bias: bool, + is_expert: bool = False): super().__init__() self.num_gemms = num_gemms self.input_size = input_size self.output_size = output_size self.use_bias = bias + self.is_expert = is_expert self.parallel_mode = None device = torch.device('cpu') if config.use_cpu_initialization else get_current_device() dtype = config.params_dtype @@ -78,6 +84,58 @@ def __init__(self, num_gemms: int, input_size: int, output_size: int, *, config, def weight(self): return self.weight0 + def _set_expert_replica_id(self, sharded_tensor): + replica_id = sharded_tensor.replica_id + assert len(replica_id) == 3, f'Expected replica_id to be in (PP, TP, DP) format, got: {replica_id}' + if getattr(sharded_tensor, 'is_data_parallel_fully_shard', False): + edp_replica_id = 0 + else: + edp_replica_id = parallel_state.get_expert_data_parallel_rank() + sharded_tensor.replica_id = (*replica_id[:2], edp_replica_id) + return sharded_tensor + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ): + if not self.is_expert: + return make_sharded_tensors_for_checkpoint( + self.state_dict(prefix='', keep_vars=True), + prefix, + sharded_offsets=sharded_offsets, + ) + + singleton_local_shards = (metadata or {}).get('singleton_local_shards', False) + num_global_experts = parallel_state.get_expert_model_parallel_world_size() * self.num_gemms + local_expert_indices_offset = parallel_state.get_expert_model_parallel_rank() * self.num_gemms + ep_axis = len(sharded_offsets) + sharded_state_dict = {} + + for i in range(self.num_gemms): + global_expert_idx = local_expert_indices_offset + i + if singleton_local_shards: + key_prefix = f'{global_expert_idx}.{prefix}' + new_sharded_offsets = sharded_offsets + else: + key_prefix = prefix + new_sharded_offsets = (*sharded_offsets, (ep_axis, global_expert_idx, num_global_experts)) + + for param_name in ('weight', 'bias'): + local_name = f'{param_name}{i}' + param = getattr(self, local_name, None) + if param is None: + continue + sharded_tensor = make_sharded_tensor_for_checkpoint( + param, + f'{key_prefix}{param_name}', + prepend_offsets=new_sharded_offsets, + ) + sharded_state_dict[f'{prefix}{local_name}'] = self._set_expert_replica_id(sharded_tensor) + + return sharded_state_dict + def _fallback_forward(self, x, m_splits): if isinstance(m_splits, torch.Tensor): m_splits = m_splits.tolist() From 3c75f9eca1c019c7eec0bd4be48135a610761386 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Wed, 10 Jun 2026 02:08:42 +0800 Subject: [PATCH 3/3] fix lint --- src/mcore_bridge/tuners/lora.py | 27 ++++++++++++++++++++++----- src/mcore_bridge/tuners/npu_lora.py | 11 +++++++++-- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/mcore_bridge/tuners/lora.py b/src/mcore_bridge/tuners/lora.py index b1a27cf..8b15093 100644 --- a/src/mcore_bridge/tuners/lora.py +++ b/src/mcore_bridge/tuners/lora.py @@ -184,10 +184,18 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs): if self.is_grouped: if is_torch_npu_available(): lora_a = NpuGroupedLoraLinear( - self.base_layer.num_gemms, in_features, r, config=self.config, bias=False, + self.base_layer.num_gemms, + in_features, + r, + config=self.config, + bias=False, is_expert=self.is_expert) lora_b = NpuGroupedLoraLinear( - self.base_layer.num_gemms, r, self.out_features, config=self.config, bias=lora_bias, + self.base_layer.num_gemms, + r, + self.out_features, + config=self.config, + bias=lora_bias, is_expert=self.is_expert) else: lora_a = TERowParallelGroupedLinear( @@ -223,10 +231,18 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs): if self.is_grouped: if is_torch_npu_available(): lora_a = NpuGroupedLoraLinear( - self.base_layer.num_gemms, self.in_features, r, config=self.config, bias=lora_bias, + self.base_layer.num_gemms, + self.in_features, + r, + config=self.config, + bias=lora_bias, is_expert=self.is_expert) lora_b = NpuGroupedLoraLinear( - self.base_layer.num_gemms, r, out_features, config=self.config, bias=lora_bias, + self.base_layer.num_gemms, + r, + out_features, + config=self.config, + bias=lora_bias, is_expert=self.is_expert) else: lora_a = TEGroupedLinear( @@ -411,7 +427,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - dtype = lora_A.weight0.dtype if isinstance(lora_A, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_A.weight.dtype + dtype = lora_A.weight0.dtype if isinstance(lora_A, (TEGroupedLinear, + NpuGroupedLoraLinear)) else lora_A.weight.dtype x = x.to(dtype) lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance( diff --git a/src/mcore_bridge/tuners/npu_lora.py b/src/mcore_bridge/tuners/npu_lora.py index adf63fc..099b6a8 100644 --- a/src/mcore_bridge/tuners/npu_lora.py +++ b/src/mcore_bridge/tuners/npu_lora.py @@ -3,9 +3,9 @@ import torch.nn as nn import torch.nn.functional as F from megatron.core import parallel_state +from megatron.core.extensions.transformer_engine import TEGroupedLinear from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import make_sharded_tensor_for_checkpoint -from megatron.core.extensions.transformer_engine import TEGroupedLinear from transformers.utils import is_torch_npu_available from typing import Optional, Tuple @@ -58,7 +58,13 @@ def is_expert_layer(base_layer) -> bool: class NpuGroupedLoraLinear(nn.Module): """Generic grouped linear for low-rank NPU LoRA adapters.""" - def __init__(self, num_gemms: int, input_size: int, output_size: int, *, config, bias: bool, + def __init__(self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config, + bias: bool, is_expert: bool = False): super().__init__() self.num_gemms = num_gemms @@ -166,6 +172,7 @@ def forward(self, x, m_splits): class _NpuGroupedLoraLinearGMM(torch.autograd.Function): + @staticmethod def forward(ctx, input_tensor, m_splits, weights, *weight_input_T): import torch_npu