Skip to content

Commit 1e1fc6f

Browse files
committed
[PEFT] feat: Add support for temporary disable adapter
Similar to: https://github.com/huggingface/peft/blob/261366de2e40cde64b702d6b9c527081ad850549/src/peft/mixed_model.py#L192-L201 `enable_adapter_layers` and `disable_adapter_layers` are alternatives if users want to control manually. Signed-off-by: Hollow Man <[email protected]>
1 parent 7695d4a commit 1e1fc6f

File tree

5 files changed

+68
-19
lines changed

5 files changed

+68
-19
lines changed

src/megatron/bridge/peft/adapter_wrapper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None:
100100
super(AdapterWrapper, self).__init__()
101101
self.to_wrap = to_wrap
102102
self.adapter = adapter
103+
self._adapter_enabled = True
104+
105+
def enable_adapter_layers(self) -> None:
106+
"""Enable the adapter layers, allowing them to contribute to the forward pass output."""
107+
self._adapter_enabled = True
108+
109+
def disable_adapter_layers(self) -> None:
110+
"""Disable the adapter layers, making the forward pass return only the base module output."""
111+
self._adapter_enabled = False
103112

104113
def base_linear_forward(
105114
self, x: torch.Tensor, *args: Any, **kwargs: Any

src/megatron/bridge/peft/base.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
from abc import ABC, abstractmethod
17+
from contextlib import contextmanager
1718
from dataclasses import dataclass, field
1819
from typing import Optional, TypeVar, Union
1920

@@ -95,17 +96,7 @@ def __call__(self, model: ModelType, training: bool = True) -> ModelType:
9596
"""
9697
self.freeze_model(model, training=training)
9798

98-
if isinstance(model, list) and len(model) > 1:
99-
for model_chunk in model:
100-
walk(model_chunk, self.transform)
101-
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
102-
walk(model.module, self.transform)
103-
else:
104-
if isinstance(model, list):
105-
model_to_walk = model[0] if len(model) == 1 else model
106-
else:
107-
model_to_walk = model
108-
walk(model_to_walk, self.transform)
99+
self._walk_model(model, self.transform)
109100

110101
if not training:
111102
self.freeze_model(model, training=training)
@@ -119,6 +110,48 @@ def __call__(self, model: ModelType, training: bool = True) -> ModelType:
119110

120111
return model
121112

113+
def _walk_model(self, model: ModelType, func) -> None:
114+
if isinstance(model, list):
115+
for model_chunk in model:
116+
walk(model_chunk, func)
117+
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
118+
walk(model.module, func)
119+
else:
120+
walk(model, func)
121+
122+
def enable_adapter_layers(self, model: ModelType) -> None:
123+
"""Enable adapter layers for all PEFT-wrapped modules in the model."""
124+
125+
def enable(module: nn.Module) -> nn.Module:
126+
method = getattr(module, "enable_adapter_layers", None)
127+
if callable(method):
128+
method()
129+
return module
130+
131+
self._walk_model(model, enable)
132+
133+
def disable_adapter_layers(self, model: ModelType) -> None:
134+
"""Disable adapter layers for all PEFT-wrapped modules in the model."""
135+
136+
def disable(module: nn.Module) -> nn.Module:
137+
method = getattr(module, "disable_adapter_layers", None)
138+
if callable(method):
139+
method()
140+
return module
141+
142+
self._walk_model(model, disable)
143+
144+
@contextmanager
145+
def disable_adapter(self, model: ModelType):
146+
"""
147+
Disables the adapter module.
148+
"""
149+
try:
150+
self.disable_adapter_layers(model)
151+
yield
152+
finally:
153+
self.enable_adapter_layers(model)
154+
122155
def freeze_model(self, model: ModelType, training: bool = True) -> None:
123156
"""Apply a default freeze method to the model.
124157
@@ -136,13 +169,7 @@ def freeze_parameters(module):
136169
param.requires_grad = False
137170
return module
138171

139-
if isinstance(model, list):
140-
for model_chunk in model:
141-
walk(model_chunk, freeze_parameters)
142-
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
143-
walk(model.module, freeze_parameters)
144-
else:
145-
walk(model, freeze_parameters)
172+
self._walk_model(model, freeze_parameters)
146173

147174
if training:
148175
if isinstance(model, list):

src/megatron/bridge/peft/canonical_lora.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class to provide a specific implementation of the forward method.
7474
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
7575
# pylint: disable=C0115,C0116
7676
linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs)
77+
if not self._adapter_enabled:
78+
return linear_output, bias
7779
query = self.adapter.adapter_q(layernorm_output)
7880
key = self.adapter.adapter_k(layernorm_output)
7981
value = self.adapter.adapter_v(layernorm_output)
@@ -100,6 +102,8 @@ class to provide a specific implementation of the forward method.
100102
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
101103
# pylint: disable=C0115,C0116
102104
linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs)
105+
if not self._adapter_enabled:
106+
return linear_output, bias
103107
adapter_output_gate = self.adapter.adapter_gate(layernorm_output)
104108
adapter_output_up = self.adapter.adapter_up(layernorm_output)
105109
adapter_output = torch.cat([adapter_output_gate, adapter_output_up], dim=-1)

src/megatron/bridge/peft/dora_layers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
152152
tuple: A tuple containing the DoRA output and bias term.
153153
"""
154154
linear_output, bias, layernorm_output = self.base_linear_forward(x)
155+
if not self._adapter_enabled:
156+
return linear_output, bias
155157
adapter_output = self.adapter(layernorm_output.contiguous())
156158

157159
# mag_norm_scale is ||W_0 + B_0 A_0|| / ||W_0 + B A|| (scaling in front of BA not shown)

src/megatron/bridge/peft/lora_layers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Ten
4848
4949
Returns:
5050
A tuple containing:
51-
- Combined output (linear_output + adapter_output)
51+
- Combined output (linear_output + adapter_output) if adapter is enabled,
52+
otherwise just the linear_output
5253
- Bias term (if present, otherwise None)
5354
"""
5455
# pylint: disable=C0115,C0116
5556
linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs)
57+
if not self._adapter_enabled:
58+
return linear_output, bias
5659
adapter_output = self.adapter(layernorm_output.contiguous())
5760
adapter_output = adapter_output.reshape(linear_output.shape)
5861
return linear_output + adapter_output, bias
@@ -428,6 +431,10 @@ def _make_lora_branch(
428431
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
429432
# pylint: disable=C0115,C0116
430433

434+
# If adapter is disabled, fall back to base forward
435+
if not self._adapter_enabled:
436+
return super().forward(x)
437+
431438
# Construct fused impl if needed
432439
# Note: We initialize during the first forward pass in
433440
# case the params are modified after the constructor.

0 commit comments

Comments
 (0)