Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/megatron/bridge/peft/canonical_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

import torch
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.transformer.moe.router import TopKRouter
from torch import nn

from megatron.bridge.peft.adapter_wrapper import AdapterWrapper
from megatron.bridge.peft.base import PEFT
from megatron.bridge.peft.lora_layers import LinearAdapter, LoRALinear
from megatron.bridge.peft.lora_layers import LinearAdapter, LoRALinear, LoRATopKRouter
from megatron.bridge.peft.module_matcher import ModuleMatcher
from megatron.bridge.peft.utils import ParallelLinearAdapter, get_adapter_attributes_from_linear, is_expert_linear

Expand Down Expand Up @@ -215,7 +216,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s
"""

# Skip already transformed modules
if isinstance(m, (LinearAdapter, LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate)):
if isinstance(m, (LinearAdapter, LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)):
return m

if (ans := self.match(m, name, prefix)) is not None:
Expand Down Expand Up @@ -275,6 +276,8 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s

adapter = ParallelLinearAdapter(in_features, out_features, **adapter_kwargs)
logger.info(f"Adding lora to: {full_name}")
if isinstance(m, TopKRouter):
return LoRATopKRouter(m, adapter)
return LoRALinear(m, adapter)

return m
6 changes: 5 additions & 1 deletion src/megatron/bridge/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import torch.nn as nn
import transformer_engine.pytorch as te
from megatron.core import parallel_state
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.utils import unwrap_model

from megatron.bridge.peft.base import PEFT
from megatron.bridge.peft.lora_layers import (
LinearAdapter,
LoRALinear,
LoRATopKRouter,
TEFusedLoRALinear,
TELinearAdapter,
patch_linear_module,
Expand Down Expand Up @@ -102,7 +104,7 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
nn.Module: The modified module with LoRA applied, or the original module if not a target.
"""
# Skip already transformed modules
adapter_types = (LinearAdapter, LoRALinear)
adapter_types = (LinearAdapter, LoRALinear, LoRATopKRouter)
adapter_types = adapter_types + (TELinearAdapter,)
if isinstance(module, adapter_types):
return module
Expand Down Expand Up @@ -168,6 +170,8 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
disable_sequence_parallel_comm=disable_sp_comm,
base_linear_is_parallel=base_linear_is_parallel,
)
if isinstance(module, TopKRouter):
return LoRATopKRouter(module, adapter)
if enable_op_fuser:
return TEFusedLoRALinear(module, adapter)
else:
Expand Down
17 changes: 17 additions & 0 deletions src/megatron/bridge/peft/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
import transformer_engine.pytorch as te
from megatron.core.transformer.moe.moe_utils import apply_random_logits

from megatron.bridge.peft.adapter_wrapper import AdapterWrapper
from megatron.bridge.utils.import_utils import safe_import
Expand Down Expand Up @@ -58,6 +59,22 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Ten
return linear_output + adapter_output, bias


class LoRATopKRouter(AdapterWrapper):
"""Adapter wrapper that applies LoRA to router gating logits."""

def forward(self, x: torch.Tensor):
"""Forward pass that adds LoRA delta to router logits before routing."""
self.to_wrap._maintain_float32_expert_bias()
jittered_input = self.to_wrap.apply_input_jitter(x)
logits = self.to_wrap.gating(jittered_input)
if self._adapter_enabled:
adapter_output = self.adapter(jittered_input.contiguous())
logits = logits + adapter_output.to(dtype=logits.dtype)
if self.to_wrap.config.moe_router_force_load_balancing:
logits = apply_random_logits(logits)
return self.to_wrap.routing(logits)


