Skip to content

Add Direct Logit Attribution tool for TransformerBridge#1316

Open
TravisHaa wants to merge 2 commits into
TransformerLensOrg:devfrom
TravisHaa:feat/dla-tool-1263
Open

Add Direct Logit Attribution tool for TransformerBridge#1316
TravisHaa wants to merge 2 commits into
TransformerLensOrg:devfrom
TravisHaa:feat/dla-tool-1263

Conversation

@TravisHaa
Copy link
Copy Markdown

@TravisHaa TravisHaa commented May 20, 2026

Description

  • Implemented a Direct Logit Attribution (DLA) tool for the new TransformerBridge system, closes [Proposal] Direct Logit Attribution Tool #1263. Based on the stale PR (Draft) Add DLA function to utils #466 but adapted to utilize 3.0 TransformerBridge.
  • returns per-component (or per-layer) contributions to logit difference ebtween correct and wrong token, decomposing residual stream based off accumulated bool.
  • generated docstring (according to contributing.md) highlighting important warnings of limitations of DLA tool (currently does not support hybrid layers like mamba, requires bridge compatibility mode)

Acknowledged limitations

  • Strict mode only. hybrid architectures (Mamba, SSM, Mixer, LinearAttention) raise
    NotImplementedError. ActivationCache.decompose_resid only knows how to decompose attn_out + mlp_out per layer; supporting hybrid blocks requires extending that method and is out of scope
    for this PR. (will be working on this in next steps)
  • Requires compatibility mode. Raises ValueError if bridge.enable_compatibility_mode()
    hasn't been called. Without folded LayerNorm weights, the projection direction is wrong and
    per-component scores don't reflect actual logit contributions.
  • Excludes unembedding bias. Per-component scores sum to actual_logit_diff − (b_U[correct] − b_U[wrong]). This matches the convention in cache.decompose_resid and PR (Draft) Add DLA function to utils #466 — the bias is a constant offset

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Screenshots

image

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

Copy link
Copy Markdown
Collaborator

@jlarson4 jlarson4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @TravisHaa! I have reviewed your PR and left a few comments. I cannot run the code in its current state, until some of these comments are addressed. Let me know if you have any questions.

Also, the PR is marked as "tests added", but I am not seeing any tests in the diff. Could you add tests/unit/tools/test_direct_logit_attribution.py covering at least:

  • Correct path on a small bridge (e.g., gpt2 with compatibility mode)
  • ValueError when compatibility mode is off
  • NotImplementedError when a Mamba-like adapter is present
  • Both accumulated=True and accumulated=False

Feel free to tag me once these edits are in and I will re-review. Thank you for your work on this, it is coming along nicely!

Comment thread transformer_lens/tools/analysis/direct_logit_attribution.py Outdated
Comment thread transformer_lens/tools/analysis/__init__.py Outdated
Comment thread transformer_lens/tools/analysis/direct_logit_attribution.py Outdated
Comment thread transformer_lens/tools/analysis/direct_logit_attribution.py Outdated
Comment thread transformer_lens/tools/analysis/direct_logit_attribution.py Outdated
Comment thread transformer_lens/tools/analysis/direct_logit_attribution.py Outdated
Comment thread transformer_lens/tools/analysis/direct_logit_attribution.py Outdated
Resolved review feedback from @jlarson4, added tests covering
reconstruction invariants on a distilgpt2 bridge in compatibility mode,
arguments, asserting sum(scores) == logit_diff - (b_U[correct] -
b_U[wrong]) against the model's real logits, plus labels/shape and
batch-averaging checks.

Added additional hardening:
- Fix a latent direction-shape bug: replace the fragile
  answer_tokens.numel()==1 branch with a robust reshape so single-prompt,
  single-token inputs are handled correctly
- Detect hybrid blocks via bridge.layer_types() instead of substring
  matching named_modules(), the codebase's own semantic mechanism
- Import get_act_name from transformer_lens.utilities to avoid the
  transformer_lens.utils DeprecationWarning; drop the invalid
  return_type kwarg to run_with_cache
- Register the analysis subpackage in tools/__init__.py

Closes TransformerLensOrg#1263.
@TravisHaa
Copy link
Copy Markdown
Author

@jlarson4 I apologize for the delay once again, I spent a lot of time understanding how I can improve this PR on top of your suggested fixes by exploring the codebase. I will make sure to re-iterate within a couple hours of your review for this brand new commit 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants