@@ -15,6 +15,9 @@ class ActivationBuffer:
1515 """
1616 Implements a buffer of activations. The buffer stores activations from a model,
1717 yields them in batches, and refreshes them when the buffer is less than half full.
18+
19+ max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default.
20+ This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness.
1821 """
1922 def __init__ (self ,
2023 data , # generator which yields text data
@@ -28,8 +31,9 @@ def __init__(self,
2831 out_batch_size = 8192 , # size of batches in which to yield activations
2932 device = 'cpu' , # device on which to store the activations
3033 remove_bos : bool = False ,
31- add_special_tokens : bool = True ,
32- ):
34+ add_special_tokens : bool = True ,
35+ max_activation_norm_multiple : int | None = None ,
36+ ):
3337
3438 if io not in ['in' , 'out' ]:
3539 raise ValueError ("io must be either 'in' or 'out'" )
@@ -58,6 +62,7 @@ def __init__(self,
5862 self .device = device
5963 self .add_special_tokens = add_special_tokens
6064 self .remove_bos = remove_bos
65+ self .remove_high_norm = max_activation_norm_multiple
6166
6267 if remove_bos and self .model .tokenizer .bos_token_id is None :
6368 print (
@@ -155,6 +160,13 @@ def refresh(self):
155160 first_one = (mask .to (t .int64 ).cumsum (dim = 1 ) == 1 ) & mask
156161 mask = mask & ~ first_one
157162
163+ if self .remove_high_norm is not None :
164+ # some models (like Qwen) have random high norm activation sinks which reduce training effectiveness
165+ norms_BL = hidden_states .norm (dim = - 1 )
166+ median_norm = norms_BL .median ()
167+ norm_mask = norms_BL > median_norm * self .remove_high_norm
168+ mask = mask & ~ norm_mask
169+
158170 hidden_states = hidden_states [mask ]
159171
160172 remaining_space = self .activation_buffer_size - current_idx
0 commit comments