Skip to content

Refactor arch configs #468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 22, 2025
Merged

Refactor arch configs #468

merged 31 commits into from
May 22, 2025

Conversation

chanind
Copy link
Collaborator

@chanind chanind commented Apr 30, 2025

Description

This PR is WIP with the goal of simplifying and separating out SAE architecture configs. Each SAE arch now gets its own config, which can be further customized. I've also removed some config options that are legacy / not well used or documented. These deletions includes:

  • ghost grads
  • most decoder init stuff (encoder is always init as decoder transpose)
  • decoder finetuning
  • normalize decoder (L1 SAEs just scale L1 loss by decoder norm always - doing otherwise is basically always wrong)

I've tried to move only stuff into the base SAEConfig if it's actually needed to run the the SAE, e.g. size of the SAE, rather than stuff that's useful to know but not actually needed (e.g. what model / layer / L1 coefficient, etc...). This extra info is moved to a metadata option on the config.

This PR also refactors the way various coefficients work, so each training SAE class must implement get_coefficients() that returns a dict of coefficient names and values / warm-up step. This solves the problem that L1 SAEs have a L1 coefficient, but JumpReLU has a L0 coefficient, and topk have neither (but may have an aux coefficient in the future).

These changes should also make it easy to add new architecture or tweak exsiting architectures. You just need call register_sae_training_class() and register_sae_class() with your custom SAE class / config, and then you can train with it.

Still TODO

  • backwards compatibility with old configs
  • make sure saving / loading works
  • save the training config alongside the inference config and upload that to huggingface as well
  • autopopulate config metadata
  • fix remaining tests
  • update docs
  • ensure new init code is reasonable
  • test real SAE training to verify everything is the same

Copy link

codecov bot commented May 4, 2025

Codecov Report

Attention: Patch coverage is 79.83015% with 95 lines in your changes missing coverage. Please review.

Please upload report for BASE (alpha@09f2457). Learn more about missing BASE report.

Files with missing lines Patch % Lines
sae_lens/saes/sae.py 71.89% 28 Missing and 15 partials ⚠️
sae_lens/training/optim.py 44.00% 11 Missing and 3 partials ⚠️
sae_lens/training/activations_store.py 16.66% 10 Missing ⚠️
sae_lens/registry.py 64.70% 3 Missing and 3 partials ⚠️
sae_lens/util.py 62.50% 3 Missing and 3 partials ⚠️
sae_lens/config.py 86.66% 1 Missing and 1 partial ⚠️
sae_lens/evals.py 77.77% 1 Missing and 1 partial ⚠️
sae_lens/loading/pretrained_sae_loaders.py 91.30% 1 Missing and 1 partial ⚠️
sae_lens/sae_training_runner.py 86.66% 2 Missing ⚠️
sae_lens/saes/gated_sae.py 94.28% 2 Missing ⚠️
... and 3 more
Additional details and impacted files
@@           Coverage Diff            @@
##             alpha     #468   +/-   ##
========================================
  Coverage         ?   71.68%           
========================================
  Files            ?       25           
  Lines            ?     3426           
  Branches         ?      447           
========================================
  Hits             ?     2456           
  Misses           ?      803           
  Partials         ?      167           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@chanind chanind marked this pull request as ready for review May 4, 2025 22:35
@chanind chanind requested a review from anthonyduong9 May 4, 2025 22:37
self.mse_loss_fn = self._get_mse_loss_fn()

@abstractmethod
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to keep float | in the type? There are no errors if we remove it from this method and the subclass methods.

Copy link
Collaborator Author

@chanind chanind May 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that users might want to implement their own SAE classes by extending the base SAE class and just return a float per coefficient, as it could be confusing to need to return a TrainCoefficientConfig, and likely most users don't care that much about the warm up stuff

@chanind chanind requested a review from anthonyduong9 May 17, 2025 12:54
Copy link
Collaborator

@anthonyduong9 anthonyduong9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting an error on this branch when I run make docs-serve.

anthonyduong@Anthonys-MacBook-Pro-2 SAELens % make docs-serve
poetry run mkdocs serve
INFO    -  DeprecationWarning: Importing from 'mkdocs_autorefs.plugin' is deprecated. Import directly from 'mkdocs_autorefs' instead.
             File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocstrings/plugin.py", line 29, in <module>
               from mkdocs_autorefs.plugin import AutorefsPlugin
             File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs_autorefs/plugin.py", line 12, in __getattr__
               warnings.warn(
INFO    -  FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
             File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/transformers/utils/__init__.py", line 74, in <module>
               from .hub import (
             File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/transformers/utils/hub.py", line 105, in <module>
               warnings.warn(
WARNING -  Config value 'plugins': Plugin 'mkdocstrings' option 'watch': Unrecognised configuration name: watch
INFO    -  Building documentation...
INFO    -  DeprecationWarning: Setting a fallback anchor function is deprecated and will be removed in a future release.
             File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocstrings/plugin.py", line 188, in on_config
               autorefs.get_fallback_anchor = self.handlers.get_anchors
             File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs_autorefs/_internal/plugin.py", line 562, in get_fallback_anchor
               warn(
Generating SAE table...
  0%|                                                                                                                                                                                                                         | 0/63 [00:00<?, ?it/sINFO    -  DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.                               | 0/1 [00:00<?, ?it/s]
             File "/Users/anthonyduong/Code/SAELens/sae_lens/loading/pretrained_saes_directory.py", line 27, in get_pretrained_saes_directory
               with resources.open_text(package, "pretrained_saes.yaml") as file:
             File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/resources/_legacy.py", line 18, in wrapper
               warnings.warn(
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.43it/s]
  0%|                                                                                                                                                                                                                         | 0/63 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/bin/mkdocs", line 8, in <module>
    sys.exit(cli())
             ^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1161, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1082, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1697, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1443, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 788, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/__main__.py", line 272, in serve_command
    serve.serve(**kwargs)
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/commands/serve.py", line 85, in serve
    builder(config)
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/commands/serve.py", line 67, in builder
    build(config, serve_url=None if is_clean else serve_url, dirty=is_dirty)
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/commands/build.py", line 268, in build
    config.plugins.on_pre_build(config=config)
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/plugins.py", line 590, in on_pre_build
    return self.run_event('pre_build', config=config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/plugins.py", line 568, in run_event
    result = method(**kwargs)
             ^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Code/SAELens/docs/generate_sae_table.py", line 31, in on_pre_build
    generate_sae_table()
  File "/Users/anthonyduong/Code/SAELens/docs/generate_sae_table.py", line 78, in generate_sae_table
    df = df[INCLUDED_CFG]
         ~~^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/pandas/core/frame.py", line 4108, in __getitem__
    indexer = self.columns._get_indexer_strict(key, "columns")[1]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/pandas/core/indexes/base.py", line 6200, in _get_indexer_strict
    self._raise_if_missing(keyarr, indexer, axis_name)
  File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/pandas/core/indexes/base.py", line 6252, in _raise_if_missing
    raise KeyError(f"{not_found} not in index")
KeyError: "['hook_name', 'hook_layer', 'context_size', 'dataset_path'] not in index"
make: *** [docs-serve] Error 1

I'm guessing that we need to fix the error in #425 (comment) before fixing the one above though.

@chanind
Copy link
Collaborator Author

chanind commented May 20, 2025

@anthonyduong9 good catch on the docs table, that should be fixed now in f504d2e

@anthonyduong9 anthonyduong9 self-requested a review May 22, 2025 06:12
Copy link
Collaborator

@anthonyduong9 anthonyduong9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@chanind chanind merged commit 5063a29 into alpha May 22, 2025
4 checks passed
@chanind chanind deleted the refactor-arch-configs branch May 22, 2025 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants