fix: GatedSAE needs to prevent the model from cheating L0 over steps #546
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.


Description
Summary: This PR fix issue where the model cheat the L0 over steps. There is no dependencies that are required for this change. This PR is committed on top of PR #545
PR #545 fix issue where the loss from
pi_via_gatealso flow to theW_decandb_dechttps://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/gated_sae.py#L196
This PR fix issue where the model cheat the L0 over steps: the loss from
pi_via_gateshrinks the encoder in the first step, and then increase the decoder in the next step to adjust to the changed encoder. Repeat this over many steps, the encoder will be very small and decoder will be very largeNote: The above is an assumption. However, the experiment results below show that GatedSAE with normalized decoder does have higher
metrics/explained_variancethan without normalization.Fixes #544
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and tests
make check-cito check format and linting. (you can runmake formatto format code if needed.)Performance Check.
If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:
Please links to wandb dashboards with a control and test group.
WandB dashboard
Control and test group: Tag
Gated No Scale vs Norm Aux