Skip to content

Commit f4abe32

Browse files
authored
Merge pull request #50 from saprmarks/topk_annealing
Topk annealing
2 parents eb07533 + 8f87957 commit f4abe32

File tree

6 files changed

+230
-44
lines changed

6 files changed

+230
-44
lines changed

dictionary_learning/buffer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,14 @@ def __init__(self,
5656
self.refresh_batch_size = refresh_batch_size
5757
self.out_batch_size = out_batch_size
5858
self.device = device
59-
self.remove_bos = remove_bos and (self.model.tokenizer.bos_token_id is not None)
6059
self.add_special_tokens = add_special_tokens
60+
self.remove_bos = remove_bos
61+
62+
if remove_bos and self.model.tokenizer.bos_token_id is None:
63+
print(
64+
"\n\n\nWARNING: remove_bos is True but tokenizer does not have a bos token. We are removing the first non-pad token instead. Don't use sequence packing.\n\n\n"
65+
)
66+
6167

6268
def __iter__(self):
6369
return self
@@ -138,9 +144,17 @@ def refresh(self):
138144
hidden_states = hidden_states.value
139145
if isinstance(hidden_states, tuple):
140146
hidden_states = hidden_states[0]
147+
141148
if self.remove_bos:
142-
bos_mask = (input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id)
143-
mask = mask & ~bos_mask
149+
if self.model.tokenizer.bos_token_id is not None:
150+
bos_mask = input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id
151+
mask = mask & ~bos_mask
152+
else:
153+
# some models (like Qwen) don't have a bos token, so we need to remove the first non-pad token
154+
assert mask.dim() == 2, "expected shape (batch_size, seq_len)"
155+
first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask
156+
mask = mask & ~first_one
157+
144158
hidden_states = hidden_states[mask]
145159

146160
remaining_space = self.activation_buffer_size - current_idx

dictionary_learning/pytorch_buffer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,12 @@ def __init__(
119119
self.device = device
120120
self.add_special_tokens = add_special_tokens
121121
self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
122-
self.remove_bos = remove_bos and (self.tokenizer.bos_token_id is not None)
122+
self.remove_bos = remove_bos
123+
124+
if remove_bos and self.tokenizer.bos_token_id is None:
125+
print(
126+
"\n\n\nWARNING: remove_bos is True but tokenizer does not have a bos token. We are removing the first non-pad token instead. Don't use sequence packing.\n\n\n"
127+
)
123128

124129
if not self.tokenizer.pad_token:
125130
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -192,10 +197,17 @@ def refresh(self):
192197
with t.no_grad():
193198
input = self.tokenized_batch()
194199
hidden_states = collect_activations(self.model, self.submodule, input)
195-
mask = (input["attention_mask"] != 0)
200+
mask = input["attention_mask"] != 0
196201
if self.remove_bos:
197-
bos_mask = (input["input_ids"] == self.tokenizer.bos_token_id)
198-
mask = mask & ~bos_mask
202+
if self.tokenizer.bos_token_id is not None:
203+
bos_mask = input["input_ids"] == self.tokenizer.bos_token_id
204+
mask = mask & ~bos_mask
205+
else:
206+
# some models (like Qwen) don't have a bos token, so we need to remove the first non-pad token
207+
assert mask.dim() == 2, "expected shape (batch_size, seq_len)"
208+
first_one = (mask.to(t.int64).cumsum(dim=1) == 1) & mask
209+
mask = mask & ~first_one
210+
199211
hidden_states = hidden_states[mask]
200212

201213
remaining_space = self.activation_buffer_size - current_idx

dictionary_learning/trainers/batch_top_k.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@ def __init__(self, activation_dim: int, dict_size: int, k: int):
3434
self.encoder.bias.data.zero_()
3535
self.b_dec = nn.Parameter(t.zeros(activation_dim))
3636

37-
def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True):
37+
def encode(
38+
self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True
39+
):
3840
post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))
3941

