Skip to content

Commit 60ec6bf

Browse files
authored
Merge pull request #52 from saprmarks/remove-high-norm
Add optional mask for high norm tokens
2 parents 227da89 + 1d2737a commit 60ec6bf

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

dictionary_learning/buffer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class ActivationBuffer:
1515
"""
1616
Implements a buffer of activations. The buffer stores activations from a model,
1717
yields them in batches, and refreshes them when the buffer is less than half full.
18+
19+
max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default.
20+
This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness.
1821
"""
1922
def __init__(self,
2023
data, # generator which yields text data
@@ -28,8 +31,9 @@ def __init__(self,
2831
out_batch_size=8192, # size of batches in which to yield activations
2932
device='cpu', # device on which to store the activations
3033
remove_bos: bool = False,
31-
add_special_tokens: bool = True,
32-
):
34+
add_special_tokens: bool = True,
35+
max_activation_norm_multiple: int | None = None,
36+
):
3337

3438
if io not in ['in', 'out']:
3539
raise ValueError("io must be either 'in' or 'out'")
@@ -58,6 +62,7 @@ def __init__(self,
5862
self.device = device
5963
self.add_special_tokens = add_special_tokens
6064
self.remove_bos = remove_bos
65+
self.remove_high_norm = max_activation_norm_multiple
6166

6267
if remove_bos and self.model.tokenizer.bos_token_id is None:
6368
print(
@@ -155,6 +160,13 @@ def refresh(self):
155160
first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask
156161
mask = mask & ~first_one
157162

163+
if self.remove_high_norm is not None:
164+
# some models (like Qwen) have random high norm activation sinks which reduce training effectiveness
165+
norms_BL = hidden_states.norm(dim=-1)
166+
median_norm = norms_BL.median()
167+
norm_mask = norms_BL > median_norm * self.remove_high_norm
168+
mask = mask & ~norm_mask
169+
158170
hidden_states = hidden_states[mask]
159171

160172
remaining_space = self.activation_buffer_size - current_idx

dictionary_learning/pytorch_buffer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class ActivationBuffer:
7373
"""
7474
Implements a buffer of activations. The buffer stores activations from a model,
7575
yields them in batches, and refreshes them when the buffer is less than half full.
76+
77+
max_activation_norm_multiple: remove all activations with norm greater than median norm * max_activation_norm_multiple. 10 is a good default.
78+
This is useful for models like Qwen which have random, unpredictable high norm activation sinks which reduce training effectiveness.
7679
"""
7780

7881
def __init__(
@@ -89,6 +92,7 @@ def __init__(
8992
device="cpu", # device on which to store the activations
9093
remove_bos: bool = False,
9194
add_special_tokens: bool = True,
95+
max_activation_norm_multiple: int | None = None,
9296
):
9397
if io not in ["in", "out"]:
9498
raise ValueError("io must be either 'in' or 'out'")
@@ -120,6 +124,7 @@ def __init__(
120124
self.add_special_tokens = add_special_tokens
121125
self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
122126
self.remove_bos = remove_bos
127+
self.remove_high_norm = max_activation_norm_multiple
123128

124129
if remove_bos and self.tokenizer.bos_token_id is None:
125130
print(
@@ -208,6 +213,13 @@ def refresh(self):
208213
first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask
209214
mask = mask & ~first_one
210215

216+
if self.remove_high_norm is not None:
217+
# some models (like Qwen) have random high norm activation sinks which reduce training effectiveness
218+
norms_BL = hidden_states.norm(dim=-1)
219+
median_norm = norms_BL.median()
220+
norm_mask = norms_BL > median_norm * self.remove_high_norm
221+
mask = mask & ~norm_mask
222+
211223
hidden_states = hidden_states[mask]
212224

213225
remaining_space = self.activation_buffer_size - current_idx

0 commit comments

Comments
 (0)