Skip to content

Commit e798cf4

Browse files
authored
Add sample GPT training run (#194)
1 parent 62fce3f commit e798cf4

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Training runs."""
+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Run an sweep on all layers of GPT2 Small.
2+
3+
Command:
4+
5+
```bash
6+
git clone https://github.com/ai-safety-foundation/sparse_autoencoder.git && cd sparse_autoencoder &&
7+
poetry env use python3.11 && poetry install &&
8+
poetry run python sparse_autoencoder/training_runs/gpt2.py
9+
```
10+
"""
11+
import os
12+
13+
from sparse_autoencoder import (
14+
ActivationResamplerHyperparameters,
15+
AutoencoderHyperparameters,
16+
Hyperparameters,
17+
LossHyperparameters,
18+
Method,
19+
OptimizerHyperparameters,
20+
Parameter,
21+
PipelineHyperparameters,
22+
SourceDataHyperparameters,
23+
SourceModelHyperparameters,
24+
SweepConfig,
25+
sweep,
26+
)
27+
28+
29+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
30+
31+
32+
def train() -> None:
33+
"""Train."""
34+
sweep_config = SweepConfig(
35+
parameters=Hyperparameters(
36+
loss=LossHyperparameters(
37+
l1_coefficient=Parameter(values=[0.0001]),
38+
),
39+
optimizer=OptimizerHyperparameters(
40+
lr=Parameter(value=0.0001),
41+
),
42+
source_model=SourceModelHyperparameters(
43+
name=Parameter("gpt2"),
44+
cache_names=Parameter(
45+
value=[f"blocks.{layer}.hook_mlp_out" for layer in range(12)]
46+
),
47+
hook_dimension=Parameter(768),
48+
),
49+
source_data=SourceDataHyperparameters(
50+
dataset_path=Parameter("alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"),
51+
context_size=Parameter(256),
52+
pre_tokenized=Parameter(value=True),
53+
pre_download=Parameter(value=True),
54+
# Total dataset is c.7bn activations (64 files)
55+
# C. 1.5TB needed to store all activations
56+
dataset_files=Parameter(
57+
[f"data/train-{str(i).zfill(5)}-of-00064.parquet" for i in range(20)]
58+
),
59+
),
60+
autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(values=[32, 64])),
61+
pipeline=PipelineHyperparameters(),
62+
activation_resampler=ActivationResamplerHyperparameters(
63+
threshold_is_dead_portion_fires=Parameter(1e-5),
64+
),
65+
),
66+
method=Method.GRID,
67+
)
68+
69+
sweep(sweep_config=sweep_config)
70+
71+
72+
if __name__ == "__main__":
73+
train()

0 commit comments

Comments
 (0)