|
25 | 25 | import torch |
26 | 26 | import torch.nn as nn |
27 | 27 |
|
| 28 | +from megatron.bridge.peft.base import PEFT |
28 | 29 | from megatron.bridge.peft.adapter_wrapper import AdapterWrapper |
| 30 | +from megatron.bridge.peft.lora_layers import LoRALinear |
29 | 31 |
|
30 | 32 |
|
31 | 33 | class MockLinear(nn.Module): |
@@ -67,6 +69,25 @@ def forward(self, x, *args, **kwargs): |
67 | 69 | return linear_output + adapter_output, bias |
68 | 70 |
|
69 | 71 |
|
| 72 | +class DummyPEFT(PEFT): |
| 73 | + """Minimal PEFT implementation for adapter enable/disable tests.""" |
| 74 | + |
| 75 | + def transform(self, module: nn.Module, name=None, prefix=None) -> nn.Module: |
| 76 | + return module |
| 77 | + |
| 78 | + |
| 79 | +class AdapterModel(nn.Module): |
| 80 | + """Model with a single LoRALinear adapter wrapper for testing.""" |
| 81 | + |
| 82 | + def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None: |
| 83 | + super().__init__() |
| 84 | + self.lora = LoRALinear(to_wrap, adapter) |
| 85 | + |
| 86 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 87 | + output, _ = self.lora(x) |
| 88 | + return output |
| 89 | + |
| 90 | + |
70 | 91 | class TestAdapterWrapper: |
71 | 92 | """Test the AdapterWrapper base class.""" |
72 | 93 |
|
@@ -258,3 +279,59 @@ def test_adapter_wrapper_is_abstract(self): |
258 | 279 | assert hasattr(AdapterWrapper, "base_linear_forward") |
259 | 280 | assert hasattr(AdapterWrapper, "state_dict") |
260 | 281 | assert hasattr(AdapterWrapper, "sharded_state_dict") |
| 282 | + |
| 283 | + def test_adapter_wrapper_enable_disable_toggle(self, mock_linear_simple, simple_adapter): |
| 284 | + """Test adapter output toggling via AdapterWrapper methods.""" |
| 285 | + wrapper = LoRALinear(mock_linear_simple, simple_adapter) |
| 286 | + x = torch.randn(5, 10) |
| 287 | + |
| 288 | + base_output, _ = mock_linear_simple(x) |
| 289 | + enabled_output, _ = wrapper(x) |
| 290 | + expected = base_output + simple_adapter(x) |
| 291 | + assert torch.allclose(enabled_output, expected, atol=1e-6) |
| 292 | + |
| 293 | + wrapper.disable_adapter_layers() |
| 294 | + disabled_output, _ = wrapper(x) |
| 295 | + assert torch.allclose(disabled_output, base_output, atol=1e-6) |
| 296 | + |
| 297 | + wrapper.enable_adapter_layers() |
| 298 | + reenabled_output, _ = wrapper(x) |
| 299 | + assert torch.allclose(reenabled_output, enabled_output, atol=1e-6) |
| 300 | + |
| 301 | + def test_peft_disable_adapter_context_manager(self, mock_linear_simple, simple_adapter): |
| 302 | + """Test PEFT.disable_adapter restores adapter state.""" |
| 303 | + peft = DummyPEFT() |
| 304 | + model = AdapterModel(mock_linear_simple, simple_adapter) |
| 305 | + x = torch.randn(5, 10) |
| 306 | + |
| 307 | + base_output, _ = mock_linear_simple(x) |
| 308 | + enabled_output = model(x) |
| 309 | + |
| 310 | + with peft.disable_adapter(model): |
| 311 | + disabled_output = model(x) |
| 312 | + assert torch.allclose(disabled_output, base_output, atol=1e-6) |
| 313 | + |
| 314 | + assert torch.allclose(model(x), enabled_output, atol=1e-6) |
| 315 | + |
| 316 | + with pytest.raises(RuntimeError): |
| 317 | + with peft.disable_adapter(model): |
| 318 | + raise RuntimeError("boom") |
| 319 | + |
| 320 | + assert torch.allclose(model(x), enabled_output, atol=1e-6) |
| 321 | + |
| 322 | + def test_peft_enable_disable_adapter_layers_manual(self, mock_linear_simple, simple_adapter): |
| 323 | + """Test manual adapter enable/disable via PEFT helpers.""" |
| 324 | + peft = DummyPEFT() |
| 325 | + model = AdapterModel(mock_linear_simple, simple_adapter) |
| 326 | + x = torch.randn(5, 10) |
| 327 | + |
| 328 | + base_output, _ = mock_linear_simple(x) |
| 329 | + enabled_output = model(x) |
| 330 | + |
| 331 | + peft.disable_adapter_layers(model) |
| 332 | + disabled_output = model(x) |
| 333 | + assert torch.allclose(disabled_output, base_output, atol=1e-6) |
| 334 | + |
| 335 | + peft.enable_adapter_layers(model) |
| 336 | + reenabled_output = model(x) |
| 337 | + assert torch.allclose(reenabled_output, enabled_output, atol=1e-6) |
0 commit comments