Skip to content

Commit 227da89

Browse files
committed
Fix demo
1 parent f4abe32 commit 227da89

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

README.md

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,28 @@ Another key object is the `ActivationBuffer`, defined in `buffer.py`. Following
7575
An `ActivationBuffer` is initialized from an `nnsight` `LanguageModel` object, a submodule (e.g. an MLP), and a generator which yields strings (the text data). It processes a large number of strings, up to some capacity, and saves the submodule's activations. You sample batches from it, and when it is half-depleted, it refreshes itself with new text data.
7676

7777
Here's an example for training a dictionary; in it we load a language model as an `nnsight` `LanguageModel` (this will work for any Huggingface model), specify a submodule, create an `ActivationBuffer`, and then train an autoencoder with `trainSAE`.
78+
79+
NOTE: This is a simple reference example. For an example with standard hyperparameter settings, HuggingFace dataset usage, etc, we recommend referring to this [demonstration](https://github.com/adamkarvonen/dictionary_learning_demo).
7880
```python
7981
from nnsight import LanguageModel
80-
from dictionary_learning import ActivationBuffer, AutoEncoder
81-
from dictionary_learning.trainers import StandardTrainer
82+
from dictionary_learning import ActivationBuffer
83+
from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK
8284
from dictionary_learning.training import trainSAE
8385

8486
device = "cuda:0"
85-
model_name = "EleutherAI/pythia-70m-deduped" # can be any Huggingface model
87+
model_name = "EleutherAI/pythia-70m-deduped" # can be any Huggingface model
8688

8789
model = LanguageModel(
8890
model_name,
8991
device_map=device,
9092
)
91-
submodule = model.gpt_neox.layers[1].mlp # layer 1 MLP
92-
activation_dim = 512 # output dimension of the MLP
93+
layer = 1
94+
submodule = model.gpt_neox.layers[1].mlp # layer 1 MLP
95+
activation_dim = 512 # output dimension of the MLP
9396
dictionary_size = 16 * activation_dim
97+
llm_batch_size = 16
98+
sae_batch_size = 128
99+
training_steps = 20
94100

95101
# data must be an iterator that outputs strings
96102
data = iter(
@@ -99,30 +105,43 @@ data = iter(
99105
"In real life, for training a dictionary",
100106
"you would need much more data than this",
101107
]
108+
* 100000
102109
)
110+
103111
buffer = ActivationBuffer(
104112
data=data,
105113
model=model,
106114
submodule=submodule,
107-
d_submodule=activation_dim, # output dimension of the model component
108-
n_ctxs=3e4, # you can set this higher or lower dependong on your available memory
115+
d_submodule=activation_dim, # output dimension of the model component
116+
n_ctxs=int(
117+
1e2
118+
), # you can set this higher or lower depending on your available memory
109119
device=device,
120+
refresh_batch_size=llm_batch_size,
121+
out_batch_size=sae_batch_size,
110122
) # buffer will yield batches of tensors of dimension = submodule's output dimension
111123

112124
trainer_cfg = {
113-
"trainer": StandardTrainer,
114-
"dict_class": AutoEncoder,
125+
"trainer": TopKTrainer,
126+
"dict_class": AutoEncoderTopK,
115127
"activation_dim": activation_dim,
116128
"dict_size": dictionary_size,
117129
"lr": 1e-3,
118130
"device": device,
131+
"steps": training_steps,
132+
"layer": layer,
133+
"lm_name": model_name,
134+
"warmup_steps": 1,
135+
"k": 100,
119136
}
120137

121138
# train the sparse autoencoder (SAE)
122139
ae = trainSAE(
123140
data=buffer, # you could also use another (i.e. pytorch dataloader) here instead of buffer
124141
trainer_configs=[trainer_cfg],
142+
steps=training_steps, # The number of training steps. Total trained tokens = steps * batch_size
125143
)
144+
126145
```
127146
Some technical notes our training infrastructure and supported features:
128147
* Training uses the `ConstrainedAdam` optimizer defined in `training.py`. This is a variant of Adam which supports constraining the `AutoEncoder`'s decoder weights to be norm 1.

0 commit comments

Comments
 (0)