Skip to content

[Question] Why GatedSAETrainer do not implement resampling_neurons but GatedAnnealTrainer and StandardTrainer does? #53

@jasonrichdarmawan

Description

@jasonrichdarmawan

Question

Is this to faithfully follow a paper?

For context: I am new to SAE, sae_lens library, and I have never tried dictionary_learning library

For context 2: I have a problem with sparsity/dead_features metric being too high after I change sae_lens library implementation to use .detach() in via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec Code

For context 3: I noticed that sparsity/dead_feature metric do not decrease over the course of the training and I can't find resampling neurons logic in sae_lens library. So, now I am looking at dictionary_learning implementation on how to implement resampling neurons

GatedSAETrainer

class GatedSAETrainer(SAETrainer):

GatedAnnealTrainer

def resample_neurons(self, deads, activations):
with t.no_grad():
if deads.sum() == 0: return
print(f"resampling {deads.sum().item()} neurons")
# compute loss for each activation
losses = (activations - self.ae(activations)).norm(dim=-1)
# sample input to create encoder/decoder weights from
n_resample = min([deads.sum(), losses.shape[0]])
indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
sampled_vecs = activations[indices]
# reset encoder/decoder weights for dead neurons
alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2
self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
self.ae.encoder.bias[deads][:n_resample] = 0.
# reset Adam parameters for dead neurons
state_dict = self.optimizer.state_dict()['state']
## encoder weight
state_dict[1]['exp_avg'][deads] = 0.
state_dict[1]['exp_avg_sq'][deads] = 0.
## encoder bias
state_dict[2]['exp_avg'][deads] = 0.
state_dict[2]['exp_avg_sq'][deads] = 0.
## decoder weight
state_dict[3]['exp_avg'][:,deads] = 0.
state_dict[3]['exp_avg_sq'][:,deads] = 0.

StandardTrainer

def resample_neurons(self, deads, activations):
with t.no_grad():
if deads.sum() == 0: return
print(f"resampling {deads.sum().item()} neurons")
# compute loss for each activation
losses = (activations - self.ae(activations)).norm(dim=-1)
# sample input to create encoder/decoder weights from
n_resample = min([deads.sum(), losses.shape[0]])
indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
sampled_vecs = activations[indices]
# get norm of the living neurons
alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
# resample first n_resample dead neurons
deads[deads.nonzero()[n_resample:]] = False
self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2
self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
self.ae.encoder.bias[deads] = 0.
# reset Adam parameters for dead neurons
state_dict = self.optimizer.state_dict()['state']
## encoder weight
state_dict[1]['exp_avg'][deads] = 0.
state_dict[1]['exp_avg_sq'][deads] = 0.
## encoder bias
state_dict[2]['exp_avg'][deads] = 0.
state_dict[2]['exp_avg_sq'][deads] = 0.
## decoder weight
state_dict[3]['exp_avg'][:,deads] = 0.
state_dict[3]['exp_avg_sq'][:,deads] = 0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions