Skip to content

Commit a3bb972

Browse files
committed
add test cases
Signed-off-by: Hollow Man <[email protected]>
1 parent 30107e5 commit a3bb972

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

tests/unit_tests/peft/test_adapter_wrapper.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import torch
2626
import torch.nn as nn
2727

28+
from megatron.bridge.peft.base import PEFT
2829
from megatron.bridge.peft.adapter_wrapper import AdapterWrapper
30+
from megatron.bridge.peft.lora_layers import LoRALinear
2931

3032

3133
class MockLinear(nn.Module):
@@ -67,6 +69,25 @@ def forward(self, x, *args, **kwargs):
6769
return linear_output + adapter_output, bias
6870

6971

72+
class DummyPEFT(PEFT):
73+
"""Minimal PEFT implementation for adapter enable/disable tests."""
74+
75+
def transform(self, module: nn.Module, name=None, prefix=None) -> nn.Module:
76+
return module
77+
78+
79+
class AdapterModel(nn.Module):
80+
"""Model with a single LoRALinear adapter wrapper for testing."""
81+
82+
def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None:
83+
super().__init__()
84+
self.lora = LoRALinear(to_wrap, adapter)
85+
86+
def forward(self, x: torch.Tensor) -> torch.Tensor:
87+
output, _ = self.lora(x)
88+
return output
89+
90+
7091
class TestAdapterWrapper:
7192
"""Test the AdapterWrapper base class."""
7293

@@ -258,3 +279,59 @@ def test_adapter_wrapper_is_abstract(self):
258279
assert hasattr(AdapterWrapper, "base_linear_forward")
259280
assert hasattr(AdapterWrapper, "state_dict")
260281
assert hasattr(AdapterWrapper, "sharded_state_dict")
282+
283+
def test_adapter_wrapper_enable_disable_toggle(self, mock_linear_simple, simple_adapter):
284+
"""Test adapter output toggling via AdapterWrapper methods."""
285+
wrapper = LoRALinear(mock_linear_simple, simple_adapter)
286+
x = torch.randn(5, 10)
287+
288+
base_output, _ = mock_linear_simple(x)
289+
enabled_output, _ = wrapper(x)
290+
expected = base_output + simple_adapter(x)
291+
assert torch.allclose(enabled_output, expected, atol=1e-6)
292+
293+
wrapper.disable_adapter_layers()
294+
disabled_output, _ = wrapper(x)
295+
assert torch.allclose(disabled_output, base_output, atol=1e-6)
296+
297+
wrapper.enable_adapter_layers()
298+
reenabled_output, _ = wrapper(x)
299+
assert torch.allclose(reenabled_output, enabled_output, atol=1e-6)
300+
301+
def test_peft_disable_adapter_context_manager(self, mock_linear_simple, simple_adapter):
302+
"""Test PEFT.disable_adapter restores adapter state."""
303+
peft = DummyPEFT()
304+
model = AdapterModel(mock_linear_simple, simple_adapter)
305+
x = torch.randn(5, 10)
306+
307+
base_output, _ = mock_linear_simple(x)
308+
enabled_output = model(x)
309+
310+
with peft.disable_adapter(model):
311+
disabled_output = model(x)
312+
assert torch.allclose(disabled_output, base_output, atol=1e-6)
313+
314+
assert torch.allclose(model(x), enabled_output, atol=1e-6)
315+
316+
with pytest.raises(RuntimeError):
317+
with peft.disable_adapter(model):
318+
raise RuntimeError("boom")
319+
320+
assert torch.allclose(model(x), enabled_output, atol=1e-6)
321+
322+
def test_peft_enable_disable_adapter_layers_manual(self, mock_linear_simple, simple_adapter):
323+
"""Test manual adapter enable/disable via PEFT helpers."""
324+
peft = DummyPEFT()
325+
model = AdapterModel(mock_linear_simple, simple_adapter)
326+
x = torch.randn(5, 10)
327+
328+
base_output, _ = mock_linear_simple(x)
329+
enabled_output = model(x)
330+
331+
peft.disable_adapter_layers(model)
332+
disabled_output = model(x)
333+
assert torch.allclose(disabled_output, base_output, atol=1e-6)
334+
335+
peft.enable_adapter_layers(model)
336+
reenabled_output = model(x)
337+
assert torch.allclose(reenabled_output, enabled_output, atol=1e-6)

0 commit comments

Comments
 (0)