@@ -34,11 +34,15 @@ def __init__(self, activation_dim: int, dict_size: int, k: int):
3434 self .encoder .bias .data .zero_ ()
3535 self .b_dec = nn .Parameter (t .zeros (activation_dim ))
3636
37- def encode (self , x : t .Tensor , return_active : bool = False , use_threshold : bool = True ):
37+ def encode (
38+ self , x : t .Tensor , return_active : bool = False , use_threshold : bool = True
39+ ):
3840 post_relu_feat_acts_BF = nn .functional .relu (self .encoder (x - self .b_dec ))
3941
4042 if use_threshold :
41- encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self .threshold )
43+ encoded_acts_BF = post_relu_feat_acts_BF * (
44+ post_relu_feat_acts_BF > self .threshold
45+ )
4246 else :
4347 # Flatten and perform batch top-k
4448 flattened_acts = post_relu_feat_acts_BF .flatten ()
@@ -105,6 +109,7 @@ def __init__(
105109 decay_start : Optional [int ] = None , # when does the lr decay start
106110 threshold_beta : float = 0.999 ,
107111 threshold_start_step : int = 1000 ,
112+ k_anneal_steps : Optional [int ] = None ,
108113 seed : Optional [int ] = None ,
109114 device : Optional [str ] = None ,
110115 wandb_name : str = "BatchTopKSAE" ,
@@ -122,6 +127,7 @@ def __init__(
122127 self .k = k
123128 self .threshold_beta = threshold_beta
124129 self .threshold_start_step = threshold_start_step
130+ self .k_anneal_steps = k_anneal_steps
125131
126132 if seed is not None :
127133 t .manual_seed (seed )
@@ -146,17 +152,43 @@ def __init__(
146152 self .dead_feature_threshold = 10_000_000
147153 self .top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper
148154 self .num_tokens_since_fired = t .zeros (dict_size , dtype = t .long , device = device )
149- self .logging_parameters = ["effective_l0" , "dead_features" , "pre_norm_auxk_loss" ]
155+ self .logging_parameters = [
156+ "effective_l0" ,
157+ "dead_features" ,
158+ "pre_norm_auxk_loss" ,
159+ ]
150160 self .effective_l0 = - 1
151161 self .dead_features = - 1
152162 self .pre_norm_auxk_loss = - 1
153163
154- self .optimizer = t .optim .Adam (self .ae .parameters (), lr = self .lr , betas = (0.9 , 0.999 ))
164+ self .optimizer = t .optim .Adam (
165+ self .ae .parameters (), lr = self .lr , betas = (0.9 , 0.999 )
166+ )
155167
156168 lr_fn = get_lr_schedule (steps , warmup_steps , decay_start = decay_start )
157169
158170 self .scheduler = t .optim .lr_scheduler .LambdaLR (self .optimizer , lr_lambda = lr_fn )
159171
172+ def update_annealed_k (
173+ self , step : int , activation_dim : int , k_anneal_steps : Optional [int ] = None
174+ ) -> None :
175+ """Update k buffer in-place with annealed value"""
176+ if k_anneal_steps is None :
177+ return
178+
179+ assert 0 <= k_anneal_steps < self .steps , (
180+ "k_anneal_steps must be >= 0 and < steps."
181+ )
182+ # self.k is the target k set for the trainer, not the dictionary's current k
183+ assert activation_dim > self .k , "activation_dim must be greater than k"
184+
185+ step = min (step , k_anneal_steps )
186+ ratio = step / k_anneal_steps
187+ annealed_value = activation_dim * (1 - ratio ) + self .k * ratio
188+
189+ # Update in-place
190+ self .ae .k .fill_ (int (annealed_value ))
191+
160192 def get_auxiliary_loss (self , residual_BD : t .Tensor , post_relu_acts_BF : t .Tensor ):
161193 dead_features = self .num_tokens_since_fired >= self .dead_feature_threshold
162194 self .dead_features = int (dead_features .sum ())
@@ -170,19 +202,28 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor)
170202 auxk_acts , auxk_indices = auxk_latents .topk (k_aux , sorted = False )
171203
172204 auxk_buffer_BF = t .zeros_like (post_relu_acts_BF )
173- auxk_acts_BF = auxk_buffer_BF .scatter_ (dim = - 1 , index = auxk_indices , src = auxk_acts )
205+ auxk_acts_BF = auxk_buffer_BF .scatter_ (
206+ dim = - 1 , index = auxk_indices , src = auxk_acts
207+ )
174208
175209 # Note: decoder(), not decode(), as we don't want to apply the bias
176210 x_reconstruct_aux = self .ae .decoder (auxk_acts_BF )
177211 l2_loss_aux = (
178- (residual_BD .float () - x_reconstruct_aux .float ()).pow (2 ).sum (dim = - 1 ).mean ()
212+ (residual_BD .float () - x_reconstruct_aux .float ())
213+ .pow (2 )
214+ .sum (dim = - 1 )
215+ .mean ()
179216 )
180217
181218 self .pre_norm_auxk_loss = l2_loss_aux
182219
183220 # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614
184- residual_mu = residual_BD .mean (dim = 0 )[None , :].broadcast_to (residual_BD .shape )
185- loss_denom = (residual_BD .float () - residual_mu .float ()).pow (2 ).sum (dim = - 1 ).mean ()
221+ residual_mu = residual_BD .mean (dim = 0 )[None , :].broadcast_to (
222+ residual_BD .shape
223+ )
224+ loss_denom = (
225+ (residual_BD .float () - residual_mu .float ()).pow (2 ).sum (dim = - 1 ).mean ()
226+ )
186227 normalized_auxk_loss = l2_loss_aux / loss_denom
187228
188229 return normalized_auxk_loss .nan_to_num (0.0 )
@@ -220,7 +261,7 @@ def loss(self, x, step=None, logging=False):
220261
221262 e = x - x_hat
222263
223- self .effective_l0 = self .k
264+ self .effective_l0 = self .ae . k . item ()
224265
225266 num_tokens_in_step = x .size (0 )
226267 did_fire = t .zeros_like (self .num_tokens_since_fired , dtype = t .bool )
@@ -239,7 +280,11 @@ def loss(self, x, step=None, logging=False):
239280 x ,
240281 x_hat ,
241282 f ,
242- {"l2_loss" : l2_loss .item (), "auxk_loss" : auxk_loss .item (), "loss" : loss .item ()},
283+ {
284+ "l2_loss" : l2_loss .item (),
285+ "auxk_loss" : auxk_loss .item (),
286+ "loss" : loss .item (),
287+ },
243288 )
244289
245290 def update (self , step , x ):
@@ -263,6 +308,7 @@ def update(self, step, x):
263308 self .optimizer .step ()
264309 self .optimizer .zero_grad ()
265310 self .scheduler .step ()
311+ self .update_annealed_k (step , self .ae .activation_dim , self .k_anneal_steps )
266312
267313 # Make sure the decoder is still unit-norm
268314 self .ae .decoder .weight .data = set_decoder_norm_to_unit_norm (
0 commit comments