Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/megatron/bridge/peft/adapter_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None:
super(AdapterWrapper, self).__init__()
self.to_wrap = to_wrap
self.adapter = adapter
self._adapter_enabled = True

def enable_adapter_layers(self) -> None:
"""Enable the adapter layers, allowing them to contribute to the forward pass output."""
self._adapter_enabled = True

def disable_adapter_layers(self) -> None:
"""Disable the adapter layers, making the forward pass return only the base module output."""
self._adapter_enabled = False

def base_linear_forward(
self, x: torch.Tensor, *args: Any, **kwargs: Any
Expand Down
63 changes: 45 additions & 18 deletions src/megatron/bridge/peft/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Optional, TypeVar, Union

Expand Down Expand Up @@ -96,17 +97,7 @@ def __call__(self, model: ModelType, training: bool = True) -> ModelType:
"""
self.freeze_model(model, training=training)

if isinstance(model, list) and len(model) > 1:
for model_chunk in model:
walk(model_chunk, self.transform)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
walk(model.module, self.transform)
else:
if isinstance(model, list):
model_to_walk = model[0] if len(model) == 1 else model
else:
model_to_walk = model
walk(model_to_walk, self.transform)
self._walk_model(model, self.transform)

if training:
maybe_enable_recompute_inputs_grad(model)
Expand All @@ -123,6 +114,48 @@ def __call__(self, model: ModelType, training: bool = True) -> ModelType:

return model

def _walk_model(self, model: ModelType, func) -> None:
if isinstance(model, list):
for model_chunk in model:
walk(model_chunk, func)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
walk(model.module, func)
else:
walk(model, func)

def enable_adapter_layers(self, model: ModelType) -> None:
"""Enable adapter layers for all PEFT-wrapped modules in the model."""

def enable(module: nn.Module) -> nn.Module:
method = getattr(module, "enable_adapter_layers", None)
if callable(method):
method()
return module

self._walk_model(model, enable)

def disable_adapter_layers(self, model: ModelType) -> None:
"""Disable adapter layers for all PEFT-wrapped modules in the model."""

def disable(module: nn.Module) -> nn.Module:
method = getattr(module, "disable_adapter_layers", None)
if callable(method):
method()
return module

self._walk_model(model, disable)

@contextmanager
def disable_adapter(self, model: ModelType):
"""
Disables the adapter module.
"""
try:
self.disable_adapter_layers(model)
yield
finally:
self.enable_adapter_layers(model)

def freeze_model(self, model: ModelType, training: bool = True) -> None:
"""Apply a default freeze method to the model.

Expand All @@ -140,13 +173,7 @@ def freeze_parameters(module):
param.requires_grad = False
return module

if isinstance(model, list):
for model_chunk in model:
walk(model_chunk, freeze_parameters)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
walk(model.module, freeze_parameters)
else:
walk(model, freeze_parameters)
self._walk_model(model, freeze_parameters)

if training:
if isinstance(model, list):
Expand Down
4 changes: 4 additions & 0 deletions src/megatron/bridge/peft/canonical_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class to provide a specific implementation of the forward method.
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# pylint: disable=C0115,C0116
linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs)
if not self._adapter_enabled:
return linear_output, bias
query = self.adapter.adapter_q(layernorm_output)
key = self.adapter.adapter_k(layernorm_output)
value = self.adapter.adapter_v(layernorm_output)
Expand All @@ -100,6 +102,8 @@ class to provide a specific implementation of the forward method.
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# pylint: disable=C0115,C0116
linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs)
if not self._adapter_enabled:
return linear_output, bias
adapter_output_gate = self.adapter.adapter_gate(layernorm_output)
adapter_output_up = self.adapter.adapter_up(layernorm_output)
adapter_output = torch.cat([adapter_output_gate, adapter_output_up], dim=-1)
Expand Down
2 changes: 2 additions & 0 deletions src/megatron/bridge/peft/dora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
tuple: A tuple containing the DoRA output and bias term.
"""
linear_output, bias, layernorm_output = self.base_linear_forward(x)
if not self._adapter_enabled:
return linear_output, bias
adapter_output = self.adapter(layernorm_output.contiguous())

# mag_norm_scale is ||W_0 + B_0 A_0|| / ||W_0 + B A|| (scaling in front of BA not shown)
Expand Down
34 changes: 33 additions & 1 deletion src/megatron/bridge/peft/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Ten

Returns:
A tuple containing:
- Combined output (linear_output + adapter_output)
- Combined output (linear_output + adapter_output) if adapter is enabled,
otherwise just the linear_output
- Bias term (if present, otherwise None)
"""
# pylint: disable=C0115,C0116
linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs)
if not self._adapter_enabled:
return linear_output, bias
adapter_output = self.adapter(layernorm_output.contiguous())
adapter_output = adapter_output.reshape(linear_output.shape)
return linear_output + adapter_output, bias
Expand Down Expand Up @@ -122,6 +125,15 @@ def __init__(
lora_A_init_method=lora_A_init_method,
lora_dtype=lora_dtype,
)
self._adapter_enabled = True

def enable_adapter_layers(self) -> None:
"""Enable the adapter layers, allowing them to contribute to the forward pass output."""
self._adapter_enabled = True

def disable_adapter_layers(self) -> None:
"""Disable the adapter layers, making the forward pass return only the base module output."""
self._adapter_enabled = False

@torch.no_grad
@staticmethod
Expand Down Expand Up @@ -181,6 +193,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
# pylint: disable=C0115,C0116
res = super(TELinearAdapter, self).forward(x)

if not self._adapter_enabled:
return res

if self.dropout_position == "pre":
x = self.dropout(x)
# LoRA fwd is performed in original precision regardless of FP8 enabled
Expand Down Expand Up @@ -428,6 +444,10 @@ def _make_lora_branch(
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
# pylint: disable=C0115,C0116

# If adapter is disabled, fall back to base forward
if not self._adapter_enabled:
return super().forward(x)

# Construct fused impl if needed
# Note: We initialize during the first forward pass in
# case the params are modified after the constructor.
Expand Down Expand Up @@ -506,6 +526,15 @@ def __init__(
lora_A_init_method=lora_A_init_method,
lora_dtype=lora_dtype,
)
self._adapter_enabled = True

def enable_adapter_layers(self) -> None:
"""Enable the adapter layers, allowing them to contribute to the forward pass output."""
self._adapter_enabled = True

def disable_adapter_layers(self) -> None:
"""Disable the adapter layers, making the forward pass return only the base module output."""
self._adapter_enabled = False

@torch.no_grad
@staticmethod
Expand Down Expand Up @@ -573,6 +602,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
else:
res = torch.nn.functional.linear(x, self.weight, self.bias)

if not self._adapter_enabled:
return res

if self.dropout_position == "pre":
x = self.dropout(x)
lora_res = self.lora_b(self.lora_a(x))
Expand Down
77 changes: 77 additions & 0 deletions tests/unit_tests/peft/test_adapter_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import torch
import torch.nn as nn

from megatron.bridge.peft.base import PEFT
from megatron.bridge.peft.adapter_wrapper import AdapterWrapper
from megatron.bridge.peft.lora_layers import LoRALinear


class MockLinear(nn.Module):
Expand Down Expand Up @@ -67,6 +69,25 @@ def forward(self, x, *args, **kwargs):
return linear_output + adapter_output, bias


class DummyPEFT(PEFT):
"""Minimal PEFT implementation for adapter enable/disable tests."""

def transform(self, module: nn.Module, name=None, prefix=None) -> nn.Module:
return module


class AdapterModel(nn.Module):
"""Model with a single LoRALinear adapter wrapper for testing."""

def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None:
super().__init__()
self.lora = LoRALinear(to_wrap, adapter)

def forward(self, x: torch.Tensor) -> torch.Tensor:
output, _ = self.lora(x)
return output


class TestAdapterWrapper:
"""Test the AdapterWrapper base class."""

Expand Down Expand Up @@ -258,3 +279,59 @@ def test_adapter_wrapper_is_abstract(self):
assert hasattr(AdapterWrapper, "base_linear_forward")
assert hasattr(AdapterWrapper, "state_dict")
assert hasattr(AdapterWrapper, "sharded_state_dict")

def test_adapter_wrapper_enable_disable_toggle(self, mock_linear_simple, simple_adapter):
"""Test adapter output toggling via AdapterWrapper methods."""
wrapper = LoRALinear(mock_linear_simple, simple_adapter)
x = torch.randn(5, 10)

base_output, _ = mock_linear_simple(x)
enabled_output, _ = wrapper(x)
expected = base_output + simple_adapter(x)
assert torch.allclose(enabled_output, expected, atol=1e-6)

wrapper.disable_adapter_layers()
disabled_output, _ = wrapper(x)
assert torch.allclose(disabled_output, base_output, atol=1e-6)

wrapper.enable_adapter_layers()
reenabled_output, _ = wrapper(x)
assert torch.allclose(reenabled_output, enabled_output, atol=1e-6)

def test_peft_disable_adapter_context_manager(self, mock_linear_simple, simple_adapter):
"""Test PEFT.disable_adapter restores adapter state."""
peft = DummyPEFT()
model = AdapterModel(mock_linear_simple, simple_adapter)
x = torch.randn(5, 10)

base_output, _ = mock_linear_simple(x)
enabled_output = model(x)

with peft.disable_adapter(model):
disabled_output = model(x)
assert torch.allclose(disabled_output, base_output, atol=1e-6)

assert torch.allclose(model(x), enabled_output, atol=1e-6)

with pytest.raises(RuntimeError):
with peft.disable_adapter(model):
raise RuntimeError("boom")

assert torch.allclose(model(x), enabled_output, atol=1e-6)

def test_peft_enable_disable_adapter_layers_manual(self, mock_linear_simple, simple_adapter):
"""Test manual adapter enable/disable via PEFT helpers."""
peft = DummyPEFT()
model = AdapterModel(mock_linear_simple, simple_adapter)
x = torch.randn(5, 10)

base_output, _ = mock_linear_simple(x)
enabled_output = model(x)

peft.disable_adapter_layers(model)
disabled_output = model(x)
assert torch.allclose(disabled_output, base_output, atol=1e-6)

peft.enable_adapter_layers(model)
reenabled_output = model(x)
assert torch.allclose(reenabled_output, enabled_output, atol=1e-6)