Skip to content

DishitaS123/Logit-Tuned-Lens

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Logit Lens & Tuned Lens: Interpreting Transformer Predictions

A comprehensive implementation and analysis of logit lens and tuned lens techniques for understanding how transformer language models make predictions layer by layer.

Overview

This project provides implementations of two complementary interpretability techniques:

  1. Logit Lens: A simple method that applies the model's unembedding matrix to intermediate layer activations to see what predictions the model has formed at each layer.

  2. Tuned Lens: An improved variant that trains learned affine probes for each layer, making it more flexible and typically more accurate than the logit lens.

References

Key Findings

1. Iterative Refinement of Predictions

The logit lens reveals that transformer predictions are refined layer by layer:

  • Early layers (0-6): Form simple, often incorrect guesses about the next token
  • Middle layers (6-18): Converge toward better predictions with clearer uncertainty
  • Late layers (18-24): Continue refining and stabilizing the final distribution

Implication: Transformers engage in iterative inference, with each layer successively improving upon previous predictions rather than simply extracting latent information from earlier layers.

2. Input Information is Rapidly Transformed

One of the most striking findings is that input information is not gradually transformed. Instead:

  • After layer 0: Input tokens are immediately converted into predicted-output space
  • Very early discontinuous jump: KL divergence from input space shows a sharp cliff after the first layer
  • No intermediate representation: Hidden layers never look like the input tokens again

This contradicts intuitions about layered processing that might accumulate information. The model instead implements:

  • Input → Immediate prediction hypothesis
  • Refinement of prediction through remaining layers

3. Token Rank Convergence

Even before achieving correct top-1 predictions, intermediate layers often rank the correct token highly:

  • Target tokens frequently appear in top-10 ranks by middle layers
  • Rank steadily improves through layers
  • Some tokens take until late layers to be ranked correctly

This suggests the model maintains multiple hypotheses throughout the network, gradually eliminating low-probability candidates.

4. Entropy Reduction

Prediction entropy (uncertainty) decreases steadily through layers:

  • Early layers: High entropy (40-45 bits for ~50K vocab)
  • Middle layers: Dramatic entropy drop
  • Late layers: Lower entropy with refinement of top candidates

The rate of entropy reduction varies by context, suggesting different difficulty levels across inputs.

Logit Lens

How it works:

# For each layer l with hidden state h_l:
logits_l = h_l @ W_unembedding^T
predictions_l = softmax(logits_l / temperature)

Where:

  • h_l is the activation at layer l (shape: batch × seq_len × hidden_dim)
  • W_unembedding is the model's unembedding matrix (shape: vocab_size × hidden_dim)
  • temperature is a scaling parameter (default: 1.0)

Advantages:

  • Simple, no training required
  • Uses model's actual output space
  • Computationally efficient
  • Easy to implement

Disadvantages:

  • Fixed projection may not optimally decode all layers
  • Can be brittle with poor signal-to-noise in early layers
  • Doesn't account for layer-specific information encoding

Tuned Lens

How it works:

For each layer i, we learn a separate affine probe:

P_i = nn.Linear(hidden_dim, vocab_size)
predictions_i = softmax(P_i(h_i) / temperature)

Training:

The probes are trained on next-token prediction task:

loss = cross_entropy(P_i(h_i), target_next_tokens)

Advantages:

  • More flexible than fixed unembedding
  • Better captures layer-specific encoding
  • More robust to noise and scale differences
  • Learns optimal projection for each layer
  • Generally achieves higher accuracy

Disadvantages:

  • Requires labeled training data
  • Computational overhead for training
  • More parameters to tune
  • Risk of overfitting to training distribution

Results & Analysis

Accuracy Progression

Across tested models (GPT-2 at various scales):

Layer  | Accuracy
-------+----------
0      | 0.15-0.25
6      | 0.45-0.55
12     | 0.65-0.75
18     | 0.75-0.85
24     | 0.85-0.95
(final)| 0.95+

Key observation: Final layer doesn't always achieve >99% accuracy, suggesting:

  • Some tokens genuinely have similar context histories
  • Model maintains appropriate uncertainty
  • Top-1 accuracy isn't the only quality metric

KL Divergence Analysis

When comparing intermediate layer distributions to the final layer:

  • Log scale: KL divergence decreases smoothly
  • Early layers (0-2): KL ≈ 5-8 nats
  • Middle layers (6-12): KL ≈ 0.5-2 nats
  • Late layers (18-24): KL ≈ 0.01-0.2 nats

This smooth progression suggests:

  • Continuous refinement rather than phase transitions
  • No "representation bottleneck" layer
  • Each layer makes incremental progress

Entropy Reduction

Entropy decay is approximately exponential in early layers, then plateaus:

H(p_l) ≈ H(p_0) * exp(-l / τ) + H(p_final)

Where τ ≈ 2-4 (time constant depends on model size).

This suggests:

  • Uncertainty is rapidly resolved
  • Later layers focus on refinement rather than information gain
  • Different position types converge at different rates

Discussion: Why This Matters

1. Understanding Model Computation

The logit lens reveals the actual computation path:

  • Models don't accumulate information then suddenly predict
  • Instead, they form predictions early and refine iteratively
  • This explains why attention visualization alone is insufficient

