Skip to content

Commit 888c586

Browse files
canragerchanindhijohnnylin
authored
feat: Temporal SAE integration (#575)
* clauded temporal SAE integration * setting up tests and removing intermediate files * use SAELens names for weights * updating implementation * disallow folding w dec norm for temporal saes * added temporal sae * Revert "added temporal sae" This reverts commit c644790. * added warning to standalone decoding * updated loading temporal sae to safetensors format * fixing syntax issues * testing sae inference * fix: 1. W_enc not initialized for tied weights and 2. added scaling factor. * added end-to-end comparison with original implementation of TemporalSAE * fixed linting * fix: set temporal hook_name, fix lint * add neuronpedia entries to yaml * use gemma-2-2b instead of google/gemma-2-2b in pretrained yaml * make W_enc optional for Temporal SAE * adapted tests * ruff formatting * fixed layer index of temporal Llama SAEs * fix: temporal pretrained yaml * fix: undo formatting change * Fix hook_resid_post ID in pretrained_saes.yaml * fix: final corrections for temporal SAEs llama yaml * moving scaling into temporal SAEs for now --------- Co-authored-by: David Chanin <[email protected]> Co-authored-by: Johnny Lin <[email protected]>
1 parent 0507574 commit 888c586

File tree

13 files changed

+1685
-19
lines changed

13 files changed

+1685
-19
lines changed

sae_lens/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
StandardSAEConfig,
2929
StandardTrainingSAE,
3030
StandardTrainingSAEConfig,
31+
TemporalSAE,
32+
TemporalSAEConfig,
3133
TopKSAE,
3234
TopKSAEConfig,
3335
TopKTrainingSAE,
@@ -105,6 +107,8 @@
105107
"JumpReLUTranscoderConfig",
106108
"MatryoshkaBatchTopKTrainingSAE",
107109
"MatryoshkaBatchTopKTrainingSAEConfig",
110+
"TemporalSAE",
111+
"TemporalSAEConfig",
108112
]
109113

110114

@@ -127,3 +131,4 @@
127131
register_sae_class("transcoder", Transcoder, TranscoderConfig)
128132
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
129133
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)

sae_lens/loading/pretrained_sae_loaders.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
15511551
}
15521552

15531553

1554+
def get_temporal_sae_config_from_hf(
1555+
repo_id: str,
1556+
folder_name: str,
1557+
device: str,
1558+
force_download: bool = False,
1559+
cfg_overrides: dict[str, Any] | None = None,
1560+
) -> dict[str, Any]:
1561+
"""Get TemporalSAE config without loading weights."""
1562+
# Download config file
1563+
conf_path = hf_hub_download(
1564+
repo_id=repo_id,
1565+
filename=f"{folder_name}/conf.yaml",
1566+
force_download=force_download,
1567+
)
1568+
1569+
# Load and parse config
1570+
with open(conf_path) as f:
1571+
yaml_config = yaml.safe_load(f)
1572+
1573+
# Extract parameters
1574+
d_in = yaml_config["llm"]["dimin"]
1575+
exp_factor = yaml_config["sae"]["exp_factor"]
1576+
d_sae = int(d_in * exp_factor)
1577+
1578+
# extract layer from folder_name eg : "layer_12/temporal"
1579+
layer = re.search(r"layer_(\d+)", folder_name)
1580+
if layer is None:
1581+
raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582+
layer = int(layer.group(1))
1583+
1584+
# Build config dict
1585+
cfg_dict = {
1586+
"architecture": "temporal",
1587+
"hook_name": f"blocks.{layer}.hook_resid_post",
1588+
"d_in": d_in,
1589+
"d_sae": d_sae,
1590+
"n_heads": yaml_config["sae"]["n_heads"],
1591+
"n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592+
"bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593+
"sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594+
"kval_topk": yaml_config["sae"]["kval_topk"],
1595+
"tied_weights": yaml_config["sae"]["tied_weights"],
1596+
"dtype": yaml_config["data"]["dtype"],
1597+
"device": device,
1598+
"normalize_activations": "constant_scalar_rescale",
1599+
"activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600+
"apply_b_dec_to_input": True,
1601+
}
1602+
1603+
if cfg_overrides:
1604+
cfg_dict.update(cfg_overrides)
1605+
1606+
return cfg_dict
1607+
1608+
1609+
def temporal_sae_huggingface_loader(
1610+
repo_id: str,
1611+
folder_name: str,
1612+
device: str = "cpu",
1613+
force_download: bool = False,
1614+
cfg_overrides: dict[str, Any] | None = None,
1615+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616+
"""
1617+
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618+
1619+
Expects folder_name to contain:
1620+
- conf.yaml (configuration)
1621+
- latest_ckpt.safetensors (model weights)
1622+
"""
1623+
1624+
cfg_dict = get_temporal_sae_config_from_hf(
1625+
repo_id=repo_id,
1626+
folder_name=folder_name,
1627+
device=device,
1628+
force_download=force_download,
1629+
cfg_overrides=cfg_overrides,
1630+
)
1631+
1632+
# Download checkpoint (safetensors format)
1633+
ckpt_path = hf_hub_download(
1634+
repo_id=repo_id,
1635+
filename=f"{folder_name}/latest_ckpt.safetensors",
1636+
force_download=force_download,
1637+
)
1638+
1639+
# Load checkpoint from safetensors
1640+
state_dict_raw = load_file(ckpt_path, device=device)
1641+
1642+
# Convert to SAELens naming convention
1643+
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644+
state_dict = {}
1645+
1646+
# Copy attention layers as-is
1647+
for key, value in state_dict_raw.items():
1648+
if key.startswith("attn_layers."):
1649+
state_dict[key] = value.to(device)
1650+
1651+
# Main parameters
1652+
state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653+
state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654+
1655+
# Handle tied/untied weights
1656+
if "E" in state_dict_raw:
1657+
state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658+
1659+
return cfg_dict, state_dict, None
1660+
1661+
15541662
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
15551663
"sae_lens": sae_lens_huggingface_loader,
15561664
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1563,6 +1671,7 @@ def get_mntss_clt_layer_config_from_hf(
15631671
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
15641672
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
15651673
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674+
"temporal": temporal_sae_huggingface_loader,
15661675
"goodfire": get_goodfire_huggingface_loader,
15671676
}
15681677

@@ -1579,5 +1688,6 @@ def get_mntss_clt_layer_config_from_hf(
15791688
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
15801689
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
15811690
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691+
"temporal": get_temporal_sae_config_from_hf,
15821692
"goodfire": get_goodfire_config_from_hf,
15831693
}

sae_lens/loading/pretrained_saes_directory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from functools import cache
3-
from importlib import resources
3+
from importlib.resources import files
44
from typing import Any
55

66
import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
2424
package = "sae_lens"
2525
# Access the file within the package using importlib.resources
2626
directory: dict[str, PretrainedSAELookup] = {}
27-
with resources.open_text(package, "pretrained_saes.yaml") as file:
27+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
28+
with yaml_file.open("r") as file:
2829
# Load the YAML file content
2930
data = yaml.safe_load(file)
3031
for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
6869
float | None: The norm_scaling_factor if it exists, None otherwise.
6970
"""
7071
package = "sae_lens"
71-
with resources.open_text(package, "pretrained_saes.yaml") as file:
72+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
73+
with yaml_file.open("r") as file:
7274
data = yaml.safe_load(file)
7375
if release in data:
7476
for sae_info in data[release]["saes"]:

sae_lens/pretrained_saes.yaml

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
temporal-sae-gemma-2-2b:
2+
conversion_func: temporal
3+
model: gemma-2-2b
4+
repo_id: canrager/temporalSAEs
5+
config_overrides:
6+
model_name: gemma-2-2b
7+
hook_name: blocks.12.hook_resid_post
8+
dataset_path: monology/pile-uncopyrighted
9+
saes:
10+
- id: blocks.12.hook_resid_post
11+
l0: 192
12+
norm_scaling_factor: 0.00666666667
13+
path: gemma-2-2B/layer_12/temporal
14+
neuronpedia: gemma-2-2b/12-temporal-res
15+
temporal-sae-llama-3.1-8b:
16+
conversion_func: temporal
17+
model: meta-llama/Llama-3.1-8B
18+
repo_id: canrager/temporalSAEs
19+
config_overrides:
20+
model_name: meta-llama/Llama-3.1-8B
21+
dataset_path: monology/pile-uncopyrighted
22+
saes:
23+
- id: blocks.15.hook_resid_post
24+
l0: 256
25+
norm_scaling_factor: 0.029
26+
path: llama-3.1-8B/layer_15/temporal
27+
neuronpedia: llama3.1-8b/15-temporal-res
28+
- id: blocks.26.hook_resid_post
29+
l0: 256
30+
norm_scaling_factor: 0.029
31+
path: llama-3.1-8B/layer_26/temporal
32+
neuronpedia: llama3.1-8b/26-temporal-res
133
deepseek-r1-distill-llama-8b-qresearch:
234
conversion_func: deepseek_r1
335
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
@@ -14900,4 +14932,4 @@ goodfire-llama-3.1-8b-instruct:
1490014932
saes:
1490114933
- id: layer_19
1490214934
path: Llama-3.1-8B-Instruct-SAE-l19.pth
14903-
l0: 91
14935+
l0: 91

sae_lens/saes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
StandardTrainingSAE,
2626
StandardTrainingSAEConfig,
2727
)
28+
from .temporal_sae import TemporalSAE, TemporalSAEConfig
2829
from .topk_sae import (
2930
TopKSAE,
3031
TopKSAEConfig,
@@ -71,4 +72,6 @@
7172
"JumpReLUTranscoderConfig",
7273
"MatryoshkaBatchTopKTrainingSAE",
7374
"MatryoshkaBatchTopKTrainingSAEConfig",
75+
"TemporalSAE",
76+
"TemporalSAEConfig",
7477
]

sae_lens/saes/sae.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ class SAEConfig(ABC):
155155
dtype: str = "float32"
156156
device: str = "cpu"
157157
apply_b_dec_to_input: bool = True
158-
normalize_activations: Literal[
159-
"none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
160-
] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
158+
normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
159+
"none" # none, expected_average_only_in (Anthropic April Update)
160+
)
161161
reshape_activations: Literal["none", "hook_z"] = "none"
162162
metadata: SAEMetadata = field(default_factory=SAEMetadata)
163163

@@ -309,6 +309,7 @@ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
309309

310310
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
311311
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
312+
312313
elif self.cfg.normalize_activations == "layer_norm":
313314
# we need to scale the norm of the input and store the scaling factor
314315
def run_time_activation_ln_in(
@@ -452,23 +453,14 @@ def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
452453
def process_sae_in(
453454
self, sae_in: Float[torch.Tensor, "... d_in"]
454455
) -> Float[torch.Tensor, "... d_in"]:
455-
# print(f"Input shape to process_sae_in: {sae_in.shape}")
456-
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
457-
# print(f"self.b_dec shape: {self.b_dec.shape}")
458-
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
459-
460456
sae_in = sae_in.to(self.dtype)
461-
462-
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
463457
sae_in = self.reshape_fn_in(sae_in)
464-
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
465458

466459
sae_in = self.hook_sae_input(sae_in)
467460
sae_in = self.run_time_activation_norm_fn_in(sae_in)
468461

469462
# Here's where the error happens
470463
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
471-
# print(f"Bias term shape: {bias_term.shape}")
472464

473465
return sae_in - bias_term
474466

0 commit comments

Comments
 (0)