Skip to content

Commit 78ebb98

Browse files
committed
Add adapter output & parameter averaging (speechbrain#16)
1 parent b2fe481 commit 78ebb98

14 files changed

Lines changed: 458 additions & 29 deletions

docs/adapter_composition.md

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ The basic building blocks of the more advanced setups are simple objects derived
3939
each representing a different possibility to combine single adapters.
4040
The following table gives an overview on the supported composition blocks and their support by different adapter methods.
4141

42-
| Block | (Bottleneck)<br> Adapters | Prefix<br> Tuning | Compacter | LoRA | (IA)³ |
42+
| Block | Bottleneck<br> Adapters | Prefix<br> Tuning | Compacter | LoRA | (IA)³ |
4343
| --- | --- | --- | --- | --- | --- |
4444
| [`Stack`](#stack) |||| | |
4545
| [`Fuse`](#fuse) || || | |
4646
| [`Split`](#split) || || | |
4747
| [`BatchSplit`](#batchsplit) |||| | |
4848
| [`Parallel`](#parallel) |||| | |
49+
| [Output averaging](#output-averaging) || || | |
50+
| [Parameter averaging](#parameter-averaging) ||||||
4951

5052
Next, we present all composition blocks in more detail.
5153

@@ -178,7 +180,8 @@ model.active_adapters = ac.Split("g", "h", split_index=64)
178180
```
179181

180182
## `BatchSplit`
181-
The `BatchSplit` lock is an alternative to split the input between several adapters. It does not split the input sequences but the
183+
184+
The `BatchSplit` block is an alternative to split the input between several adapters. It does not split the input sequences but the
182185
batch into smaller batches. As a result, the input sequences remain untouched.
183186

184187
In the following example, we split the batch between adapters `i`, `k` and `l`. The `batch_sizes`parameter specifies
@@ -232,6 +235,50 @@ print("STS-B adapter output:", output1[0].item())
232235
print("MRPC adapter output:", bool(torch.argmax(output2[0]).item()))
233236
```
234237

238+
## Averaging Outputs or Parameters
239+
240+
Following approaches of ensembling full models at inference time for better generalization, recent work on adapters has explored methods of averaging pre-trained adapters.
241+
This includes averaging output representations of adapters ([Wang et al., 2021](https://arxiv.org/pdf/2109.04877.pdf)) as well as averaging adapter parameters ([Wang et al., 2022](https://arxiv.org/pdf/2205.12410.pdf), [Chronopoulou et al., 2023](https://aclanthology.org/2023.findings-eacl.153.pdf)).
242+
`adapters` provides built-in support for both types of inference time averaging methods.
243+
244+
### Output averaging
245+
246+
Output averaging allows to dynamically aggregate the output representations of multiple adapters in a model forward pass via weighted averaging.
247+
This is realized via the `Average` composition block that works similar to other composition blocks.
248+
In the example below, the three adapters are averaged with the weights `0.1` for `m`, `0.6` for `n` and `0.3` for `o`.
249+
250+
```python
251+
import adapters.composition as ac
252+
253+
// ...
254+
255+
model.add_adapter("m")
256+
model.add_adapter("n")
257+
model.add_adapter("o")
258+
259+
model.active_adapters = ac.Average("m", "n", "o", weights=[0.1, 0.6, 0.3])
260+
```
261+
262+
### Parameter averaging
263+
264+
Parameter averaging enables creating a new adapter via weighted averaging of the parameters of multiple pre-trained adapters.
265+
As this process is typically not done dynamically at runtime, `adapters` provides `average_adapter()` as a dedicated method for parameter averaging.
266+
In the example below, the parameters of the adapters `m`, `n` and `o` are averaged (with weights `0.1` `0.6` and `0.3`, respectively) to create a new adapter `avg`.
267+
Note that for this to succeed, all averaged adapters must use the same adapter configuration.
268+
269+
```python
270+
model.add_adapter("m")
271+
model.add_adapter("n")
272+
model.add_adapter("o")
273+
274+
model.average_adapter("avg", ["m", "n", "o"], weights=[0.1, 0.6, 0.3])
275+
```
276+
277+
Compared to output averaging, parameter averaging of adapters has the advantage of not inducing any additional inference time relative to using a single adapter.
278+
279+
For both output and parameter averaging, passed weights are normalized by default.
280+
To disable normalization, pass `normalize_weights=False`.
281+
235282
## Nesting composition blocks
236283

237284
Of course, it is also possible to combine different composition blocks in one adapter setup.

src/adapters/composition.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
from collections.abc import Sequence
3-
from typing import List, Set, Union
3+
from typing import List, Optional, Set, Union
44

55

66
class AdapterCompositionBlock(Sequence):
@@ -87,13 +87,33 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], b
8787
self.batch_sizes = batch_sizes if isinstance(batch_sizes, list) else [batch_sizes] * len(split_adapters)
8888

8989

90+
class Average(AdapterCompositionBlock):
91+
def __init__(
92+
self,
93+
*average_adapters: List[Union[AdapterCompositionBlock, str]],
94+
weights: Optional[List[float]] = None,
95+
normalize_weights: bool = True
96+
):
97+
super().__init__(*average_adapters)
98+
if weights is not None:
99+
# normalize weights
100+
if normalize_weights:
101+
sum_weights = sum(weights) if weights else 1
102+
self.weights = [w / sum_weights for w in weights]
103+
else:
104+
self.weights = weights
105+
else:
106+
self.weights = [1 / len(average_adapters)] * len(average_adapters)
107+
108+
90109
# Mapping each composition block type to the allowed nested types
91110
ALLOWED_NESTINGS = {
92-
Stack: [str, Fuse, Split, Parallel, BatchSplit],
111+
Stack: [str, Fuse, Split, Parallel, BatchSplit, Average],
93112
Fuse: [str, Stack],
94-
Split: [str, Split, Stack, BatchSplit],
95-
Parallel: [str, Stack, BatchSplit],
96-
BatchSplit: [str, Stack, Split, BatchSplit],
113+
Split: [str, Split, Stack, BatchSplit, Average],
114+
Parallel: [str, Stack, BatchSplit, Average],
115+
BatchSplit: [str, Stack, Split, BatchSplit, Average],
116+
Average: [str, Stack, Split, BatchSplit],
97117
}
98118

99119
# Some composition blocks might not be supported by all models.

src/adapters/layer.py

Lines changed: 129 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from abc import ABCMeta, abstractmethod
2-
from typing import List, Mapping, Union
2+
from typing import Dict, List, Mapping, Union
33

44
import numpy as np
55
import torch
66
from torch import nn
77

8-
from .composition import AdapterCompositionBlock, BatchSplit, Fuse, Parallel, Split, Stack, adjust_tensors_for_parallel
8+
from .composition import (
9+
AdapterCompositionBlock,
10+
Average,
11+
BatchSplit,
12+
Fuse,
13+
Parallel,
14+
Split,
15+
Stack,
16+
adjust_tensors_for_parallel,
17+
)
918
from .configuration import AdapterConfig
1019
from .context import AdapterSetup, ForwardContext
1120
from .modeling import Adapter, BertFusion, ParallelAdapter
@@ -71,7 +80,11 @@ def _store_fusion_attentions(self, fusion_name, attentions):
7180
attention_cache[fusion_name][self.layer_idx][self.location_key] = attentions
7281

7382
@abstractmethod
74-
def add_adapter(self, adapter_name: str, layer_idx: int):
83+
def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
84+
raise NotImplementedError()
85+
86+
@abstractmethod
87+
def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool:
7588
raise NotImplementedError()
7689

7790
@abstractmethod
@@ -105,7 +118,7 @@ def init_adapters(self, config):
105118
self.adapters = nn.ModuleDict(dict())
106119
self.adapter_fusion_layer = nn.ModuleDict(dict())
107120

108-
def add_adapter(self, adapter_name: str, layer_idx: int):
121+
def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
109122
self.layer_idx = layer_idx
110123
adapter_config = self.config.adapters.match(
111124
adapter_name,
@@ -139,6 +152,31 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
139152
)
140153
adapter.train(self.training) # make sure training mode is consistent
141154
self.adapters[adapter_name] = adapter
155+
return True
156+
157+
return False
158+
159+
def average_adapter(self, adapter_name: str, input_adapters: Dict[str, float]) -> bool:
160+
# add new adapter
161+
if self.add_adapter(adapter_name, self.layer_idx):
162+
# average weights
163+
avg_state_dict = {}
164+
for name, weight in input_adapters.items():
165+
if name in self.adapters:
166+
module = self.adapters[name]
167+
for k, v in module.state_dict().items():
168+
if k in avg_state_dict:
169+
avg_state_dict[k] += weight * v
170+
else:
171+
avg_state_dict[k] = weight * v
172+
else:
173+
self.delete_adapter(adapter_name) # clean up before raising error
174+
raise ValueError("Adapter {} not found.".format(name))
175+
# load averaged weights
176+
self.adapters[adapter_name].load_state_dict(avg_state_dict)
177+
return True
178+
179+
return False
142180

143181
def delete_adapter(self, adapter_name: str):
144182
if adapter_name in self.adapters:
@@ -225,7 +263,12 @@ def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer
225263
hidden_states = self.adapter_batchsplit(
226264
adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1
227265
)
228-
# Case 5: We have a single adapter which is part of this module -> forward pass
266+
# Case 5: We have a nested average block -> call average method
267+
elif isinstance(adapter_stack_layer, Average):
268+
hidden_states = self.adapter_average_output(
269+
adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1
270+
)
271+
# Case 6: We have a single adapter which is part of this module -> forward pass
229272
elif adapter_stack_layer in self.adapters:
230273
adapter_layer = self.adapters[adapter_stack_layer]
231274
hidden_states, _, residual = adapter_layer.pre_forward(hidden_states, input_tensor, layer_norm)
@@ -341,7 +384,12 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer
341384
split_hidden_states[i] = self.adapter_batchsplit(
342385
adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1
343386
)
344-
# Case 4: We have a single adapter which is part of this module -> forward pass
387+
# Case 4: We have a nested average -> call average method
388+
elif isinstance(adapter_block, Average):
389+
split_hidden_states[i] = self.adapter_average_output(
390+
adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1
391+
)
392+
# Case 5: We have a single adapter which is part of this module -> forward pass
345393
elif adapter_block in self.adapters:
346394
adapter_layer = self.adapters[adapter_block]
347395
context = ForwardContext.get_context()
@@ -352,7 +400,7 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer
352400
)
353401
split_hidden_states[i] = layer_output[0]
354402
self._store_gating_score(adapter_block, layer_output[-1])
355-
# Case 5: nesting other composition blocks is invalid
403+
# Case 6: nesting other composition blocks is invalid
356404
elif isinstance(adapter_block, AdapterCompositionBlock):
357405
raise ValueError(
358406
"Invalid adapter setup. Cannot nest {} in {}".format(
@@ -403,7 +451,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
403451
lvl=lvl + 1,
404452
)
405453
children_hidden.append(child_hidden_states)
406-
# Case 2. We have a nested batchsplit block -> call batchsplit method
454+
# Case 2: We have a nested batchsplit block -> call batchsplit method
407455
elif isinstance(child, BatchSplit):
408456
child_hidden_states = self.adapter_batchsplit(
409457
child,
@@ -413,7 +461,17 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
413461
lvl=lvl + 1,
414462
)
415463
children_hidden.append(child_hidden_states)
416-
# Case 3: We have a single adapter which is part of this module -> forward pass
464+
# Case 3: We have a nested average block -> call average method
465+
elif isinstance(child, Average):
466+
child_hidden_states = self.adapter_average_output(
467+
child,
468+
hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size],
469+
input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size],
470+
layer_norm,
471+
lvl=lvl + 1,
472+
)
473+
children_hidden.append(child_hidden_states)
474+
# Case 4: We have a single adapter which is part of this module -> forward pass
417475
elif child in self.adapters:
418476
adapter_layer = self.adapters[child]
419477
context = ForwardContext.get_context()
@@ -425,7 +483,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
425483
child_hidden_states = layer_output[0]
426484
self._store_gating_score(child, layer_output[-1])
427485
children_hidden.append(child_hidden_states)
428-
# Case 4: nesting other composition blocks is invalid
486+
# Case 5: nesting other composition blocks is invalid
429487
elif isinstance(child, AdapterCompositionBlock):
430488
raise ValueError(
431489
"Invalid adapter setup. Cannot nest {} in {}".format(
@@ -487,7 +545,17 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
487545
lvl=lvl + 1,
488546
)
489547
children_hidden.append(child)
490-
# Case 4: We have a single adapter which is part of this module -> forward pass
548+
# Case 4: We have a nested average block -> call average method
549+
elif isinstance(adapter_block, Average):
550+
child = self.adapter_average_output(
551+
adapter_block,
552+
hidden_states[batch_idx[0] : batch_idx[1]],
553+
input_tensor[batch_idx[0] : batch_idx[1]],
554+
layer_norm,
555+
lvl=lvl + 1,
556+
)
557+
children_hidden.append(child)
558+
# Case 5: We have a single adapter which is part of this module -> forward pass
491559
elif adapter_block in self.adapters:
492560

493561
adapter_layer = self.adapters[adapter_block]
@@ -499,7 +567,7 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
499567
)
500568
children_hidden.append(layer_output[0])
501569
self._store_gating_score(adapter_block, layer_output[-1])
502-
# Case 5: nesting other composition blocks is invalid
570+
# Case 6: nesting other composition blocks is invalid
503571
elif isinstance(adapter_block, AdapterCompositionBlock):
504572
raise ValueError(
505573
"Invalid adapter setup. Cannot nest {} in {}".format(
@@ -513,6 +581,53 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
513581
hidden_states = torch.cat(children_hidden, 0)
514582
return hidden_states
515583

584+
def adapter_average_output(self, adapter_setup: Average, hidden_states, input_tensor, layer_norm, lvl=0):
585+
"""
586+
For averaging the output representations of multiple adapters.
587+
"""
588+
context = ForwardContext.get_context()
589+
590+
# We assume all adapters have the same config
591+
first_adapter = self.adapters[adapter_setup.first()]
592+
hidden_states, _, residual = first_adapter.pre_forward(hidden_states, input_tensor, layer_norm)
593+
594+
children_hidden = []
595+
596+
for adapter_block in adapter_setup:
597+
# Case 1: We have a nested stack -> call stack method
598+
if isinstance(adapter_block, Stack):
599+
child, _, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1)
600+
children_hidden.append(child)
601+
# Case 2: We have a nested split block -> call split method
602+
elif isinstance(adapter_block, Split):
603+
child = self.adapter_split(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1)
604+
children_hidden.append(child)
605+
# Case 3: We have a nested batch split block -> call batchsplit method
606+
elif isinstance(adapter_block, BatchSplit):
607+
child = self.adapter_batchsplit(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1)
608+
children_hidden.append(child)
609+
# Case 4: We have a single adapter which is part of this module -> forward pass
610+
elif adapter_block in self.adapters:
611+
adapter_layer = self.adapters[adapter_block]
612+
layer_output = adapter_layer(
613+
hidden_states, residual_input=residual, output_gating=context.output_adapter_gating_scores
614+
)
615+
children_hidden.append(layer_output[0])
616+
self._store_gating_score(adapter_block, layer_output[-1])
617+
# Case 5: nesting other composition blocks is invalid
618+
elif isinstance(adapter_block, AdapterCompositionBlock):
619+
raise ValueError(
620+
"Invalid adapter setup. Cannot nest {} in {}".format(
621+
adapter_block.__class__.__name__, adapter_setup.__class__.__name__
622+
)
623+
)
624+
# Case X: No adapter which is part of this module -> ignore
625+
626+
weights = torch.tensor(adapter_setup.weights).unsqueeze(1).unsqueeze(1).to(hidden_states.device)
627+
hidden_states = torch.mean(torch.cat(children_hidden, 0) * weights, 0)
628+
629+
return hidden_states
630+
516631
def adapter_layer_forward(self, hidden_states, residual_input, layer_norm):
517632
"""Forward pass through the adapter layer.
518633
NOTE: This method should only be called if the calling module directly inherits from AdapterLayer. Otherwise,
@@ -550,6 +665,8 @@ def adapter_layer_forward(self, hidden_states, residual_input, layer_norm):
550665
)
551666
elif isinstance(adapter_setup, BatchSplit):
552667
hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, residual_input, layer_norm)
668+
elif isinstance(adapter_setup, Average):
669+
hidden_states = self.adapter_average_output(adapter_setup, hidden_states, residual_input, layer_norm)
553670
else:
554671
raise ValueError(f"Invalid adapter setup {adapter_setup}")
555672

0 commit comments

Comments
 (0)