4042
if use_threshold:
41-
encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold)
43+
encoded_acts_BF = post_relu_feat_acts_BF * (
44+
post_relu_feat_acts_BF > self.threshold
45+
)
4246
else:
4347
# Flatten and perform batch top-k
4448
flattened_acts = post_relu_feat_acts_BF.flatten()
@@ -105,6 +109,7 @@ def __init__(
105109
decay_start: Optional[int] = None, # when does the lr decay start
106110
threshold_beta: float = 0.999,
107111
threshold_start_step: int = 1000,
112+
k_anneal_steps: Optional[int] = None,
108113
seed: Optional[int] = None,
109114
device: Optional[str] = None,
110115
wandb_name: str = "BatchTopKSAE",
@@ -122,6 +127,7 @@ def __init__(
122127
self.k = k
123128
self.threshold_beta = threshold_beta
124129
self.threshold_start_step = threshold_start_step
130+
self.k_anneal_steps = k_anneal_steps
125131

126132
if seed is not None:
127133
t.manual_seed(seed)
@@ -146,17 +152,43 @@ def __init__(
146152
self.dead_feature_threshold = 10_000_000
147153
self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper
148154
self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device)
149-
self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"]
155+
self.logging_parameters = [
156+
"effective_l0",
157+
"dead_features",
158+
"pre_norm_auxk_loss",
159+
]
150160
self.effective_l0 = -1
151161
self.dead_features = -1
152162
self.pre_norm_auxk_loss = -1
153163

154-
self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999))
164+
self.optimizer = t.optim.Adam(
165+
self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)
166+
)
155167

156168
lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start)
157169

158170
self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
159171

172+
def update_annealed_k(
173+
self, step: int, activation_dim: int, k_anneal_steps: Optional[int] = None
174+
) -> None:
175+
"""Update k buffer in-place with annealed value"""
176+
if k_anneal_steps is None:
177+
return
178+
179+
assert 0 <= k_anneal_steps < self.steps, (
180+
"k_anneal_steps must be >= 0 and < steps."
181+
)
182+
# self.k is the target k set for the trainer, not the dictionary's current k
183+
assert activation_dim > self.k, "activation_dim must be greater than k"
184+
185+
step = min(step, k_anneal_steps)
186+
ratio = step / k_anneal_steps
187+
annealed_value = activation_dim * (1 - ratio) + self.k * ratio
188+
189+
# Update in-place
190+
self.ae.k.fill_(int(annealed_value))
191+
160192
def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor):
161193
dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold
162194
self.dead_features = int(dead_features.sum())
@@ -170,19 +202,28 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor)
170202
auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)
171203

172204
auxk_buffer_BF = t.zeros_like(post_relu_acts_BF)
173-
auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts)
205+
auxk_acts_BF = auxk_buffer_BF.scatter_(
206+
dim=-1, index=auxk_indices, src=auxk_acts
207+
)
174208

175209
# Note: decoder(), not decode(), as we don't want to apply the bias
176210
x_reconstruct_aux = self.ae.decoder(auxk_acts_BF)
177211
l2_loss_aux = (
178-
(residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean()
212+
(residual_BD.float() - x_reconstruct_aux.float())
213+
.pow(2)
214+
.sum(dim=-1)
215+
.mean()
179216
)
180217

181218
self.pre_norm_auxk_loss = l2_loss_aux
182219

183220
# normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614
184-
residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape)
185-
loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean()
221+
residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(
222+
residual_BD.shape
223+
)
224+
loss_denom = (
225+
(residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean()
226+
)
186227
normalized_auxk_loss = l2_loss_aux / loss_denom
187228

188229
return normalized_auxk_loss.nan_to_num(0.0)
@@ -220,7 +261,7 @@ def loss(self, x, step=None, logging=False):
220261

221262
e = x - x_hat
222263

223-
self.effective_l0 = self.k
264+
self.effective_l0 = self.ae.k.item()
224265

225266
num_tokens_in_step = x.size(0)
226267
did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool)
@@ -239,7 +280,11 @@ def loss(self, x, step=None, logging=False):
239280
x,
240281
x_hat,
241282
f,
242-
{"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()},
283+
{
284+
"l2_loss": l2_loss.item(),
285+
"auxk_loss": auxk_loss.item(),
286+
"loss": loss.item(),
287+
},
243288
)
244289

245290
def update(self, step, x):
@@ -263,6 +308,7 @@ def update(self, step, x):
263308
self.optimizer.step()
264309
self.optimizer.zero_grad()
265310
self.scheduler.step()
311+
self.update_annealed_k(step, self.ae.activation_dim, self.k_anneal_steps)
266312

267313
# Make sure the decoder is still unit-norm
268314
self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm(

0 commit comments

Comments
 (0)