Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions src/mcore_bridge/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from mcore_bridge.utils import get_current_device

from .npu_lora import NpuGroupedLoraLinear, is_expert_layer
from .utils import tuners_sharded_state_dict

mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0')
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
self.is_grouped = isinstance(base_layer, TEGroupedLinear)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.is_expert = getattr(base_layer, 'is_expert', False)
self.is_expert = is_expert_layer(base_layer)
self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False)
if self.is_expert:
self.tp_size = get_expert_tensor_parallel_world_size()
Expand Down Expand Up @@ -181,21 +182,29 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs):
elif self.is_parallel_a:
in_features = self.in_features * self.tp_size
if self.is_grouped:
lora_a = TERowParallelGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=in_features,
output_size=r,
bias=False,
**kwargs,
)
lora_b = TEGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=r,
output_size=self.out_features,
bias=lora_bias,
parallel_mode=None,
**kwargs,
)
if is_torch_npu_available():
lora_a = NpuGroupedLoraLinear(
self.base_layer.num_gemms, in_features, r, config=self.config, bias=False,
is_expert=self.is_expert)
lora_b = NpuGroupedLoraLinear(
self.base_layer.num_gemms, r, self.out_features, config=self.config, bias=lora_bias,
is_expert=self.is_expert)
else:
lora_a = TERowParallelGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=in_features,
output_size=r,
bias=False,
**kwargs,
)
lora_b = TEGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=r,
output_size=self.out_features,
bias=lora_bias,
parallel_mode=None,
**kwargs,
)
else:
lora_a = TERowParallelLinear(
input_size=in_features,
Expand All @@ -212,20 +221,28 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs):
else:
out_features = self.out_features * self.tp_size
if self.is_grouped:
lora_a = TEGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=self.in_features,
output_size=r,
bias=lora_bias,
parallel_mode=None,
**kwargs)
lora_b = TEColumnParallelGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=r,
output_size=out_features,
bias=lora_bias,
**kwargs,
)
if is_torch_npu_available():
lora_a = NpuGroupedLoraLinear(
self.base_layer.num_gemms, self.in_features, r, config=self.config, bias=lora_bias,
is_expert=self.is_expert)
lora_b = NpuGroupedLoraLinear(
self.base_layer.num_gemms, r, out_features, config=self.config, bias=lora_bias,
is_expert=self.is_expert)
else:
lora_a = TEGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=self.in_features,
output_size=r,
bias=lora_bias,
parallel_mode=None,
**kwargs)
lora_b = TEColumnParallelGroupedLinear(
num_gemms=self.base_layer.num_gemms,
input_size=r,
output_size=out_features,
bias=lora_bias,
**kwargs,
)
else:
lora_a = _build_local_te_linear(self.in_features, r, lora_bias, **kwargs)
lora_b = TEColumnParallelLinear(
Expand Down Expand Up @@ -279,11 +296,11 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
if adapter_name in self.lora_A.keys():
lora_a = self.lora_A[adapter_name]
lora_b = self.lora_B[adapter_name]
if isinstance(lora_a, TEGroupedLinear):
if isinstance(lora_a, (TEGroupedLinear, NpuGroupedLoraLinear)):
weights_a = [getattr(lora_a, f'weight{i}') for i in range(lora_a.num_gemms)]
else:
weights_a = [lora_a.weight]
if isinstance(lora_b, TEGroupedLinear):
if isinstance(lora_b, (TEGroupedLinear, NpuGroupedLoraLinear)):
weights_b = [getattr(lora_b, f'weight{i}') for i in range(lora_b.num_gemms)]
else:
weights_b = [lora_b.weight]
Expand Down Expand Up @@ -394,15 +411,15 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype
dtype = lora_A.weight0.dtype if isinstance(lora_A, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_A.weight.dtype
x = x.to(dtype)

lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A(
dropout(x))
lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(
lora_A, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_A(dropout(x))
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_B(lora_result, *args, **kwargs) if isinstance(
lora_B, TEGroupedLinear) else lora_B(lora_result)
lora_B, (TEGroupedLinear, NpuGroupedLoraLinear)) else lora_B(lora_result)
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_result * scaling
Expand Down
219 changes: 219 additions & 0 deletions src/mcore_bridge/tuners/npu_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import make_sharded_tensor_for_checkpoint
from megatron.core.extensions.transformer_engine import TEGroupedLinear
from transformers.utils import is_torch_npu_available
from typing import Optional, Tuple

from mcore_bridge.utils import get_current_device

_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT = 2
_GMM_SPLIT_ITEM_DWEIGHT = 3
_GMM_GROUP_BY_M_AXIS = 0
_GMM_GROUP_BY_K_AXIS = 2
_GMM_GROUP_LIST_IS_EXPERT_SIZES = 1


def _is_mindspeed_grouped_linear(base_layer) -> bool:
if not (is_torch_npu_available() and isinstance(base_layer, TEGroupedLinear)):
return False
return type(base_layer).__module__.startswith('mindspeed.')


def _has_moe_local_expert_grouping(base_layer) -> bool:
config = getattr(base_layer, 'config', None)
num_moe_experts = getattr(config, 'num_moe_experts', None)
if num_moe_experts is None:
return False
ep_size = getattr(config, 'expert_model_parallel_size', 1)
if not isinstance(ep_size, int) or ep_size <= 0:
return False
if num_moe_experts % ep_size != 0:
return False
return getattr(base_layer, 'num_gemms', None) == num_moe_experts // ep_size


def is_expert_layer(base_layer) -> bool:
is_expert = getattr(base_layer, 'is_expert', None)
if is_expert is not None:
return bool(is_expert)
if _is_mindspeed_grouped_linear(base_layer):
if getattr(base_layer, 'explicit_expert_comm', False):
return True
has_moe_local_expert_grouping = _has_moe_local_expert_grouping(base_layer)
if getattr(base_layer, 'expert_parallel', False):
return has_moe_local_expert_grouping
# MindSpeedTEGroupedLinear receives is_expert but does not keep it as
# an attribute. When EP/ETP does not trigger explicit expert comm,
# fall back to the TEGroupedMLP invariant: one grouped slot per local
# expert.
return has_moe_local_expert_grouping
return False


class NpuGroupedLoraLinear(nn.Module):
"""Generic grouped linear for low-rank NPU LoRA adapters."""

def __init__(self, num_gemms: int, input_size: int, output_size: int, *, config, bias: bool,
is_expert: bool = False):
super().__init__()
self.num_gemms = num_gemms
self.input_size = input_size
self.output_size = output_size
self.use_bias = bias
self.is_expert = is_expert
self.parallel_mode = None
device = torch.device('cpu') if config.use_cpu_initialization else get_current_device()
dtype = config.params_dtype
for i in range(num_gemms):
self.register_parameter(
f'weight{i}',
nn.Parameter(torch.empty(output_size, input_size, device=device, dtype=dtype)),
)
if bias:
self.register_parameter(
f'bias{i}',
nn.Parameter(torch.empty(output_size, device=device, dtype=dtype)),
)

@property
def weight(self):
return self.weight0

def _set_expert_replica_id(self, sharded_tensor):
replica_id = sharded_tensor.replica_id
assert len(replica_id) == 3, f'Expected replica_id to be in (PP, TP, DP) format, got: {replica_id}'
if getattr(sharded_tensor, 'is_data_parallel_fully_shard', False):
edp_replica_id = 0
else:
edp_replica_id = parallel_state.get_expert_data_parallel_rank()
sharded_tensor.replica_id = (*replica_id[:2], edp_replica_id)
return sharded_tensor

def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
):
if not self.is_expert:
return make_sharded_tensors_for_checkpoint(
self.state_dict(prefix='', keep_vars=True),
prefix,
sharded_offsets=sharded_offsets,
)

singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
num_global_experts = parallel_state.get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = parallel_state.get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
sharded_state_dict = {}

for i in range(self.num_gemms):
global_expert_idx = local_expert_indices_offset + i
if singleton_local_shards:
key_prefix = f'{global_expert_idx}.{prefix}'
new_sharded_offsets = sharded_offsets
else:
key_prefix = prefix
new_sharded_offsets = (*sharded_offsets, (ep_axis, global_expert_idx, num_global_experts))

for param_name in ('weight', 'bias'):
local_name = f'{param_name}{i}'
param = getattr(self, local_name, None)
if param is None:
continue
sharded_tensor = make_sharded_tensor_for_checkpoint(
param,
f'{key_prefix}{param_name}',
prepend_offsets=new_sharded_offsets,
)
sharded_state_dict[f'{prefix}{local_name}'] = self._set_expert_replica_id(sharded_tensor)

Comment on lines +131 to +142
return sharded_state_dict

def _fallback_forward(self, x, m_splits):
if isinstance(m_splits, torch.Tensor):
m_splits = m_splits.tolist()
outputs = []
offset = 0
Comment on lines +145 to +149

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _fallback_forward, there is no validation to ensure that the length of m_splits matches self.num_gemms. If len(m_splits) is greater than self.num_gemms, the loop will attempt to access non-existent attributes (e.g., weight{i}), resulting in a cryptic AttributeError.

Add a defensive check at the beginning of _fallback_forward to verify that len(m_splits) matches self.num_gemms.

    def _fallback_forward(self, x, m_splits):
        if len(m_splits) != self.num_gemms:
            raise RuntimeError(
                f"Expected m_splits length to be equal to num_gemms ({self.num_gemms}), got {len(m_splits)}"
            )
        if isinstance(m_splits, torch.Tensor):
            m_splits = m_splits.tolist()
        outputs = []
        offset = 0

for i, split_size in enumerate(m_splits):
split_size = int(split_size)
x_i = x[offset:offset + split_size]
offset += split_size
weight = getattr(self, f'weight{i}')
bias = getattr(self, f'bias{i}', None)
outputs.append(F.linear(x_i, weight, bias))
if offset != x.shape[0]:
raise RuntimeError(f'Grouped LoRA token split mismatch: got {offset}, expected {x.shape[0]}')
return torch.cat(outputs, dim=0) if outputs else x.new_empty((*x.shape[:-1], self.output_size)), None

def _can_use_grouped_matmul(self, x):
# PEFT's lora_bias=True adds a bias to LoRA-B and is not the mainstream
# Swift path. Keep that uncommon adapter shape on the generic PyTorch
# path until it is explicitly needed and verified.
return is_torch_npu_available() and x.device.type == 'npu' and x.dim() == 2 and not self.use_bias

def forward(self, x, m_splits):
if not self._can_use_grouped_matmul(x):
return self._fallback_forward(x, m_splits)
weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)]
return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None
Comment on lines +167 to +171

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method of NpuGroupedLoraLinear does not accept arbitrary positional (*args) or keyword (**kwargs) arguments. However, in LoraParallelLinear.forward (within src/mcore_bridge/tuners/lora.py), the adapter layers are called with *args and **kwargs (e.g., lora_A(dropout(x), *args, **kwargs)). If any extra keyword arguments or positional arguments are passed, this will raise a TypeError.

To ensure compatibility and prevent runtime crashes, update the signature of forward in NpuGroupedLoraLinear to accept *args and **kwargs.

Suggested change
def forward(self, x, m_splits):
if not self._can_use_grouped_matmul(x):
return self._fallback_forward(x, m_splits)
weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)]
return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None
def forward(self, x, m_splits, *args, **kwargs):
if not self._can_use_grouped_matmul(x):
return self._fallback_forward(x, m_splits)
weights = [getattr(self, f'weight{i}') for i in range(self.num_gemms)]
return _NpuGroupedLoraLinearGMM.apply(x, m_splits, weights, *[weight.T for weight in weights]), None



class _NpuGroupedLoraLinearGMM(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, m_splits, weights, *weight_input_T):
import torch_npu

if isinstance(m_splits, torch.Tensor):
group_list = m_splits
if group_list.device.type == 'cpu':
group_list = group_list.to(device=input_tensor.device, dtype=torch.int64)
else:
group_list = group_list.to(dtype=torch.int64)
else:
group_list = torch.tensor(m_splits, device=input_tensor.device, dtype=torch.int64)
output = torch_npu.npu_grouped_matmul(
[input_tensor],
weight_input_T,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
Comment on lines +188 to +196

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _NpuGroupedLoraLinearGMM.forward, weight_input_T is a tuple of tensors. The torch_npu.npu_grouped_matmul operator expects a list of tensors for its inputs and weights. Passing a tuple instead of a list can lead to type errors or undefined behavior in the C++ bindings of torch_npu.

Convert weight_input_T to a list before passing it to npu_grouped_matmul.

Suggested change
output = torch_npu.npu_grouped_matmul(
[input_tensor],
weight_input_T,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
output = torch_npu.npu_grouped_matmul(
[input_tensor],
list(weight_input_T),
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]

ctx.group_list = group_list
ctx.save_for_backward(input_tensor, *weights)
return output

@staticmethod
def backward(ctx, grad_output):
import torch_npu

input_tensor = ctx.saved_tensors[0]
weights = ctx.saved_tensors[1:]
group_list = ctx.group_list
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
weights,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
Comment on lines +208 to +216

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _NpuGroupedLoraLinearGMM.backward, weights is retrieved as a slice of ctx.saved_tensors, which is a tuple. Similar to the forward pass, torch_npu.npu_grouped_matmul expects a list of tensors.

Convert weights to a list before passing it to npu_grouped_matmul.

Suggested change
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
weights,
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
grad_input = torch_npu.npu_grouped_matmul(
[grad_output],
list(weights),
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_FORWARD_OR_DINPUT,
group_type=_GMM_GROUP_BY_M_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]

grad_weight_T = torch_npu.npu_grouped_matmul(
[input_tensor.T],
[grad_output],
bias=None,
group_list=group_list,
split_item=_GMM_SPLIT_ITEM_DWEIGHT,
group_type=_GMM_GROUP_BY_K_AXIS,
group_list_type=_GMM_GROUP_LIST_IS_EXPERT_SIZES,
)[0]
return grad_input, None, None, *grad_weight_T
Comment on lines +202 to +226
Loading