Add Direct Logit Attribution tool for TransformerBridge#1316
Conversation
jlarson4
left a comment
There was a problem hiding this comment.
Hi @TravisHaa! I have reviewed your PR and left a few comments. I cannot run the code in its current state, until some of these comments are addressed. Let me know if you have any questions.
Also, the PR is marked as "tests added", but I am not seeing any tests in the diff. Could you add tests/unit/tools/test_direct_logit_attribution.py covering at least:
- Correct path on a small bridge (e.g., gpt2 with compatibility mode)
ValueErrorwhen compatibility mode is offNotImplementedErrorwhen a Mamba-like adapter is present- Both
accumulated=Trueandaccumulated=False
Feel free to tag me once these edits are in and I will re-review. Thank you for your work on this, it is coming along nicely!
Resolved review feedback from @jlarson4, added tests covering reconstruction invariants on a distilgpt2 bridge in compatibility mode, arguments, asserting sum(scores) == logit_diff - (b_U[correct] - b_U[wrong]) against the model's real logits, plus labels/shape and batch-averaging checks. Added additional hardening: - Fix a latent direction-shape bug: replace the fragile answer_tokens.numel()==1 branch with a robust reshape so single-prompt, single-token inputs are handled correctly - Detect hybrid blocks via bridge.layer_types() instead of substring matching named_modules(), the codebase's own semantic mechanism - Import get_act_name from transformer_lens.utilities to avoid the transformer_lens.utils DeprecationWarning; drop the invalid return_type kwarg to run_with_cache - Register the analysis subpackage in tools/__init__.py Closes TransformerLensOrg#1263.
|
@jlarson4 I apologize for the delay once again, I spent a lot of time understanding how I can improve this PR on top of your suggested fixes by exploring the codebase. I will make sure to re-iterate within a couple hours of your review for this brand new commit 👍 |
Description
TransformerBridgesystem, closes [Proposal] Direct Logit Attribution Tool #1263. Based on the stale PR (Draft) Add DLA function to utils #466 but adapted to utilize 3.0 TransformerBridge.Acknowledged limitations
NotImplementedError.ActivationCache.decompose_residonly knows how to decomposeattn_out + mlp_outper layer; supporting hybrid blocks requires extending that method and is out of scopefor this PR. (will be working on this in next steps)
ValueErrorifbridge.enable_compatibility_mode()hasn't been called. Without folded LayerNorm weights, the projection direction is wrong and
per-component scores don't reflect actual logit contributions.
actual_logit_diff − (b_U[correct] − b_U[wrong]). This matches the convention incache.decompose_residand PR (Draft) Add DLA function to utils #466 — the bias is a constant offsetType of change
Screenshots
Checklist: