Skip to content

Commit ab3325b

Browse files
TimoImhofcalpt
authored andcommitted
add missing delete method to mixin,
add remove test inv_adapters
1 parent 78ebb98 commit ab3325b

2 files changed

Lines changed: 50 additions & 1 deletion

File tree

src/adapters/model_mixin.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,18 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
12551255
else:
12561256
self.base_model.add_adapter(adapter_name, config, overwrite_ok=overwrite_ok, set_active=set_active)
12571257

1258+
def delete_adapter(self, adapter_name: str):
1259+
"""
1260+
Deletes the adapter with the specified name from the model.
1261+
1262+
Args:
1263+
adapter_name (str): The name of the adapter.
1264+
"""
1265+
if self.base_model is self:
1266+
super().delete_adapter(adapter_name)
1267+
else:
1268+
self.base_model.delete_adapter(adapter_name)
1269+
12581270
def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False):
12591271
"""
12601272
Sets the model into mode for training the given adapters. If self.base_model is self, must inherit from a class

tests_adapters/methods/test_adapter_common.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@
2727

2828
@require_torch
2929
class BottleneckAdapterTestMixin(AdapterMethodBaseTestMixin):
30-
3130
adapter_configs_to_test = [
3231
(PfeifferConfig(), ["adapters.{name}."]),
3332
(MAMConfig(), ["adapters.{name}.", "prefix_tunings.{name}."]),
3433
]
34+
inv_adapter_configs_to_test = [
35+
(PfeifferInvConfig(), ["invertible_adapters.{name}"]),
36+
(HoulsbyInvConfig(), ["invertible_adapters.{name}"]),
37+
]
3538

3639
def test_add_adapter(self):
3740
model = self.get_model()
@@ -104,6 +107,40 @@ def forward_pre_hook(module, input):
104107
# We expect one call to invertible adapter
105108
self.assertEqual(1, calls)
106109

110+
def test_delete_adapter_with_invertible(self):
111+
"""Tests if the invertible adapters are deleted correctly."""
112+
model = self.get_model().base_model
113+
model.eval()
114+
if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin):
115+
self.skipTest("Model does not support invertible adapters.")
116+
117+
# iterate through all adapter invertible adapter configs
118+
for adapter_config, filter_keys in self.inv_adapter_configs_to_test:
119+
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
120+
name = adapter_config.__class__.__name__
121+
model.add_adapter(name, config=adapter_config)
122+
model.set_active_adapters([name])
123+
124+
# check if adapter is correctly added to config
125+
self.assert_adapter_available(model, name)
126+
# remove the adapter again
127+
model.delete_adapter(name)
128+
129+
# check if adapter is correctly removed from the model
130+
self.assert_adapter_unavailable(model, name)
131+
132+
# check additionally if invertible adapter is removed correctly from the model
133+
self.assertFalse(name in model.invertible_adapters)
134+
self.assertEqual(None, model.get_invertible_adapter())
135+
136+
# check that weights are available and active
137+
has_weights = False
138+
filter_keys = [k.format(name=name) for k in filter_keys]
139+
print(f"filter_keys = {filter_keys}")
140+
for k, v in self.filter_parameters(model, filter_keys).items():
141+
has_weights = True
142+
self.assertFalse(has_weights)
143+
107144
def test_get_adapter(self):
108145
model = self.get_model()
109146
model.eval()

0 commit comments

Comments
 (0)