Skip to content

Commit 7810443

Browse files
author
SrGonao
committed
Renaming hooks
1 parent bcae2b8 commit 7810443

File tree

5 files changed

+140
-15
lines changed

5 files changed

+140
-15
lines changed

delphi/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from delphi.log.result_analysis import log_results
3232
from delphi.pipeline import Pipe, Pipeline, process_wrapper
3333
from delphi.scorers import DetectionScorer, FuzzingScorer
34-
from delphi.sparse_coders import load_sparse_coders
34+
from delphi.sparse_coders import load_hooks_sparse_coders
3535

3636

3737
def load_artifacts(run_cfg: RunConfig):
@@ -54,7 +54,7 @@ def load_artifacts(run_cfg: RunConfig):
5454
token=run_cfg.hf_token,
5555
)
5656

57-
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg, compile=True)
57+
hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg, compile=True)
5858

5959
return run_cfg.hookpoints, hookpoint_to_sparse_encode, model
6060

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from typing import Callable
2+
3+
import torch.nn as nn
4+
from transformers import PreTrainedModel
5+
6+
from delphi.config import RunConfig
7+
8+
from .custom.gemmascope import load_gemma_autoencoders
9+
from .load_sparsify import load_sparsify_hooks, load_sparsify_sparse_coders
10+
11+
12+
def load_hooks_sparse_coders(
13+
model: PreTrainedModel,
14+
run_cfg: RunConfig,
15+
compile: bool = False,
16+
) -> dict[str, Callable]:
17+
"""
18+
Load sparse coders for specified hookpoints.
19+
20+
Args:
21+
model (PreTrainedModel): The model to load sparse coders for.
22+
run_cfg (RunConfig): The run configuration.
23+
24+
Returns:
25+
dict[str, Callable]: A dictionary mapping hookpoints to sparse coders.
26+
"""
27+
28+
# Add SAE hooks to the model
29+
if "gemma" not in run_cfg.sparse_model:
30+
hookpoint_to_sparse_encode = load_sparsify_hooks(
31+
model,
32+
run_cfg.sparse_model,
33+
run_cfg.hookpoints,
34+
compile=compile,
35+
)
36+
else:
37+
# model path will always be of the form google/gemma-scope-<size>-pt-<type>/
38+
# where <size> is the size of the model and <type> is either res or mlp
39+
model_path = "google/" + run_cfg.sparse_model.split("/")[1]
40+
type = model_path.split("-")[-1]
41+
# we can use the hookpoints to determine the layer, size and l0,
42+
# because the module is determined by the model name
43+
# the hookpoint should be in the format
44+
# layer_<layer>/width_<sae_size>/average_l0_<l0>
45+
layers = []
46+
l0s = []
47+
sae_sizes = []
48+
for hookpoint in run_cfg.hookpoints:
49+
layer = int(hookpoint.split("/")[0].split("_")[1])
50+
sae_size = hookpoint.split("/")[1].split("_")[1]
51+
l0 = int(hookpoint.split("/")[2].split("_")[2])
52+
layers.append(layer)
53+
sae_sizes.append(sae_size)
54+
l0s.append(l0)
55+
56+
hookpoint_to_sparse_encode = load_gemma_autoencoders(
57+
model_path=model_path,
58+
ae_layers=layers,
59+
average_l0s=l0s,
60+
sizes=sae_sizes,
61+
type=type,
62+
dtype=model.dtype,
63+
device=model.device,
64+
)
65+
66+
return hookpoint_to_sparse_encode
67+
68+
69+
def load_sparse_coders(
70+
model: PreTrainedModel,
71+
run_cfg: RunConfig,
72+
compile: bool = False,
73+
) -> dict[str, nn.Module]:
74+
"""
75+
Load sparse coders for specified hookpoints.
76+
77+
Args:
78+
model (PreTrainedModel): The model to load sparse coders for.
79+
run_cfg (RunConfig): The run configuration.
80+
81+
Returns:
82+
dict[str, Callable]: A dictionary mapping hookpoints to sparse coders.
83+
"""
84+
85+
# Add SAE hooks to the model
86+
if "gemma" not in run_cfg.sparse_model:
87+
hookpoint_to_sparse_model = load_sparsify_sparse_coders(
88+
model,
89+
run_cfg.sparse_model,
90+
run_cfg.hookpoints,
91+
compile=compile,
92+
)
93+
else:
94+
# model path will always be of the form google/gemma-scope-<size>-pt-<type>/
95+
# where <size> is the size of the model and <type> is either res or mlp
96+
model_path = "google/" + run_cfg.sparse_model.split("/")[1]
97+
type = model_path.split("-")[-1]
98+
# we can use the hookpoints to determine the layer, size and l0,
99+
# because the module is determined by the model name
100+
# the hookpoint should be in the format
101+
# layer_<layer>/width_<sae_size>/average_l0_<l0>
102+
layers = []
103+
l0s = []
104+
sae_sizes = []
105+
for hookpoint in run_cfg.hookpoints:
106+
layer = int(hookpoint.split("/")[0].split("_")[1])
107+
sae_size = hookpoint.split("/")[1].split("_")[1]
108+
l0 = int(hookpoint.split("/")[2].split("_")[2])
109+
layers.append(layer)
110+
sae_sizes.append(sae_size)
111+
l0s.append(l0)
112+
113+
hookpoint_to_sparse_model = load_gemma_autoencoders(
114+
model_path=model_path,
115+
ae_layers=layers,
116+
average_l0s=l0s,
117+
sizes=sae_sizes,
118+
type=type,
119+
dtype=model.dtype,
120+
device=model.device,
121+
)
122+
123+
return hookpoint_to_sparse_model

delphi/tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from delphi.config import CacheConfig, RunConfig
66
from delphi.latents import LatentCache
7-
from delphi.sparse_coders import load_sparse_coders
7+
from delphi.sparse_coders import load_hooks_sparse_coders
88

99
random_text = [
1010
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
@@ -62,7 +62,7 @@ def cache_setup(
6262
sparse_model="EleutherAI/sae-pythia-70m-32k",
6363
hookpoints=["layers.1"],
6464
)
65-
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg_gemma)
65+
hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg_gemma)
6666

6767
# Define cache config and initialize cache
6868
cache_cfg = CacheConfig(batch_size=1, ctx_len=16, n_tokens=100)

delphi/tests/test_autoencoders/test_sparse_coders.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44

55
# Import the function to be tested
6-
from delphi.sparse_coders import load_sparse_coders
6+
from delphi.sparse_coders import load_hooks_sparse_coders
77

88

99
# A simple dummy run configuration for testing.
@@ -69,16 +69,17 @@ def run_cfg_gemma():
6969

7070
def test_retrieve_autoencoders_from_sparsify(dummy_model, run_cfg_sparsify):
7171
"""
72-
Tests that load_sparse_coders retrieves autoencoders from Sparsify.
72+
Tests that load_hooks_sparse_coders retrieves autoencoders from Sparsify.
7373
"""
74-
submodules = load_sparse_coders(dummy_model, run_cfg_sparsify)
74+
hookpoint_to_sparse_encode = load_hooks_sparse_coders(dummy_model, run_cfg_sparsify)
7575
# Verify that we received a dictionary of autoencoders.
7676
assert (
77-
isinstance(submodules, dict) and len(submodules) > 0
77+
isinstance(hookpoint_to_sparse_encode, dict)
78+
and len(hookpoint_to_sparse_encode) > 0
7879
), "No autoencoders retrieved from the Sparsify branch."
7980

8081
# Validate that at least one autoencoder is callable.
81-
for key, autoencoder in submodules.items():
82+
for key, autoencoder in hookpoint_to_sparse_encode.items():
8283
dummy_input = torch.randn(2, 512)
8384
try:
8485
_ = autoencoder(dummy_input)
@@ -91,16 +92,17 @@ def test_retrieve_autoencoders_from_sparsify(dummy_model, run_cfg_sparsify):
9192

9293
def test_retrieve_autoencoders_from_gemma(dummy_model, run_cfg_gemma):
9394
"""
94-
Tests that load_sparse_coders retrieves autoencoders from Gemma.
95+
Tests that load_hooks_sparse_coders retrieves autoencoders from Gemma.
9596
"""
96-
submodules = load_sparse_coders(dummy_model, run_cfg_gemma)
97+
hookpoint_to_sparse_encode = load_hooks_sparse_coders(dummy_model, run_cfg_gemma)
9798
# Verify that we received a dictionary of autoencoders.
9899
assert (
99-
isinstance(submodules, dict) and len(submodules) > 0
100+
isinstance(hookpoint_to_sparse_encode, dict)
101+
and len(hookpoint_to_sparse_encode) > 0
100102
), "No autoencoders retrieved from the Gemma branch."
101103

102104
# Validate that at least one autoencoder is callable.
103-
for key, autoencoder in submodules.items():
105+
for key, autoencoder in hookpoint_to_sparse_encode.items():
104106
dummy_input = torch.randn(2, 2304)
105107
try:
106108
_ = autoencoder(dummy_input)

examples/caching_activations.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"source": [
3737
"from transformers import AutoModel\n",
3838
"\n",
39-
"from delphi.sparse_coders import load_sparse_coders\n",
39+
"from delphi.sparse_coders import load_hooks_sparse_coders\n",
4040
"from delphi.config import RunConfig\n"
4141
]
4242
},
@@ -87,7 +87,7 @@
8787
" hookpoints=[\"layer_10/width_16k/average_l0_39\"],\n",
8888
")\n",
8989
"\n",
90-
"hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg)"
90+
"hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg)"
9191
]
9292
},
9393
{

0 commit comments

Comments
 (0)