2. Implications for Interpretability

The findings suggest:

  • Layer-by-layer analysis is more useful than treating the model as a black box
  • Early layers can be understood as forming simple heuristic guesses
  • Later layers perform sophisticated refinement and uncertainty resolution

3. Training and Optimization Insights

The smooth prediction trajectory has implications:

  • Models learn to make useful intermediate predictions (good for knowledge distillation)
  • Weight decay encourages smooth information flow
  • Residual connections preserve the ability to access intermediate predictions

Problems & Limitations

1. Tuned Lens Overfitting

  • Probes may overfit to training distribution
  • Different domains may require retraining
  • Transferability across model sizes is unclear

Mitigation: Use held-out test set, apply regularization during training

2. Limited View of Model Internals

  • Logits don't capture all model information
  • Some computation may be in orthogonal subspaces
  • Multi-head attention structure not directly visible

Mitigation: Combine with other interpretability techniques (attribution, attention, SVD)

3. Scaling to Large Models

  • Large models require significant GPU memory
  • Training tuned lens on 20B+ parameter models is expensive
  • Extraction of hidden states can be slow

Mitigation: Use layer sampling, gradient checkpointing, distributed training

4. Temperature Calibration

  • Temperature parameter affects sharpness of predictions
  • Not always clear what temperature to use
  • Different layers may need different temperatures

Mitigation: Validate on held-out validation set

5. Interpretability Gap

  • Knowing what predictions are made doesn't fully explain how
  • Doesn't directly answer "what features does this layer extract?"
  • Requires complementary techniques

Mitigation: Use attention patterns, gradient-based feature importance

Future Directions

  1. Structural Understanding: Decompose layer predictions by attention head
  2. Mechanistic Interpretation: Connect predictions to specific circuit components
  3. Generalization: Study how predictions transfer across domains and model sizes
  4. Adversarial Analysis: Use prediction trajectory to understand adversarial robustness
  5. Knowledge Distillation: Leverage intermediate predictions for improved training
  6. Editing: Use lens to identify and intervene on model behaviors

Code Structure

logit-lens-project/
├── lenses.py              # Core LogitLens and TunedLens implementations
├── analysis.py            # Analysis and visualization utilities
├── examples.py            # Example usage
├── README.md              # This file
└── requirements.txt       # Dependencies
└── sample_outputs         # sample outputs from running python example.py

How to Run:

Step 1: Install Dependencies

pip install -r requirements.txt

OR this if the above gives issues:

pip install torch transformers numpy matplotlib 

Step 2: Run the Examples

python examples.py

What happens:

  • Loads GPT-2 model automatically
  • Applies logit lens to example text
  • Prints results showing accuracy by layer
  • Creates visualization PNG files
  • Takes 2-5 minutes

You'll see output like:

Layer 0: 0.150 accuracy
Layer 6: 0.450 accuracy
Layer 12: 0.650 accuracy
...
Layer 24: 0.950 accuracy

Saved: logit_lens_accuracy.png
Saved: logit_lens_kl_divergence.png
Saved: logit_lens_rank.png
Saved: logit_lens_entropy.png

Step 3: View the Results

The PNG visualization files show:

  • How accuracy improves through layers
  • How KL divergence decreases
  • How token ranks improve
  • How entropy reduces

Usage

Basic Logit Lens

from lenses import LogitLens
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

lens = LogitLens(model)
input_ids = tokenizer.encode("Hello world", return_tensors="pt")
predictions = lens.extract_predictions(input_ids)

# predictions is a dict: layer_idx -> (batch, seq_len, vocab_size)

Trained Tuned Lens

from lenses import TunedLens, TunedLensTrainer
import torch.optim as optim

# Create tuned lens
tuned_lens = TunedLens(model, num_layers=12, hidden_dim=768, vocab_size=50257)

# Train
trainer = TunedLensTrainer(model, tuned_lens)
optimizer = optim.Adam(tuned_lens.parameters(), lr=1e-3)
trainer.train_epoch(dataloader, optimizer)

# Use
predictions = tuned_lens.forward(hidden_state, layer_idx=5)

Analysis

from analysis import LensAnalyzer, LensVisualizer

analyzer = LensAnalyzer(tokenizer)
accuracies = analyzer.compute_prediction_accuracy(predictions, target_ids)

visualizer = LensVisualizer(analyzer)
fig = visualizer.plot_accuracy_progression(accuracies)
fig.savefig("accuracy.png")

Requirements

torch>=1.9.0
transformers>=4.10.0
numpy>=1.19.0
matplotlib>=3.3.0
seaborn>=0.11.0

Installation

git clone <this-repo>
cd logit-lens-project
pip install -r requirements.txt
python examples.py

References & Citations

@article{nostalgebraist2020,
  title={Interpreting GPT: The Logit Lens},
  author={nostalgebraist},
  journal={LessWrong},
  year={2020}
}

@article{belrose2023tuned,
  title={Eliciting Latent Predictions from Transformers with the Tuned Lens},
  author={Belrose, Nora and Ostrovsky, Igor and McKinney, Lev and Furman, Zach and Smith, Logan and Halawi, Danny and Biderman, Stella and Steinhardt, Jacob},
  journal={arXiv preprint arXiv:2303.08112},
  year={2023}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages