11from abc import ABCMeta , abstractmethod
2- from typing import List , Mapping , Union
2+ from typing import Dict , List , Mapping , Union
33
44import numpy as np
55import torch
66from torch import nn
77
8- from .composition import AdapterCompositionBlock , BatchSplit , Fuse , Parallel , Split , Stack , adjust_tensors_for_parallel
8+ from .composition import (
9+ AdapterCompositionBlock ,
10+ Average ,
11+ BatchSplit ,
12+ Fuse ,
13+ Parallel ,
14+ Split ,
15+ Stack ,
16+ adjust_tensors_for_parallel ,
17+ )
918from .configuration import AdapterConfig
1019from .context import AdapterSetup , ForwardContext
1120from .modeling import Adapter , BertFusion , ParallelAdapter
@@ -71,7 +80,11 @@ def _store_fusion_attentions(self, fusion_name, attentions):
7180 attention_cache [fusion_name ][self .layer_idx ][self .location_key ] = attentions
7281
7382 @abstractmethod
74- def add_adapter (self , adapter_name : str , layer_idx : int ):
83+ def add_adapter (self , adapter_name : str , layer_idx : int ) -> bool :
84+ raise NotImplementedError ()
85+
86+ @abstractmethod
87+ def average_adapter (self , adapter_name : str , input_adapters : Dict [str , float ]) -> bool :
7588 raise NotImplementedError ()
7689
7790 @abstractmethod
@@ -105,7 +118,7 @@ def init_adapters(self, config):
105118 self .adapters = nn .ModuleDict (dict ())
106119 self .adapter_fusion_layer = nn .ModuleDict (dict ())
107120
108- def add_adapter (self , adapter_name : str , layer_idx : int ):
121+ def add_adapter (self , adapter_name : str , layer_idx : int ) -> bool :
109122 self .layer_idx = layer_idx
110123 adapter_config = self .config .adapters .match (
111124 adapter_name ,
@@ -139,6 +152,31 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
139152 )
140153 adapter .train (self .training ) # make sure training mode is consistent
141154 self .adapters [adapter_name ] = adapter
155+ return True
156+
157+ return False
158+
159+ def average_adapter (self , adapter_name : str , input_adapters : Dict [str , float ]) -> bool :
160+ # add new adapter
161+ if self .add_adapter (adapter_name , self .layer_idx ):
162+ # average weights
163+ avg_state_dict = {}
164+ for name , weight in input_adapters .items ():
165+ if name in self .adapters :
166+ module = self .adapters [name ]
167+ for k , v in module .state_dict ().items ():
168+ if k in avg_state_dict :
169+ avg_state_dict [k ] += weight * v
170+ else :
171+ avg_state_dict [k ] = weight * v
172+ else :
173+ self .delete_adapter (adapter_name ) # clean up before raising error
174+ raise ValueError ("Adapter {} not found." .format (name ))
175+ # load averaged weights
176+ self .adapters [adapter_name ].load_state_dict (avg_state_dict )
177+ return True
178+
179+ return False
142180
143181 def delete_adapter (self , adapter_name : str ):
144182 if adapter_name in self .adapters :
@@ -225,7 +263,12 @@ def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer
225263 hidden_states = self .adapter_batchsplit (
226264 adapter_stack_layer , hidden_states , input_tensor , layer_norm , lvl = lvl + 1
227265 )
228- # Case 5: We have a single adapter which is part of this module -> forward pass
266+ # Case 5: We have a nested average block -> call average method
267+ elif isinstance (adapter_stack_layer , Average ):
268+ hidden_states = self .adapter_average_output (
269+ adapter_stack_layer , hidden_states , input_tensor , layer_norm , lvl = lvl + 1
270+ )
271+ # Case 6: We have a single adapter which is part of this module -> forward pass
229272 elif adapter_stack_layer in self .adapters :
230273 adapter_layer = self .adapters [adapter_stack_layer ]
231274 hidden_states , _ , residual = adapter_layer .pre_forward (hidden_states , input_tensor , layer_norm )
@@ -341,7 +384,12 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer
341384 split_hidden_states [i ] = self .adapter_batchsplit (
342385 adapter_block , split_hidden_states [i ], split_input_tensor [i ], layer_norm , lvl = lvl + 1
343386 )
344- # Case 4: We have a single adapter which is part of this module -> forward pass
387+ # Case 4: We have a nested average -> call average method
388+ elif isinstance (adapter_block , Average ):
389+ split_hidden_states [i ] = self .adapter_average_output (
390+ adapter_block , split_hidden_states [i ], split_input_tensor [i ], layer_norm , lvl = lvl + 1
391+ )
392+ # Case 5: We have a single adapter which is part of this module -> forward pass
345393 elif adapter_block in self .adapters :
346394 adapter_layer = self .adapters [adapter_block ]
347395 context = ForwardContext .get_context ()
@@ -352,7 +400,7 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer
352400 )
353401 split_hidden_states [i ] = layer_output [0 ]
354402 self ._store_gating_score (adapter_block , layer_output [- 1 ])
355- # Case 5 : nesting other composition blocks is invalid
403+ # Case 6 : nesting other composition blocks is invalid
356404 elif isinstance (adapter_block , AdapterCompositionBlock ):
357405 raise ValueError (
358406 "Invalid adapter setup. Cannot nest {} in {}" .format (
@@ -403,7 +451,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
403451 lvl = lvl + 1 ,
404452 )
405453 children_hidden .append (child_hidden_states )
406- # Case 2. We have a nested batchsplit block -> call batchsplit method
454+ # Case 2: We have a nested batchsplit block -> call batchsplit method
407455 elif isinstance (child , BatchSplit ):
408456 child_hidden_states = self .adapter_batchsplit (
409457 child ,
@@ -413,7 +461,17 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
413461 lvl = lvl + 1 ,
414462 )
415463 children_hidden .append (child_hidden_states )
416- # Case 3: We have a single adapter which is part of this module -> forward pass
464+ # Case 3: We have a nested average block -> call average method
465+ elif isinstance (child , Average ):
466+ child_hidden_states = self .adapter_average_output (
467+ child ,
468+ hidden_states [i * orig_batch_size : (i + 1 ) * orig_batch_size ],
469+ input_tensor [i * orig_batch_size : (i + 1 ) * orig_batch_size ],
470+ layer_norm ,
471+ lvl = lvl + 1 ,
472+ )
473+ children_hidden .append (child_hidden_states )
474+ # Case 4: We have a single adapter which is part of this module -> forward pass
417475 elif child in self .adapters :
418476 adapter_layer = self .adapters [child ]
419477 context = ForwardContext .get_context ()
@@ -425,7 +483,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
425483 child_hidden_states = layer_output [0 ]
426484 self ._store_gating_score (child , layer_output [- 1 ])
427485 children_hidden .append (child_hidden_states )
428- # Case 4 : nesting other composition blocks is invalid
486+ # Case 5 : nesting other composition blocks is invalid
429487 elif isinstance (child , AdapterCompositionBlock ):
430488 raise ValueError (
431489 "Invalid adapter setup. Cannot nest {} in {}" .format (
@@ -487,7 +545,17 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
487545 lvl = lvl + 1 ,
488546 )
489547 children_hidden .append (child )
490- # Case 4: We have a single adapter which is part of this module -> forward pass
548+ # Case 4: We have a nested average block -> call average method
549+ elif isinstance (adapter_block , Average ):
550+ child = self .adapter_average_output (
551+ adapter_block ,
552+ hidden_states [batch_idx [0 ] : batch_idx [1 ]],
553+ input_tensor [batch_idx [0 ] : batch_idx [1 ]],
554+ layer_norm ,
555+ lvl = lvl + 1 ,
556+ )
557+ children_hidden .append (child )
558+ # Case 5: We have a single adapter which is part of this module -> forward pass
491559 elif adapter_block in self .adapters :
492560
493561 adapter_layer = self .adapters [adapter_block ]
@@ -499,7 +567,7 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
499567 )
500568 children_hidden .append (layer_output [0 ])
501569 self ._store_gating_score (adapter_block , layer_output [- 1 ])
502- # Case 5 : nesting other composition blocks is invalid
570+ # Case 6 : nesting other composition blocks is invalid
503571 elif isinstance (adapter_block , AdapterCompositionBlock ):
504572 raise ValueError (
505573 "Invalid adapter setup. Cannot nest {} in {}" .format (
@@ -513,6 +581,53 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
513581 hidden_states = torch .cat (children_hidden , 0 )
514582 return hidden_states
515583
584+ def adapter_average_output (self , adapter_setup : Average , hidden_states , input_tensor , layer_norm , lvl = 0 ):
585+ """
586+ For averaging the output representations of multiple adapters.
587+ """
588+ context = ForwardContext .get_context ()
589+
590+ # We assume all adapters have the same config
591+ first_adapter = self .adapters [adapter_setup .first ()]
592+ hidden_states , _ , residual = first_adapter .pre_forward (hidden_states , input_tensor , layer_norm )
593+
594+ children_hidden = []
595+
596+ for adapter_block in adapter_setup :
597+ # Case 1: We have a nested stack -> call stack method
598+ if isinstance (adapter_block , Stack ):
599+ child , _ , _ = self .adapter_stack (adapter_block , hidden_states , input_tensor , layer_norm , lvl = lvl + 1 )
600+ children_hidden .append (child )
601+ # Case 2: We have a nested split block -> call split method
602+ elif isinstance (adapter_block , Split ):
603+ child = self .adapter_split (adapter_block , hidden_states , input_tensor , layer_norm , lvl = lvl + 1 )
604+ children_hidden .append (child )
605+ # Case 3: We have a nested batch split block -> call batchsplit method
606+ elif isinstance (adapter_block , BatchSplit ):
607+ child = self .adapter_batchsplit (adapter_block , hidden_states , input_tensor , layer_norm , lvl = lvl + 1 )
608+ children_hidden .append (child )
609+ # Case 4: We have a single adapter which is part of this module -> forward pass
610+ elif adapter_block in self .adapters :
611+ adapter_layer = self .adapters [adapter_block ]
612+ layer_output = adapter_layer (
613+ hidden_states , residual_input = residual , output_gating = context .output_adapter_gating_scores
614+ )
615+ children_hidden .append (layer_output [0 ])
616+ self ._store_gating_score (adapter_block , layer_output [- 1 ])
617+ # Case 5: nesting other composition blocks is invalid
618+ elif isinstance (adapter_block , AdapterCompositionBlock ):
619+ raise ValueError (
620+ "Invalid adapter setup. Cannot nest {} in {}" .format (
621+ adapter_block .__class__ .__name__ , adapter_setup .__class__ .__name__
622+ )
623+ )
624+ # Case X: No adapter which is part of this module -> ignore
625+
626+ weights = torch .tensor (adapter_setup .weights ).unsqueeze (1 ).unsqueeze (1 ).to (hidden_states .device )
627+ hidden_states = torch .mean (torch .cat (children_hidden , 0 ) * weights , 0 )
628+
629+ return hidden_states
630+
516631 def adapter_layer_forward (self , hidden_states , residual_input , layer_norm ):
517632 """Forward pass through the adapter layer.
518633 NOTE: This method should only be called if the calling module directly inherits from AdapterLayer. Otherwise,
@@ -550,6 +665,8 @@ def adapter_layer_forward(self, hidden_states, residual_input, layer_norm):
550665 )
551666 elif isinstance (adapter_setup , BatchSplit ):
552667 hidden_states = self .adapter_batchsplit (adapter_setup , hidden_states , residual_input , layer_norm )
668+ elif isinstance (adapter_setup , Average ):
669+ hidden_states = self .adapter_average_output (adapter_setup , hidden_states , residual_input , layer_norm )
553670 else :
554671 raise ValueError (f"Invalid adapter setup { adapter_setup } " )
555672
0 commit comments