Skip to content

Commit bcae2b8

Browse files
author
SrGonao
committed
Add load sparse_coders and rename load_hooks
1 parent 09e888e commit bcae2b8

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

delphi/sparse_coders/custom/gemmascope.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@ def load_gemma_autoencoders(
1515
dtype: torch.dtype = torch.bfloat16,
1616
device: torch.device = torch.device("cuda"),
1717
):
18-
submodules = {}
18+
saes = {}
1919

2020
for layer, size, l0 in zip(ae_layers, sizes, average_l0s):
2121
path = f"layer_{layer}/width_{size}/average_l0_{l0}"
2222
sae = JumpReluSae.from_pretrained(model_path, path, device)
2323

2424
sae.to(dtype)
2525

26-
def _forward(sae, x):
27-
encoded = sae.encode(x)
28-
return encoded
29-
3026
assert type in [
3127
"res",
3228
"mlp",
@@ -37,9 +33,39 @@ def _forward(sae, x):
3733
else f"layers.{layer}.post_feedforward_layernorm"
3834
)
3935

40-
submodules[hookpoint] = partial(_forward, sae)
36+
saes[hookpoint] = sae
37+
38+
return saes
39+
40+
41+
def load_gemma_hooks(
42+
model_path: str,
43+
ae_layers: list[int],
44+
average_l0s: list[int],
45+
sizes: list[str],
46+
type: str,
47+
dtype: torch.dtype = torch.bfloat16,
48+
device: torch.device = torch.device("cuda"),
49+
):
50+
saes = load_gemma_autoencoders(
51+
model_path,
52+
ae_layers,
53+
average_l0s,
54+
sizes,
55+
type,
56+
dtype,
57+
device,
58+
)
59+
hookpoint_to_sparse_encode = {}
60+
for hookpoint, sae in saes.items():
61+
62+
def _forward(sae, x):
63+
encoded = sae.encode(x)
64+
return encoded
65+
66+
hookpoint_to_sparse_encode[hookpoint] = partial(_forward, sae)
4167

42-
return submodules
68+
return hookpoint_to_sparse_encode
4369

4470

4571
# This is from the GemmaScope tutorial

delphi/sparse_coders/load_sparsify.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,49 +54,79 @@ def load_sparsify_sparse_coders(
5454
model (Any): The model to load autoencoders for.
5555
name (str): The name of the sparse model to load. If the model is on-disk
5656
this is the path to the directory containing the sparse model weights.
57-
hookpoints (list[str]): list of hookpoints to load autoencoders for.
57+
hookpoints (list[str]): list of hookpoints to identify the sparse models.
5858
device (str | torch.device | None, optional): The device to load the
5959
sparse models on. If not specified the sparse models will be loaded
6060
on the same device as the base model.
6161
6262
Returns:
63-
tuple[dict[str, Any], Any]: A tuple containing the submodules dictionary
64-
and the edited model.
63+
dict[str, Any]: A dictionary mapping hookpoints to sparse models.
6564
"""
6665
if device is None:
6766
device = model.device or "cpu"
6867

6968
# Load the sparse models
70-
hookpoint_to_sparse = {}
69+
sparse_model_dict = {}
7170
name_path = Path(name)
7271
if name_path.exists():
7372
for hookpoint in hookpoints:
74-
hookpoint_to_sparse[hookpoint] = Sae.load_from_disk(
73+
sparse_model_dict[hookpoint] = Sae.load_from_disk(
7574
name_path / hookpoint, device=device
7675
)
7776
if compile:
78-
hookpoint_to_sparse[hookpoint] = torch.compile(
79-
hookpoint_to_sparse[hookpoint]
77+
sparse_model_dict[hookpoint] = torch.compile(
78+
sparse_model_dict[hookpoint]
8079
)
8180
else:
8281
sparse_models = Sae.load_many(name, device=device)
8382
for hookpoint in hookpoints:
84-
hookpoint_to_sparse[hookpoint] = sparse_models[hookpoint]
83+
sparse_model_dict[hookpoint] = sparse_models[hookpoint]
8584
if compile:
86-
hookpoint_to_sparse[hookpoint] = torch.compile(
87-
hookpoint_to_sparse[hookpoint]
85+
sparse_model_dict[hookpoint] = torch.compile(
86+
sparse_model_dict[hookpoint]
8887
)
8988

9089
del sparse_models
90+
return sparse_model_dict
9191

92-
submodules = {}
93-
for hookpoint, sparse_model in hookpoint_to_sparse.items():
92+
93+
def load_sparsify_hooks(
94+
model: PreTrainedModel,
95+
name: str,
96+
hookpoints: list[str],
97+
device: str | torch.device | None = None,
98+
compile: bool = False,
99+
) -> dict[str, Callable]:
100+
"""
101+
Load the encode functions for sparsify sparse coders on specified hookpoints.
102+
103+
Args:
104+
model (Any): The model to load autoencoders for.
105+
name (str): The name of the sparse model to load. If the model is on-disk
106+
this is the path to the directory containing the sparse model weights.
107+
hookpoints (list[str]): list of hookpoints to identify the sparse models.
108+
device (str | torch.device | None, optional): The device to load the
109+
sparse models on. If not specified the sparse models will be loaded
110+
on the same device as the base model.
111+
112+
Returns:
113+
dict[str, Callable]: A dictionary mapping hookpoints to encode functions.
114+
"""
115+
sparse_model_dict = load_sparsify_sparse_coders(
116+
model,
117+
name,
118+
hookpoints,
119+
device,
120+
compile,
121+
)
122+
hookpoint_to_sparse_encode = {}
123+
for hookpoint, sparse_model in sparse_model_dict.items():
94124
path_segments = resolve_path(model, hookpoint.split("."))
95125
if path_segments is None:
96126
raise ValueError(f"Could not find valid path for hookpoint: {hookpoint}")
97127

98-
submodules[".".join(path_segments)] = partial(
128+
hookpoint_to_sparse_encode[".".join(path_segments)] = partial(
99129
sae_dense_latents, sae=sparse_model
100130
)
101131

102-
return submodules
132+
return hookpoint_to_sparse_encode

0 commit comments

Comments
 (0)