Skip to content

Commit f5beb32

Browse files
authored
feat: adding support for loading goodfire llama 3 SAEs (#579)
1 parent 137e3c7 commit f5beb32

File tree

3 files changed

+326
-1
lines changed

3 files changed

+326
-1
lines changed

sae_lens/loading/pretrained_sae_loaders.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
523523
return cfg_dict, state_dict, log_sparsity
524524

525525

526+
def get_goodfire_config_from_hf(
527+
repo_id: str,
528+
folder_name: str, # noqa: ARG001
529+
device: str,
530+
force_download: bool = False, # noqa: ARG001
531+
cfg_overrides: dict[str, Any] | None = None,
532+
) -> dict[str, Any]:
533+
cfg_dict = None
534+
if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535+
if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537+
cfg_dict = {
538+
"architecture": "standard",
539+
"d_in": 8192,
540+
"d_sae": 65536,
541+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
542+
"hook_name": "blocks.50.hook_resid_post",
543+
"hook_head_index": None,
544+
"dataset_path": "lmsys/lmsys-chat-1m",
545+
"apply_b_dec_to_input": False,
546+
}
547+
elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548+
if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550+
cfg_dict = {
551+
"architecture": "standard",
552+
"d_in": 4096,
553+
"d_sae": 65536,
554+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
555+
"hook_name": "blocks.19.hook_resid_post",
556+
"hook_head_index": None,
557+
"dataset_path": "lmsys/lmsys-chat-1m",
558+
"apply_b_dec_to_input": False,
559+
}
560+
if cfg_dict is None:
561+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562+
if device is not None:
563+
cfg_dict["device"] = device
564+
if cfg_overrides is not None:
565+
cfg_dict.update(cfg_overrides)
566+
return cfg_dict
567+
568+
569+
def get_goodfire_huggingface_loader(
570+
repo_id: str,
571+
folder_name: str,
572+
device: str = "cpu",
573+
force_download: bool = False,
574+
cfg_overrides: dict[str, Any] | None = None,
575+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576+
cfg_dict = get_goodfire_config_from_hf(
577+
repo_id,
578+
folder_name,
579+
device,
580+
force_download,
581+
cfg_overrides,
582+
)
583+
584+
# Download the SAE weights
585+
sae_path = hf_hub_download(
586+
repo_id=repo_id,
587+
filename=folder_name,
588+
force_download=force_download,
589+
)
590+
raw_state_dict = torch.load(sae_path, map_location=device)
591+
592+
state_dict = {
593+
"W_enc": raw_state_dict["encoder_linear.weight"].T,
594+
"W_dec": raw_state_dict["decoder_linear.weight"].T,
595+
"b_enc": raw_state_dict["encoder_linear.bias"],
596+
"b_dec": raw_state_dict["decoder_linear.bias"],
597+
}
598+
599+
return cfg_dict, state_dict, None
600+
601+
526602
def get_llama_scope_config_from_hf(
527603
repo_id: str,
528604
folder_name: str,
@@ -1487,6 +1563,7 @@ def get_mntss_clt_layer_config_from_hf(
14871563
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
14881564
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
14891565
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1566+
"goodfire": get_goodfire_huggingface_loader,
14901567
}
14911568

14921569

@@ -1502,4 +1579,5 @@ def get_mntss_clt_layer_config_from_hf(
15021579
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
15031580
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
15041581
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1582+
"goodfire": get_goodfire_config_from_hf,
15051583
}

sae_lens/pretrained_saes.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14882,4 +14882,22 @@ qwen2.5-7b-instruct-andyrdt:
1488214882
neuronpedia: qwen2.5-7b-it/23-resid-post-aa
1488314883
- id: resid_post_layer_27_trainer_1
1488414884
path: resid_post_layer_27/trainer_1
14885-
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14885+
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14886+
14887+
goodfire-llama-3.3-70b-instruct:
14888+
conversion_func: goodfire
14889+
model: meta-llama/Llama-3.3-70B-Instruct
14890+
repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
14891+
saes:
14892+
- id: layer_50
14893+
path: Llama-3.3-70B-Instruct-SAE-l50.pt
14894+
l0: 121
14895+
14896+
goodfire-llama-3.1-8b-instruct:
14897+
conversion_func: goodfire
14898+
model: meta-llama/Llama-3.1-8B-Instruct
14899+
repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
14900+
saes:
14901+
- id: layer_19
14902+
path: Llama-3.1-8B-Instruct-SAE-l19.pth
14903+
l0: 91

tests/loading/test_pretrained_sae_loaders.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from pathlib import Path
2+
from typing import Any
23

34
import pytest
45
import torch
56
import yaml
67
from safetensors.torch import save_file
78
from sparsify import SparseCoder, SparseCoderConfig
89

10+
from sae_lens import StandardSAE, StandardSAEConfig
911
from sae_lens.loading.pretrained_sae_loaders import (
1012
dictionary_learning_sae_huggingface_loader_1,
1113
get_deepseek_r1_config_from_hf,
1214
get_gemma_2_transcoder_config_from_hf,
15+
get_goodfire_config_from_hf,
16+
get_goodfire_huggingface_loader,
1317
get_llama_scope_config_from_hf,
1418
get_llama_scope_r1_distill_config_from_hf,
1519
get_mntss_clt_layer_config_from_hf,
@@ -21,6 +25,7 @@
2125
sparsify_huggingface_loader,
2226
)
2327
from sae_lens.saes.sae import SAE
28+
from tests.helpers import assert_close, random_params
2429

2530

2631
def test_load_sae_config_from_huggingface():
@@ -500,6 +505,230 @@ def test_get_llama_scope_config_from_hf():
500505
assert cfg == expected_cfg
501506

502507

508+
def test_get_goodfire_config_from_hf():
509+
cfg = get_goodfire_config_from_hf(
510+
repo_id="Goodfire/Llama-3.3-70B-Instruct-SAE-l50",
511+
folder_name="Llama-3.3-70B-Instruct-SAE-l50.pt",
512+
device="cpu",
513+
)
514+
expected_cfg = {
515+
"architecture": "standard",
516+
"d_in": 8192,
517+
"d_sae": 65536,
518+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
519+
"hook_name": "blocks.50.hook_resid_post",
520+
"hook_head_index": None,
521+
"dataset_path": "lmsys/lmsys-chat-1m",
522+
"apply_b_dec_to_input": False,
523+
"device": "cpu",
524+
}
525+
assert cfg == expected_cfg
526+
527+
528+
def test_get_goodfire_llama_8b_config_from_hf():
529+
cfg = get_goodfire_config_from_hf(
530+
repo_id="Goodfire/Llama-3.1-8B-Instruct-SAE-l19",
531+
folder_name="Llama-3.1-8B-Instruct-SAE-l19.pth",
532+
device="cpu",
533+
)
534+
expected_cfg = {
535+
"architecture": "standard",
536+
"d_in": 4096,
537+
"d_sae": 65536,
538+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
539+
"hook_name": "blocks.19.hook_resid_post",
540+
"hook_head_index": None,
541+
"dataset_path": "lmsys/lmsys-chat-1m",
542+
"apply_b_dec_to_input": False,
543+
"device": "cpu",
544+
}
545+
assert cfg == expected_cfg
546+
547+
548+
def test_get_goodfire_config_from_hf_errors_on_unsupported_sae():
549+
with pytest.raises(
550+
ValueError,
551+
match="Unsupported Goodfire SAE: wrong/repo",
552+
):
553+
get_goodfire_config_from_hf(
554+
repo_id="wrong/repo",
555+
folder_name="Llama-3.3-70B-Instruct-SAE-l50.pt",
556+
device="cpu",
557+
)
558+
with pytest.raises(
559+
ValueError,
560+
match="Unsupported Goodfire SAE: Goodfire/Llama-3.3-70B-Instruct-SAE-l50/wrong_filename.pt",
561+
):
562+
get_goodfire_config_from_hf(
563+
repo_id="Goodfire/Llama-3.3-70B-Instruct-SAE-l50",
564+
folder_name="wrong_filename.pt",
565+
device="cpu",
566+
)
567+
568+
569+
def test_our_sae_matches_goodfires_implementation():
570+
# from https://colab.research.google.com/drive/1IBMQtJqy8JiRk1Q48jDEgTISmtxhlCRL
571+
class GoodfireSAE(torch.nn.Module):
572+
def __init__(
573+
self,
574+
d_in: int,
575+
d_hidden: int,
576+
device: torch.device,
577+
dtype: torch.dtype = torch.float32,
578+
):
579+
super().__init__()
580+
self.d_in = d_in
581+
self.d_hidden = d_hidden
582+
self.device = device
583+
self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
584+
self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
585+
self.dtype = dtype
586+
self.to(self.device, self.dtype)
587+
588+
def encode(self, x: torch.Tensor) -> torch.Tensor:
589+
"""Encode a batch of data using a linear, followed by a ReLU."""
590+
return torch.nn.functional.relu(self.encoder_linear(x))
591+
592+
def decode(self, x: torch.Tensor) -> torch.Tensor:
593+
"""Decode a batch of data using a linear."""
594+
return self.decoder_linear(x)
595+
596+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
597+
"""SAE forward pass. Returns the reconstruction and the encoded features."""
598+
f = self.encode(x)
599+
return self.decode(f), f
600+
601+
cfg_dict = load_sae_config_from_huggingface(
602+
release="goodfire-llama-3.3-70b-instruct",
603+
sae_id="layer_50",
604+
device="cpu",
605+
)
606+
cfg_dict["d_in"] = 128
607+
cfg_dict["d_sae"] = 256
608+
cfg_dict["dtype"] = "float32"
609+
610+
assert cfg_dict["architecture"] == "standard"
611+
cfg = StandardSAEConfig.from_dict(cfg_dict)
612+
613+
# make a SAE base on the Goodfire config, but smaller since the real SAE class is huge
614+
sae = StandardSAE(cfg)
615+
random_params(sae)
616+
617+
sae_state_dict = sae.state_dict()
618+
goodfire_state_dict = {
619+
"encoder_linear.weight": sae_state_dict["W_enc"].T,
620+
"encoder_linear.bias": sae_state_dict["b_enc"],
621+
"decoder_linear.weight": sae_state_dict["W_dec"].T,
622+
"decoder_linear.bias": sae_state_dict["b_dec"],
623+
}
624+
625+
goodfire_sae = GoodfireSAE(d_in=128, d_hidden=256, device=torch.device("cpu"))
626+
goodfire_sae.load_state_dict(goodfire_state_dict)
627+
628+
test_input = torch.randn(10, 128)
629+
630+
output = sae(test_input)
631+
features = sae.encode(test_input)
632+
goodfire_output, goodfire_features = goodfire_sae(test_input)
633+
634+
assert_close(output, goodfire_output, rtol=1e-4, atol=1e-4)
635+
assert_close(features, goodfire_features, rtol=1e-4, atol=1e-4)
636+
637+
638+
def test_get_goodfire_huggingface_loader_with_mocked_download(
639+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
640+
):
641+
repo_id = "Goodfire/Llama-3.3-70B-Instruct-SAE-l50"
642+
folder_name = "Llama-3.3-70B-Instruct-SAE-l50.pt"
643+
device = "cpu"
644+
645+
d_in = 128
646+
d_sae = 256
647+
648+
encoder_weight = torch.randn(d_sae, d_in)
649+
decoder_weight = torch.randn(d_in, d_sae)
650+
encoder_bias = torch.randn(d_sae)
651+
decoder_bias = torch.randn(d_in)
652+
653+
raw_state_dict = {
654+
"encoder_linear.weight": encoder_weight,
655+
"decoder_linear.weight": decoder_weight,
656+
"encoder_linear.bias": encoder_bias,
657+
"decoder_linear.bias": decoder_bias,
658+
}
659+
660+
sae_file_path = tmp_path / folder_name
661+
torch.save(raw_state_dict, sae_file_path)
662+
663+
def mock_get_goodfire_config_from_hf(
664+
repo_id: str, # noqa: ARG001
665+
folder_name: str, # noqa: ARG001
666+
device: str,
667+
force_download: bool = False, # noqa: ARG001
668+
cfg_overrides: dict[str, Any] | None = None, # noqa: ARG001
669+
) -> dict[str, Any]:
670+
return {
671+
"architecture": "standard",
672+
"d_in": d_in,
673+
"d_sae": d_sae,
674+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
675+
"hook_name": "blocks.50.hook_resid_post",
676+
"hook_head_index": None,
677+
"dataset_path": "lmsys/lmsys-chat-1m",
678+
"apply_b_dec_to_input": False,
679+
"device": device,
680+
}
681+
682+
def mock_hf_hub_download(
683+
repo_id: str, # noqa: ARG001
684+
filename: str, # noqa: ARG001
685+
force_download: bool = False, # noqa: ARG001
686+
) -> str:
687+
return str(sae_file_path)
688+
689+
monkeypatch.setattr(
690+
"sae_lens.loading.pretrained_sae_loaders.get_goodfire_config_from_hf",
691+
mock_get_goodfire_config_from_hf,
692+
)
693+
monkeypatch.setattr(
694+
"sae_lens.loading.pretrained_sae_loaders.hf_hub_download", mock_hf_hub_download
695+
)
696+
697+
cfg_dict, state_dict, log_sparsity = get_goodfire_huggingface_loader(
698+
repo_id=repo_id,
699+
folder_name=folder_name,
700+
device=device,
701+
force_download=False,
702+
cfg_overrides=None,
703+
)
704+
705+
expected_cfg = {
706+
"architecture": "standard",
707+
"d_in": d_in,
708+
"d_sae": d_sae,
709+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
710+
"hook_name": "blocks.50.hook_resid_post",
711+
"hook_head_index": None,
712+
"dataset_path": "lmsys/lmsys-chat-1m",
713+
"apply_b_dec_to_input": False,
714+
"device": device,
715+
}
716+
717+
assert cfg_dict == expected_cfg
718+
assert log_sparsity is None
719+
720+
assert set(state_dict.keys()) == {"W_enc", "W_dec", "b_enc", "b_dec"}
721+
torch.testing.assert_close(state_dict["W_enc"], encoder_weight.T)
722+
torch.testing.assert_close(state_dict["W_dec"], decoder_weight.T)
723+
torch.testing.assert_close(state_dict["b_enc"], encoder_bias)
724+
torch.testing.assert_close(state_dict["b_dec"], decoder_bias)
725+
726+
assert state_dict["W_enc"].shape == (d_in, d_sae)
727+
assert state_dict["W_dec"].shape == (d_sae, d_in)
728+
assert state_dict["b_enc"].shape == (d_sae,)
729+
assert state_dict["b_dec"].shape == (d_in,)
730+
731+
503732
def test_get_llama_scope_r1_distill_config_from_hf():
504733
"""Test that the Llama Scope R1 Distill config is generated correctly."""
505734
cfg = get_llama_scope_r1_distill_config_from_hf(

0 commit comments

Comments
 (0)