class TELinearAdapter(te.Linear):
"""
TELinear + LoRA, maintains ckpts structure (i.e. Linear's weight/bias remain at the same FQN)
Expand Down
9 changes: 8 additions & 1 deletion src/megatron/bridge/peft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
from megatron.core.transformer.moe.router import TopKRouter

from megatron.bridge.utils.import_utils import safe_import_from

Expand Down Expand Up @@ -107,7 +108,13 @@ def get_adapter_attributes_from_linear(
tp_size = parallel_state.get_expert_tensor_parallel_world_size()
else:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
if HAVE_TE and any(isinstance(m, te_column_parallel) for te_column_parallel in TECL):
if isinstance(m, TopKRouter):
input_is_parallel = False
in_features = m.weight.shape[1]
out_features = m.weight.shape[0]
base_linear_is_parallel = False
disable_sequence_parallel_comm = True
elif HAVE_TE and any(isinstance(m, te_column_parallel) for te_column_parallel in TECL):
input_is_parallel = False
# m.in_features and m.out_features are divided by tp_size already,
# but in_features and out_features passed to ParallelLinearAdapter are not.
Expand Down
83 changes: 81 additions & 2 deletions tests/unit_tests/peft/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import os
from copy import deepcopy
from types import SimpleNamespace

import megatron.core.parallel_state as parallel_state
import pytest
Expand All @@ -29,8 +30,14 @@
import torch.nn as nn
import transformer_engine.pytorch as te

from megatron.bridge.peft.lora import TELinearAdapter
from megatron.bridge.peft.lora_layers import LinearAdapter, LoRALinear, TEFusedLoRALinear, patch_linear_module
from megatron.bridge.peft.lora import LoRA, TELinearAdapter
from megatron.bridge.peft.lora_layers import (
LinearAdapter,
LoRALinear,
LoRATopKRouter,
TEFusedLoRALinear,
patch_linear_module,
)


class MockLinearWithTupleReturn(nn.Module):
Expand Down Expand Up @@ -661,3 +668,75 @@ def test_linear_adapter_math_correctness(self):
expected = torch.full((1, 5), 10.4)

assert torch.allclose(output, expected, atol=1e-6)


class DummyRouter(nn.Module):
def __init__(self, hidden_size: int = 4, num_experts: int = 3) -> None:
super().__init__()
self.weight = nn.Parameter(torch.randn(num_experts, hidden_size))
self.expert_bias = torch.zeros(num_experts)
self.config = SimpleNamespace(
moe_router_force_load_balancing=False,
sequence_parallel=False,
)

def _maintain_float32_expert_bias(self) -> None:
if isinstance(self.expert_bias, torch.Tensor):
self.expert_bias = self.expert_bias.float()

def apply_input_jitter(self, x: torch.Tensor) -> torch.Tensor:
return x

def gating(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.weight.t()

def routing(self, logits: torch.Tensor):
return logits, logits > 0


class RouterModel(nn.Module):
def __init__(self, router_cls: type[DummyRouter]) -> None:
super().__init__()
self.mlp = nn.Module()
self.mlp.router = router_cls()


class TestLoRATopKRouter:
"""Test LoRA router wrapper behavior."""

def test_forward_adds_adapter_delta(self) -> None:
hidden_size = 5
num_experts = 4
router = DummyRouter(hidden_size=hidden_size, num_experts=num_experts)
adapter = nn.Linear(hidden_size, num_experts, bias=False)
wrapper = LoRATopKRouter(router, adapter)

x = torch.randn(2, hidden_size)
expected_logits = router.gating(x) + adapter(x)

logits, routing_map = wrapper(x)

assert torch.allclose(logits, expected_logits)
assert routing_map.shape == expected_logits.shape

def test_lora_wraps_router_with_lora_topk(self, monkeypatch: pytest.MonkeyPatch) -> None:
from megatron.bridge.peft import lora as lora_module

class DummyTopKRouter(DummyRouter):
pass

def fake_adapter(in_features, out_features, *args, **kwargs):
return nn.Linear(in_features, out_features, bias=False)

def fake_attrs(*args, **kwargs):
return False, 4, 3, False, True, False

monkeypatch.setattr(lora_module, "TopKRouter", DummyTopKRouter, raising=True)
monkeypatch.setattr(lora_module, "ParallelLinearAdapter", fake_adapter, raising=True)
monkeypatch.setattr(lora_module, "get_adapter_attributes_from_linear", fake_attrs, raising=True)

model = RouterModel(DummyTopKRouter)
lora = LoRA(target_modules=["router"])
transformed = lora(model, training=True)

assert isinstance(transformed.mlp.router, LoRATopKRouter)