|
1 | 1 | from pathlib import Path |
| 2 | +from typing import Any |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 | import torch |
5 | 6 | import yaml |
6 | 7 | from safetensors.torch import save_file |
7 | 8 | from sparsify import SparseCoder, SparseCoderConfig |
8 | 9 |
|
| 10 | +from sae_lens import StandardSAE, StandardSAEConfig |
9 | 11 | from sae_lens.loading.pretrained_sae_loaders import ( |
10 | 12 | dictionary_learning_sae_huggingface_loader_1, |
11 | 13 | get_deepseek_r1_config_from_hf, |
12 | 14 | get_gemma_2_transcoder_config_from_hf, |
| 15 | + get_goodfire_config_from_hf, |
| 16 | + get_goodfire_huggingface_loader, |
13 | 17 | get_llama_scope_config_from_hf, |
14 | 18 | get_llama_scope_r1_distill_config_from_hf, |
15 | 19 | get_mntss_clt_layer_config_from_hf, |
|
21 | 25 | sparsify_huggingface_loader, |
22 | 26 | ) |
23 | 27 | from sae_lens.saes.sae import SAE |
| 28 | +from tests.helpers import assert_close, random_params |
24 | 29 |
|
25 | 30 |
|
26 | 31 | def test_load_sae_config_from_huggingface(): |
@@ -500,6 +505,230 @@ def test_get_llama_scope_config_from_hf(): |
500 | 505 | assert cfg == expected_cfg |
501 | 506 |
|
502 | 507 |
|
| 508 | +def test_get_goodfire_config_from_hf(): |
| 509 | + cfg = get_goodfire_config_from_hf( |
| 510 | + repo_id="Goodfire/Llama-3.3-70B-Instruct-SAE-l50", |
| 511 | + folder_name="Llama-3.3-70B-Instruct-SAE-l50.pt", |
| 512 | + device="cpu", |
| 513 | + ) |
| 514 | + expected_cfg = { |
| 515 | + "architecture": "standard", |
| 516 | + "d_in": 8192, |
| 517 | + "d_sae": 65536, |
| 518 | + "model_name": "meta-llama/Llama-3.3-70B-Instruct", |
| 519 | + "hook_name": "blocks.50.hook_resid_post", |
| 520 | + "hook_head_index": None, |
| 521 | + "dataset_path": "lmsys/lmsys-chat-1m", |
| 522 | + "apply_b_dec_to_input": False, |
| 523 | + "device": "cpu", |
| 524 | + } |
| 525 | + assert cfg == expected_cfg |
| 526 | + |
| 527 | + |
| 528 | +def test_get_goodfire_llama_8b_config_from_hf(): |
| 529 | + cfg = get_goodfire_config_from_hf( |
| 530 | + repo_id="Goodfire/Llama-3.1-8B-Instruct-SAE-l19", |
| 531 | + folder_name="Llama-3.1-8B-Instruct-SAE-l19.pth", |
| 532 | + device="cpu", |
| 533 | + ) |
| 534 | + expected_cfg = { |
| 535 | + "architecture": "standard", |
| 536 | + "d_in": 4096, |
| 537 | + "d_sae": 65536, |
| 538 | + "model_name": "meta-llama/Llama-3.1-8B-Instruct", |
| 539 | + "hook_name": "blocks.19.hook_resid_post", |
| 540 | + "hook_head_index": None, |
| 541 | + "dataset_path": "lmsys/lmsys-chat-1m", |
| 542 | + "apply_b_dec_to_input": False, |
| 543 | + "device": "cpu", |
| 544 | + } |
| 545 | + assert cfg == expected_cfg |
| 546 | + |
| 547 | + |
| 548 | +def test_get_goodfire_config_from_hf_errors_on_unsupported_sae(): |
| 549 | + with pytest.raises( |
| 550 | + ValueError, |
| 551 | + match="Unsupported Goodfire SAE: wrong/repo", |
| 552 | + ): |
| 553 | + get_goodfire_config_from_hf( |
| 554 | + repo_id="wrong/repo", |
| 555 | + folder_name="Llama-3.3-70B-Instruct-SAE-l50.pt", |
| 556 | + device="cpu", |
| 557 | + ) |
| 558 | + with pytest.raises( |
| 559 | + ValueError, |
| 560 | + match="Unsupported Goodfire SAE: Goodfire/Llama-3.3-70B-Instruct-SAE-l50/wrong_filename.pt", |
| 561 | + ): |
| 562 | + get_goodfire_config_from_hf( |
| 563 | + repo_id="Goodfire/Llama-3.3-70B-Instruct-SAE-l50", |
| 564 | + folder_name="wrong_filename.pt", |
| 565 | + device="cpu", |
| 566 | + ) |
| 567 | + |
| 568 | + |
| 569 | +def test_our_sae_matches_goodfires_implementation(): |
| 570 | + # from https://colab.research.google.com/drive/1IBMQtJqy8JiRk1Q48jDEgTISmtxhlCRL |
| 571 | + class GoodfireSAE(torch.nn.Module): |
| 572 | + def __init__( |
| 573 | + self, |
| 574 | + d_in: int, |
| 575 | + d_hidden: int, |
| 576 | + device: torch.device, |
| 577 | + dtype: torch.dtype = torch.float32, |
| 578 | + ): |
| 579 | + super().__init__() |
| 580 | + self.d_in = d_in |
| 581 | + self.d_hidden = d_hidden |
| 582 | + self.device = device |
| 583 | + self.encoder_linear = torch.nn.Linear(d_in, d_hidden) |
| 584 | + self.decoder_linear = torch.nn.Linear(d_hidden, d_in) |
| 585 | + self.dtype = dtype |
| 586 | + self.to(self.device, self.dtype) |
| 587 | + |
| 588 | + def encode(self, x: torch.Tensor) -> torch.Tensor: |
| 589 | + """Encode a batch of data using a linear, followed by a ReLU.""" |
| 590 | + return torch.nn.functional.relu(self.encoder_linear(x)) |
| 591 | + |
| 592 | + def decode(self, x: torch.Tensor) -> torch.Tensor: |
| 593 | + """Decode a batch of data using a linear.""" |
| 594 | + return self.decoder_linear(x) |
| 595 | + |
| 596 | + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 597 | + """SAE forward pass. Returns the reconstruction and the encoded features.""" |
| 598 | + f = self.encode(x) |
| 599 | + return self.decode(f), f |
| 600 | + |
| 601 | + cfg_dict = load_sae_config_from_huggingface( |
| 602 | + release="goodfire-llama-3.3-70b-instruct", |
| 603 | + sae_id="layer_50", |
| 604 | + device="cpu", |
| 605 | + ) |
| 606 | + cfg_dict["d_in"] = 128 |
| 607 | + cfg_dict["d_sae"] = 256 |
| 608 | + cfg_dict["dtype"] = "float32" |
| 609 | + |
| 610 | + assert cfg_dict["architecture"] == "standard" |
| 611 | + cfg = StandardSAEConfig.from_dict(cfg_dict) |
| 612 | + |
| 613 | + # make a SAE base on the Goodfire config, but smaller since the real SAE class is huge |
| 614 | + sae = StandardSAE(cfg) |
| 615 | + random_params(sae) |
| 616 | + |
| 617 | + sae_state_dict = sae.state_dict() |
| 618 | + goodfire_state_dict = { |
| 619 | + "encoder_linear.weight": sae_state_dict["W_enc"].T, |
| 620 | + "encoder_linear.bias": sae_state_dict["b_enc"], |
| 621 | + "decoder_linear.weight": sae_state_dict["W_dec"].T, |
| 622 | + "decoder_linear.bias": sae_state_dict["b_dec"], |
| 623 | + } |
| 624 | + |
| 625 | + goodfire_sae = GoodfireSAE(d_in=128, d_hidden=256, device=torch.device("cpu")) |
| 626 | + goodfire_sae.load_state_dict(goodfire_state_dict) |
| 627 | + |
| 628 | + test_input = torch.randn(10, 128) |
| 629 | + |
| 630 | + output = sae(test_input) |
| 631 | + features = sae.encode(test_input) |
| 632 | + goodfire_output, goodfire_features = goodfire_sae(test_input) |
| 633 | + |
| 634 | + assert_close(output, goodfire_output, rtol=1e-4, atol=1e-4) |
| 635 | + assert_close(features, goodfire_features, rtol=1e-4, atol=1e-4) |
| 636 | + |
| 637 | + |
| 638 | +def test_get_goodfire_huggingface_loader_with_mocked_download( |
| 639 | + tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 640 | +): |
| 641 | + repo_id = "Goodfire/Llama-3.3-70B-Instruct-SAE-l50" |
| 642 | + folder_name = "Llama-3.3-70B-Instruct-SAE-l50.pt" |
| 643 | + device = "cpu" |
| 644 | + |
| 645 | + d_in = 128 |
| 646 | + d_sae = 256 |
| 647 | + |
| 648 | + encoder_weight = torch.randn(d_sae, d_in) |
| 649 | + decoder_weight = torch.randn(d_in, d_sae) |
| 650 | + encoder_bias = torch.randn(d_sae) |
| 651 | + decoder_bias = torch.randn(d_in) |
| 652 | + |
| 653 | + raw_state_dict = { |
| 654 | + "encoder_linear.weight": encoder_weight, |
| 655 | + "decoder_linear.weight": decoder_weight, |
| 656 | + "encoder_linear.bias": encoder_bias, |
| 657 | + "decoder_linear.bias": decoder_bias, |
| 658 | + } |
| 659 | + |
| 660 | + sae_file_path = tmp_path / folder_name |
| 661 | + torch.save(raw_state_dict, sae_file_path) |
| 662 | + |
| 663 | + def mock_get_goodfire_config_from_hf( |
| 664 | + repo_id: str, # noqa: ARG001 |
| 665 | + folder_name: str, # noqa: ARG001 |
| 666 | + device: str, |
| 667 | + force_download: bool = False, # noqa: ARG001 |
| 668 | + cfg_overrides: dict[str, Any] | None = None, # noqa: ARG001 |
| 669 | + ) -> dict[str, Any]: |
| 670 | + return { |
| 671 | + "architecture": "standard", |
| 672 | + "d_in": d_in, |
| 673 | + "d_sae": d_sae, |
| 674 | + "model_name": "meta-llama/Llama-3.3-70B-Instruct", |
| 675 | + "hook_name": "blocks.50.hook_resid_post", |
| 676 | + "hook_head_index": None, |
| 677 | + "dataset_path": "lmsys/lmsys-chat-1m", |
| 678 | + "apply_b_dec_to_input": False, |
| 679 | + "device": device, |
| 680 | + } |
| 681 | + |
| 682 | + def mock_hf_hub_download( |
| 683 | + repo_id: str, # noqa: ARG001 |
| 684 | + filename: str, # noqa: ARG001 |
| 685 | + force_download: bool = False, # noqa: ARG001 |
| 686 | + ) -> str: |
| 687 | + return str(sae_file_path) |
| 688 | + |
| 689 | + monkeypatch.setattr( |
| 690 | + "sae_lens.loading.pretrained_sae_loaders.get_goodfire_config_from_hf", |
| 691 | + mock_get_goodfire_config_from_hf, |
| 692 | + ) |
| 693 | + monkeypatch.setattr( |
| 694 | + "sae_lens.loading.pretrained_sae_loaders.hf_hub_download", mock_hf_hub_download |
| 695 | + ) |
| 696 | + |
| 697 | + cfg_dict, state_dict, log_sparsity = get_goodfire_huggingface_loader( |
| 698 | + repo_id=repo_id, |
| 699 | + folder_name=folder_name, |
| 700 | + device=device, |
| 701 | + force_download=False, |
| 702 | + cfg_overrides=None, |
| 703 | + ) |
| 704 | + |
| 705 | + expected_cfg = { |
| 706 | + "architecture": "standard", |
| 707 | + "d_in": d_in, |
| 708 | + "d_sae": d_sae, |
| 709 | + "model_name": "meta-llama/Llama-3.3-70B-Instruct", |
| 710 | + "hook_name": "blocks.50.hook_resid_post", |
| 711 | + "hook_head_index": None, |
| 712 | + "dataset_path": "lmsys/lmsys-chat-1m", |
| 713 | + "apply_b_dec_to_input": False, |
| 714 | + "device": device, |
| 715 | + } |
| 716 | + |
| 717 | + assert cfg_dict == expected_cfg |
| 718 | + assert log_sparsity is None |
| 719 | + |
| 720 | + assert set(state_dict.keys()) == {"W_enc", "W_dec", "b_enc", "b_dec"} |
| 721 | + torch.testing.assert_close(state_dict["W_enc"], encoder_weight.T) |
| 722 | + torch.testing.assert_close(state_dict["W_dec"], decoder_weight.T) |
| 723 | + torch.testing.assert_close(state_dict["b_enc"], encoder_bias) |
| 724 | + torch.testing.assert_close(state_dict["b_dec"], decoder_bias) |
| 725 | + |
| 726 | + assert state_dict["W_enc"].shape == (d_in, d_sae) |
| 727 | + assert state_dict["W_dec"].shape == (d_sae, d_in) |
| 728 | + assert state_dict["b_enc"].shape == (d_sae,) |
| 729 | + assert state_dict["b_dec"].shape == (d_in,) |
| 730 | + |
| 731 | + |
503 | 732 | def test_get_llama_scope_r1_distill_config_from_hf(): |
504 | 733 | """Test that the Llama Scope R1 Distill config is generated correctly.""" |
505 | 734 | cfg = get_llama_scope_r1_distill_config_from_hf( |
|
0 commit comments