|
27 | 27 |
|
28 | 28 | @require_torch |
29 | 29 | class BottleneckAdapterTestMixin(AdapterMethodBaseTestMixin): |
30 | | - |
31 | 30 | adapter_configs_to_test = [ |
32 | 31 | (PfeifferConfig(), ["adapters.{name}."]), |
33 | 32 | (MAMConfig(), ["adapters.{name}.", "prefix_tunings.{name}."]), |
34 | 33 | ] |
| 34 | + inv_adapter_configs_to_test = [ |
| 35 | + (PfeifferInvConfig(), ["invertible_adapters.{name}"]), |
| 36 | + (HoulsbyInvConfig(), ["invertible_adapters.{name}"]), |
| 37 | + ] |
35 | 38 |
|
36 | 39 | def test_add_adapter(self): |
37 | 40 | model = self.get_model() |
@@ -104,6 +107,40 @@ def forward_pre_hook(module, input): |
104 | 107 | # We expect one call to invertible adapter |
105 | 108 | self.assertEqual(1, calls) |
106 | 109 |
|
| 110 | + def test_delete_adapter_with_invertible(self): |
| 111 | + """Tests if the invertible adapters are deleted correctly.""" |
| 112 | + model = self.get_model().base_model |
| 113 | + model.eval() |
| 114 | + if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin): |
| 115 | + self.skipTest("Model does not support invertible adapters.") |
| 116 | + |
| 117 | + # iterate through all adapter invertible adapter configs |
| 118 | + for adapter_config, filter_keys in self.inv_adapter_configs_to_test: |
| 119 | + with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): |
| 120 | + name = adapter_config.__class__.__name__ |
| 121 | + model.add_adapter(name, config=adapter_config) |
| 122 | + model.set_active_adapters([name]) |
| 123 | + |
| 124 | + # check if adapter is correctly added to config |
| 125 | + self.assert_adapter_available(model, name) |
| 126 | + # remove the adapter again |
| 127 | + model.delete_adapter(name) |
| 128 | + |
| 129 | + # check if adapter is correctly removed from the model |
| 130 | + self.assert_adapter_unavailable(model, name) |
| 131 | + |
| 132 | + # check additionally if invertible adapter is removed correctly from the model |
| 133 | + self.assertFalse(name in model.invertible_adapters) |
| 134 | + self.assertEqual(None, model.get_invertible_adapter()) |
| 135 | + |
| 136 | + # check that weights are available and active |
| 137 | + has_weights = False |
| 138 | + filter_keys = [k.format(name=name) for k in filter_keys] |
| 139 | + print(f"filter_keys = {filter_keys}") |
| 140 | + for k, v in self.filter_parameters(model, filter_keys).items(): |
| 141 | + has_weights = True |
| 142 | + self.assertFalse(has_weights) |
| 143 | + |
107 | 144 | def test_get_adapter(self): |
108 | 145 | model = self.get_model() |
109 | 146 | model.eval() |
|
0 commit comments