-
Notifications
You must be signed in to change notification settings - Fork 88
Description
Question
Is this to faithfully follow a paper?
For context: I am new to SAE, sae_lens library, and I have never tried dictionary_learning library
For context 2: I have a problem with sparsity/dead_features metric being too high after I change sae_lens library implementation to use .detach() in via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec Code
For context 3: I noticed that sparsity/dead_feature metric do not decrease over the course of the training and I can't find resampling neurons logic in sae_lens library. So, now I am looking at dictionary_learning implementation on how to implement resampling neurons
GatedSAETrainer
| class GatedSAETrainer(SAETrainer): |
GatedAnnealTrainer
dictionary_learning/dictionary_learning/trainers/gated_anneal.py
Lines 105 to 135 in 60ec6bf
| def resample_neurons(self, deads, activations): | |
| with t.no_grad(): | |
| if deads.sum() == 0: return | |
| print(f"resampling {deads.sum().item()} neurons") | |
| # compute loss for each activation | |
| losses = (activations - self.ae(activations)).norm(dim=-1) | |
| # sample input to create encoder/decoder weights from | |
| n_resample = min([deads.sum(), losses.shape[0]]) | |
| indices = t.multinomial(losses, num_samples=n_resample, replacement=False) | |
| sampled_vecs = activations[indices] | |
| # reset encoder/decoder weights for dead neurons | |
| alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() | |
| self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2 | |
| self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T | |
| self.ae.encoder.bias[deads][:n_resample] = 0. | |
| # reset Adam parameters for dead neurons | |
| state_dict = self.optimizer.state_dict()['state'] | |
| ## encoder weight | |
| state_dict[1]['exp_avg'][deads] = 0. | |
| state_dict[1]['exp_avg_sq'][deads] = 0. | |
| ## encoder bias | |
| state_dict[2]['exp_avg'][deads] = 0. | |
| state_dict[2]['exp_avg_sq'][deads] = 0. | |
| ## decoder weight | |
| state_dict[3]['exp_avg'][:,deads] = 0. | |
| state_dict[3]['exp_avg_sq'][:,deads] = 0. |
StandardTrainer
dictionary_learning/dictionary_learning/trainers/standard.py
Lines 76 to 109 in 60ec6bf
| def resample_neurons(self, deads, activations): | |
| with t.no_grad(): | |
| if deads.sum() == 0: return | |
| print(f"resampling {deads.sum().item()} neurons") | |
| # compute loss for each activation | |
| losses = (activations - self.ae(activations)).norm(dim=-1) | |
| # sample input to create encoder/decoder weights from | |
| n_resample = min([deads.sum(), losses.shape[0]]) | |
| indices = t.multinomial(losses, num_samples=n_resample, replacement=False) | |
| sampled_vecs = activations[indices] | |
| # get norm of the living neurons | |
| alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() | |
| # resample first n_resample dead neurons | |
| deads[deads.nonzero()[n_resample:]] = False | |
| self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 | |
| self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T | |
| self.ae.encoder.bias[deads] = 0. | |
| # reset Adam parameters for dead neurons | |
| state_dict = self.optimizer.state_dict()['state'] | |
| ## encoder weight | |
| state_dict[1]['exp_avg'][deads] = 0. | |
| state_dict[1]['exp_avg_sq'][deads] = 0. | |
| ## encoder bias | |
| state_dict[2]['exp_avg'][deads] = 0. | |
| state_dict[2]['exp_avg_sq'][deads] = 0. | |
| ## decoder weight | |
| state_dict[3]['exp_avg'][:,deads] = 0. | |
| state_dict[3]['exp_avg_sq'][:,deads] = 0. |