Skip to content

Refactor arch configs #468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: alpha
Choose a base branch
from
Draft

Refactor arch configs #468

wants to merge 8 commits into from

Conversation

chanind
Copy link
Collaborator

@chanind chanind commented Apr 30, 2025

Description

This PR is WIP with the goal of simplifying and separating out SAE architecture configs. Each SAE arch now gets its own config, which can be further customized. I've also removed some config options that are legacy / not well used or documented. These deletions includes:

  • ghost grads
  • decoder init stuff (decoder is always rand init to encoder transpose with unit norm)
  • decoder finetuning
  • normalize decoder (L1 SAEs just scale L1 loss by decoder norm always - doing otherwise is basically always wrong)

I've tried to move only stuff into the base SAEConfig if it's actually needed to run the the SAE, e.g. size of the SAE, rather than stuff that's useful to know but not actually needed (e.g. what model / layer / L1 coefficient, etc...). This extra info is moved to a metadata option on the config.

This PR also refactors the way various coefficients work, so each training SAE class must implement get_coefficients() that returns a dict of coefficient names and values / warm-up step. This solves the problem that L1 SAEs have a L1 coefficient, but JumpReLU has a L0 coefficient, and topk have neither (but may have an aux coefficient in the future).

These changes should also make it easy to add new architecture or tweak exsiting architectures. You just need call register_sae_training_class() and register_sae_class() with your custom SAE class / config, and then you can train with it.

Still TODO

  • backwards compatibility with old configs
  • make sure saving / loading works
  • save the training config alongside the inference config and upload that to huggingface as well
  • autopopulate config metadata
  • fix remaining tests
  • update docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant