-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdebug.py
More file actions
41 lines (36 loc) · 1.34 KB
/
debug.py
File metadata and controls
41 lines (36 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from notebooks.notebook_utils import GEMMA_SCOPE_RELEASE_ID, GEMMA_SCOPE_VERSIONS
from src.extract_data import batch_create_feature_dataset, batch_create_granularity_dataset
from src.models.TopKSaes import MatryoshkaTopKSae
import JumpReLUSae, MatryoshkaJumpReLU
import torch
import JumpReLU, MatryoshkaJumpReLU
import os
from src.utils import load_gemma_scope_versions
from src.auto_interp.auto_interp_pipeline import batch_interpret_models
from main import NEW_SAE_PATHS
if __name__ == "__main__":
import os
os.environ["WANDB_API_KEY"] = "a3469eb2df23f67e4d6907ebacf50ffb4ee664f7"
os.environ["HF_TOKEN"] = "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"
#NOTE
#you have huggingface credits and fireworks credits use those before openai
GEMMA_MODELS = [
(None,
f'/root/pile_uncopyrighted/{V.replace("/", "_")}',
V.replace('/', '_')
)
for V in GEMMA_SCOPE_VERSIONS
]
MODELS = NEW_SAE_PATHS
path1 = '/root/pile_uncopyrighted/datasets/auto_interp'
os.makedirs(path1, exist_ok=True)
async def main():
await batch_interpret_models(
sae_entries=[GEMMA_MODELS[0]],
copy_to_path=path1,
device=torch.device('cpu'),
type='simulator',
skip_explanations=True
)
import asyncio
asyncio.run(main())