-
Notifications
You must be signed in to change notification settings - Fork 178
Refactor arch configs #468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## alpha #468 +/- ##
========================================
Coverage ? 71.68%
========================================
Files ? 25
Lines ? 3426
Branches ? 447
========================================
Hits ? 2456
Misses ? 803
Partials ? 167 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
self.mse_loss_fn = self._get_mse_loss_fn() | ||
|
||
@abstractmethod | ||
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to keep float |
in the type? There are no errors if we remove it from this method and the subclass methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that users might want to implement their own SAE classes by extending the base SAE class and just return a float per coefficient, as it could be confusing to need to return a TrainCoefficientConfig
, and likely most users don't care that much about the warm up stuff
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
Co-authored-by: Anthony Duong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting an error on this branch when I run make docs-serve
.
anthonyduong@Anthonys-MacBook-Pro-2 SAELens % make docs-serve
poetry run mkdocs serve
INFO - DeprecationWarning: Importing from 'mkdocs_autorefs.plugin' is deprecated. Import directly from 'mkdocs_autorefs' instead.
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocstrings/plugin.py", line 29, in <module>
from mkdocs_autorefs.plugin import AutorefsPlugin
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs_autorefs/plugin.py", line 12, in __getattr__
warnings.warn(
INFO - FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/transformers/utils/__init__.py", line 74, in <module>
from .hub import (
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/transformers/utils/hub.py", line 105, in <module>
warnings.warn(
WARNING - Config value 'plugins': Plugin 'mkdocstrings' option 'watch': Unrecognised configuration name: watch
INFO - Building documentation...
INFO - DeprecationWarning: Setting a fallback anchor function is deprecated and will be removed in a future release.
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocstrings/plugin.py", line 188, in on_config
autorefs.get_fallback_anchor = self.handlers.get_anchors
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs_autorefs/_internal/plugin.py", line 562, in get_fallback_anchor
warn(
Generating SAE table...
0%| | 0/63 [00:00<?, ?it/sINFO - DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. | 0/1 [00:00<?, ?it/s]
File "/Users/anthonyduong/Code/SAELens/sae_lens/loading/pretrained_saes_directory.py", line 27, in get_pretrained_saes_directory
with resources.open_text(package, "pretrained_saes.yaml") as file:
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/resources/_legacy.py", line 18, in wrapper
warnings.warn(
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.43it/s]
0%| | 0/63 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/bin/mkdocs", line 8, in <module>
sys.exit(cli())
^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1161, in __call__
return self.main(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1082, in main
rv = self.invoke(ctx)
^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1697, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 1443, in invoke
return ctx.invoke(self.callback, **ctx.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/click/core.py", line 788, in invoke
return __callback(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/__main__.py", line 272, in serve_command
serve.serve(**kwargs)
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/commands/serve.py", line 85, in serve
builder(config)
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/commands/serve.py", line 67, in builder
build(config, serve_url=None if is_clean else serve_url, dirty=is_dirty)
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/commands/build.py", line 268, in build
config.plugins.on_pre_build(config=config)
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/plugins.py", line 590, in on_pre_build
return self.run_event('pre_build', config=config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/mkdocs/plugins.py", line 568, in run_event
result = method(**kwargs)
^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Code/SAELens/docs/generate_sae_table.py", line 31, in on_pre_build
generate_sae_table()
File "/Users/anthonyduong/Code/SAELens/docs/generate_sae_table.py", line 78, in generate_sae_table
df = df[INCLUDED_CFG]
~~^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/pandas/core/frame.py", line 4108, in __getitem__
indexer = self.columns._get_indexer_strict(key, "columns")[1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/pandas/core/indexes/base.py", line 6200, in _get_indexer_strict
self._raise_if_missing(keyarr, indexer, axis_name)
File "/Users/anthonyduong/Library/Caches/pypoetry/virtualenvs/sae-lens-xw7-A_jW-py3.11/lib/python3.11/site-packages/pandas/core/indexes/base.py", line 6252, in _raise_if_missing
raise KeyError(f"{not_found} not in index")
KeyError: "['hook_name', 'hook_layer', 'context_size', 'dataset_path'] not in index"
make: *** [docs-serve] Error 1
I'm guessing that we need to fix the error in #425 (comment) before fixing the one above though.
@anthonyduong9 good catch on the docs table, that should be fixed now in f504d2e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Description
This PR is WIP with the goal of simplifying and separating out SAE architecture configs. Each SAE arch now gets its own config, which can be further customized. I've also removed some config options that are legacy / not well used or documented. These deletions includes:
I've tried to move only stuff into the base SAEConfig if it's actually needed to run the the SAE, e.g. size of the SAE, rather than stuff that's useful to know but not actually needed (e.g. what model / layer / L1 coefficient, etc...). This extra info is moved to a metadata option on the config.
This PR also refactors the way various coefficients work, so each training SAE class must implement
get_coefficients()
that returns a dict of coefficient names and values / warm-up step. This solves the problem that L1 SAEs have a L1 coefficient, but JumpReLU has a L0 coefficient, and topk have neither (but may have an aux coefficient in the future).These changes should also make it easy to add new architecture or tweak exsiting architectures. You just need call
register_sae_training_class()
andregister_sae_class()
with your custom SAE class / config, and then you can train with it.Still TODO