Skip to content

Conversation

@callummcdougall
Copy link
Contributor

This allows seqpos slicing during training. Basically we add a seqpos_slice arg to the LanguageModelSAERunnerConfig (in the form of a tuple, which gets converted to a slice via slice(*seqpos_slice) - this is because slice objects aren't serializable when we're saving the config).

Apart from this config, the only other file getting changed is activations_store.py. It now has a seqpos_slice attribute, and it uses this to slice the activations which are fetched from get_activations (and which are used in get_buffer).

Note that the default behaviour is seqpos_slice = (None,), which slices over all sequence positions. Also note that seqpos_slice can be used in conjunction with context_size (i.e. one doesn't make the other redundant).

@jbloomAus
Copy link
Collaborator

Hey @callummcdougall I've pushed:

  • tests (in the future, please add tests!)
  • ensured that this is in the SAE config (Since using an SAE / evaluating an SAE correctly will rely on knowing seqpos slice).
  • That we can serialize / deserialize without modification.

@callummcdougall
Copy link
Contributor Author

Got it, sorry for causing undue work - yes in the future will make sure to add tests! I wasn't sure about putting it in the sae config cause it's about the SAE's training data (or what inputs make sense for it) but not about e.g. the SAE's actual architecture. I was basing this on the fact that ActivationsStore gets initialized from_config which is a LanguageModelSAERunnerConfig not SAEConfig (although now I'm looking at that page, I see that it can also get initialized from_sae, so I get why this should be added).

@jbloomAus
Copy link
Collaborator

@callummcdougall I think the idea is that if you couldn't evaluate the SAE without knowing about this property, then it needs to be in the SAE config.

Speaking of which I don't see any changes to the evals.py but presumably we should ensure that evals are only run on seqpos positions? Are you able to do this?

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code-wise this looks good to me, and looks like a reasonable addition to the library! Will defer to @jbloomAus if this is OK to merge. I guess there's a question fo whether the expectation is that this would require different evals, or if this is something that only effects training.

activations = activation_store.get_activations(batch)

assert batch.shape == (1, 10) # Full context size
assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Really great test 🥇

@callummcdougall
Copy link
Contributor Author

Code-wise this looks good to me, and looks like a reasonable addition to the library! Will defer to @jbloomAus if this is OK to merge. I guess there's a question fo whether the expectation is that this would require different evals, or if this is something that only effects training.

Think it does seem valuable to also have the logged metrics during training only apply to the right sequence positions - is that what you meant @jbloomAus , or did you mean evals that are applied in a non-training context? Either way I can likely get to that later this week

@callummcdougall callummcdougall marked this pull request as draft October 2, 2024 15:34
@callummcdougall callummcdougall marked this pull request as ready for review October 11, 2024 09:33
@callummcdougall callummcdougall marked this pull request as draft October 11, 2024 09:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants