Skip to content

Commit 8c30a90

Browse files
authored
Merge branch 'main' into pre-commit-ci-update-config
2 parents 23a1dba + 7e20d4a commit 8c30a90

File tree

10 files changed

+224
-105
lines changed

10 files changed

+224
-105
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ Install this library as a local editable installation. Run the following command
1616

1717
To run the default pipeline from the command line, use the following command:
1818

19-
`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_features 100 --hookpoints layers.5 --filter_bos`
19+
`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B`
2020

2121
This command will:
2222
1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample.
2323
2. Generate explanations for the first 100 features of layer 5 using the specified explainer model.
2424
3. Score the explanations uses fuzzing and detection scorers.
25-
4. Log summary metrics including per-scorer F1 scores and confusion matrices.
25+
4. Log summary metrics including per-scorer F1 scores and confusion matrices, and produce histograms of the scorer classification accuracies.
2626

2727
The pipeline is highly configurable and can also be called programmatically (see the [end-to-end test](https://github.com/EleutherAI/delphi/blob/main/delphi/tests/e2e.py) for an example).
2828

delphi/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from delphi.log.result_analysis import log_results
3232
from delphi.pipeline import Pipe, Pipeline, process_wrapper
3333
from delphi.scorers import DetectionScorer, FuzzingScorer
34-
from delphi.sparse_coders import load_sparse_coders
34+
from delphi.sparse_coders import load_hooks_sparse_coders
3535

3636

3737
def load_artifacts(run_cfg: RunConfig):
@@ -54,7 +54,7 @@ def load_artifacts(run_cfg: RunConfig):
5454
token=run_cfg.hf_token,
5555
)
5656

57-
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg, compile=True)
57+
hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg, compile=True)
5858

5959
return run_cfg.hookpoints, hookpoint_to_sparse_encode, model
6060

delphi/explainers/default/prompt_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ def build_prompt(
4141

4242
messages.extend(few_shot_examples)
4343

44-
user_start = f"WORDS: {examples}"
44+
user_start = f"\n{examples}\n"
4545

4646
messages.append(
4747
{
4848
"role": "user",
4949
"content": user_start,
5050
}
5151
)
52+
print(messages)
5253

5354
return messages
5455

delphi/sparse_coders/__init__.py

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,4 @@
1-
from typing import Callable
1+
from .sparse_model import load_hooks_sparse_coders, load_sparse_coders
22

3-
from transformers import PreTrainedModel
3+
__all__ = ["load_hooks_sparse_coders", "load_sparse_coders"]
44

5-
from delphi.config import RunConfig
6-
7-
from .custom.gemmascope import load_gemma_autoencoders
8-
from .load_sparsify import load_sparsify_sparse_coders
9-
10-
__all__ = ["load_sparse_coders"]
11-
12-
13-
def load_sparse_coders(
14-
model: PreTrainedModel,
15-
run_cfg: RunConfig,
16-
compile: bool = False,
17-
) -> dict[str, Callable]:
18-
"""
19-
Load sparse coders for specified hookpoints.
20-
21-
Args:
22-
model (PreTrainedModel): The model to load sparse coders for.
23-
run_cfg (RunConfig): The run configuration.
24-
25-
Returns:
26-
dict[str, Callable]: A dictionary mapping hookpoints to sparse coders.
27-
"""
28-
29-
# Add SAE hooks to the model
30-
if "gemma" not in run_cfg.sparse_model:
31-
hookpoint_to_sparse_encode = load_sparsify_sparse_coders(
32-
model,
33-
run_cfg.sparse_model,
34-
run_cfg.hookpoints,
35-
compile=compile,
36-
)
37-
else:
38-
# model path will always be of the form google/gemma-scope-<size>-pt-<type>/
39-
# where <size> is the size of the model and <type> is either res or mlp
40-
model_path = "google/" + run_cfg.sparse_model.split("/")[1]
41-
type = model_path.split("-")[-1]
42-
# we can use the hookpoints to determine the layer, size and l0,
43-
# because the module is determined by the model name
44-
# the hookpoint should be in the format
45-
# layer_<layer>/width_<sae_size>/average_l0_<l0>
46-
layers = []
47-
l0s = []
48-
sae_sizes = []
49-
for hookpoint in run_cfg.hookpoints:
50-
layer = int(hookpoint.split("/")[0].split("_")[1])
51-
sae_size = hookpoint.split("/")[1].split("_")[1]
52-
l0 = int(hookpoint.split("/")[2].split("_")[2])
53-
layers.append(layer)
54-
sae_sizes.append(sae_size)
55-
l0s.append(l0)
56-
57-
hookpoint_to_sparse_encode = load_gemma_autoencoders(
58-
model_path=model_path,
59-
ae_layers=layers,
60-
average_l0s=l0s,
61-
sizes=sae_sizes,
62-
type=type,
63-
dtype=model.dtype,
64-
device=model.device,
65-
)
66-
67-
return hookpoint_to_sparse_encode

delphi/sparse_coders/custom/gemmascope.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@ def load_gemma_autoencoders(
1515
dtype: torch.dtype = torch.bfloat16,
1616
device: torch.device = torch.device("cuda"),
1717
):
18-
submodules = {}
18+
saes = {}
1919

2020
for layer, size, l0 in zip(ae_layers, sizes, average_l0s):
2121
path = f"layer_{layer}/width_{size}/average_l0_{l0}"
2222
sae = JumpReluSae.from_pretrained(model_path, path, device)
2323

2424
sae.to(dtype)
2525

26-
def _forward(sae, x):
27-
encoded = sae.encode(x)
28-
return encoded
29-
3026
assert type in [
3127
"res",
3228
"mlp",
@@ -37,9 +33,39 @@ def _forward(sae, x):
3733
else f"layers.{layer}.post_feedforward_layernorm"
3834
)
3935

40-
submodules[hookpoint] = partial(_forward, sae)
36+
saes[hookpoint] = sae
37+
38+
return saes
39+
40+
41+
def load_gemma_hooks(
42+
model_path: str,
43+
ae_layers: list[int],
44+
average_l0s: list[int],
45+
sizes: list[str],
46+
type: str,
47+
dtype: torch.dtype = torch.bfloat16,
48+
device: torch.device = torch.device("cuda"),
49+
):
50+
saes = load_gemma_autoencoders(
51+
model_path,
52+
ae_layers,
53+
average_l0s,
54+
sizes,
55+
type,
56+
dtype,
57+
device,
58+
)
59+
hookpoint_to_sparse_encode = {}
60+
for hookpoint, sae in saes.items():
61+
62+
def _forward(sae, x):
63+
encoded = sae.encode(x)
64+
return encoded
65+
66+
hookpoint_to_sparse_encode[hookpoint] = partial(_forward, sae)
4167

42-
return submodules
68+
return hookpoint_to_sparse_encode
4369

4470

4571
# This is from the GemmaScope tutorial

delphi/sparse_coders/load_sparsify.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,49 +54,79 @@ def load_sparsify_sparse_coders(
5454
model (Any): The model to load autoencoders for.
5555
name (str): The name of the sparse model to load. If the model is on-disk
5656
this is the path to the directory containing the sparse model weights.
57-
hookpoints (list[str]): list of hookpoints to load autoencoders for.
57+
hookpoints (list[str]): list of hookpoints to identify the sparse models.
5858
device (str | torch.device | None, optional): The device to load the
5959
sparse models on. If not specified the sparse models will be loaded
6060
on the same device as the base model.
6161
6262
Returns:
63-
tuple[dict[str, Any], Any]: A tuple containing the submodules dictionary
64-
and the edited model.
63+
dict[str, Any]: A dictionary mapping hookpoints to sparse models.
6564
"""
6665
if device is None:
6766
device = model.device or "cpu"
6867

6968
# Load the sparse models
70-
hookpoint_to_sparse = {}
69+
sparse_model_dict = {}
7170
name_path = Path(name)
7271
if name_path.exists():
7372
for hookpoint in hookpoints:
74-
hookpoint_to_sparse[hookpoint] = Sae.load_from_disk(
73+
sparse_model_dict[hookpoint] = Sae.load_from_disk(
7574
name_path / hookpoint, device=device
7675
)
7776
if compile:
78-
hookpoint_to_sparse[hookpoint] = torch.compile(
79-
hookpoint_to_sparse[hookpoint]
77+
sparse_model_dict[hookpoint] = torch.compile(
78+
sparse_model_dict[hookpoint]
8079
)
8180
else:
8281
sparse_models = Sae.load_many(name, device=device)
8382
for hookpoint in hookpoints:
84-
hookpoint_to_sparse[hookpoint] = sparse_models[hookpoint]
83+
sparse_model_dict[hookpoint] = sparse_models[hookpoint]
8584
if compile:
86-
hookpoint_to_sparse[hookpoint] = torch.compile(
87-
hookpoint_to_sparse[hookpoint]
85+
sparse_model_dict[hookpoint] = torch.compile(
86+
sparse_model_dict[hookpoint]
8887
)
8988

9089
del sparse_models
90+
return sparse_model_dict
9191

92-
submodules = {}
93-
for hookpoint, sparse_model in hookpoint_to_sparse.items():
92+
93+
def load_sparsify_hooks(
94+
model: PreTrainedModel,
95+
name: str,
96+
hookpoints: list[str],
97+
device: str | torch.device | None = None,
98+
compile: bool = False,
99+
) -> dict[str, Callable]:
100+
"""
101+
Load the encode functions for sparsify sparse coders on specified hookpoints.
102+
103+
Args:
104+
model (Any): The model to load autoencoders for.
105+
name (str): The name of the sparse model to load. If the model is on-disk
106+
this is the path to the directory containing the sparse model weights.
107+
hookpoints (list[str]): list of hookpoints to identify the sparse models.
108+
device (str | torch.device | None, optional): The device to load the
109+
sparse models on. If not specified the sparse models will be loaded
110+
on the same device as the base model.
111+
112+
Returns:
113+
dict[str, Callable]: A dictionary mapping hookpoints to encode functions.
114+
"""
115+
sparse_model_dict = load_sparsify_sparse_coders(
116+
model,
117+
name,
118+
hookpoints,
119+
device,
120+
compile,
121+
)
122+
hookpoint_to_sparse_encode = {}
123+
for hookpoint, sparse_model in sparse_model_dict.items():
94124
path_segments = resolve_path(model, hookpoint.split("."))
95125
if path_segments is None:
96126
raise ValueError(f"Could not find valid path for hookpoint: {hookpoint}")
97127

98-
submodules[".".join(path_segments)] = partial(
128+
hookpoint_to_sparse_encode[".".join(path_segments)] = partial(
99129
sae_dense_latents, sae=sparse_model
100130
)
101131

102-
return submodules
132+
return hookpoint_to_sparse_encode

0 commit comments

Comments
 (0)