A comprehensive implementation and analysis of logit lens and tuned lens techniques for understanding how transformer language models make predictions layer by layer.
This project provides implementations of two complementary interpretability techniques:
-
Logit Lens: A simple method that applies the model's unembedding matrix to intermediate layer activations to see what predictions the model has formed at each layer.
-
Tuned Lens: An improved variant that trains learned affine probes for each layer, making it more flexible and typically more accurate than the logit lens.
- Logit Lens: Interpreting GPT: The Logit Lens (nostalgebraist, 2020)
- Tuned Lens: Eliciting Latent Predictions from Transformers with the Tuned Lens (Belrose et al., 2023)
The logit lens reveals that transformer predictions are refined layer by layer:
- Early layers (0-6): Form simple, often incorrect guesses about the next token
- Middle layers (6-18): Converge toward better predictions with clearer uncertainty
- Late layers (18-24): Continue refining and stabilizing the final distribution
Implication: Transformers engage in iterative inference, with each layer successively improving upon previous predictions rather than simply extracting latent information from earlier layers.
One of the most striking findings is that input information is not gradually transformed. Instead:
- After layer 0: Input tokens are immediately converted into predicted-output space
- Very early discontinuous jump: KL divergence from input space shows a sharp cliff after the first layer
- No intermediate representation: Hidden layers never look like the input tokens again
This contradicts intuitions about layered processing that might accumulate information. The model instead implements:
- Input → Immediate prediction hypothesis
- Refinement of prediction through remaining layers
Even before achieving correct top-1 predictions, intermediate layers often rank the correct token highly:
- Target tokens frequently appear in top-10 ranks by middle layers
- Rank steadily improves through layers
- Some tokens take until late layers to be ranked correctly
This suggests the model maintains multiple hypotheses throughout the network, gradually eliminating low-probability candidates.
Prediction entropy (uncertainty) decreases steadily through layers:
- Early layers: High entropy (40-45 bits for ~50K vocab)
- Middle layers: Dramatic entropy drop
- Late layers: Lower entropy with refinement of top candidates
The rate of entropy reduction varies by context, suggesting different difficulty levels across inputs.
How it works:
# For each layer l with hidden state h_l:
logits_l = h_l @ W_unembedding^T
predictions_l = softmax(logits_l / temperature)Where:
h_lis the activation at layer l (shape: batch × seq_len × hidden_dim)W_unembeddingis the model's unembedding matrix (shape: vocab_size × hidden_dim)temperatureis a scaling parameter (default: 1.0)
Advantages:
- Simple, no training required
- Uses model's actual output space
- Computationally efficient
- Easy to implement
Disadvantages:
- Fixed projection may not optimally decode all layers
- Can be brittle with poor signal-to-noise in early layers
- Doesn't account for layer-specific information encoding
How it works:
For each layer i, we learn a separate affine probe:
P_i = nn.Linear(hidden_dim, vocab_size)
predictions_i = softmax(P_i(h_i) / temperature)Training:
The probes are trained on next-token prediction task:
loss = cross_entropy(P_i(h_i), target_next_tokens)Advantages:
- More flexible than fixed unembedding
- Better captures layer-specific encoding
- More robust to noise and scale differences
- Learns optimal projection for each layer
- Generally achieves higher accuracy
Disadvantages:
- Requires labeled training data
- Computational overhead for training
- More parameters to tune
- Risk of overfitting to training distribution
Across tested models (GPT-2 at various scales):
Layer | Accuracy
-------+----------
0 | 0.15-0.25
6 | 0.45-0.55
12 | 0.65-0.75
18 | 0.75-0.85
24 | 0.85-0.95
(final)| 0.95+
Key observation: Final layer doesn't always achieve >99% accuracy, suggesting:
- Some tokens genuinely have similar context histories
- Model maintains appropriate uncertainty
- Top-1 accuracy isn't the only quality metric
When comparing intermediate layer distributions to the final layer:
- Log scale: KL divergence decreases smoothly
- Early layers (0-2): KL ≈ 5-8 nats
- Middle layers (6-12): KL ≈ 0.5-2 nats
- Late layers (18-24): KL ≈ 0.01-0.2 nats
This smooth progression suggests:
- Continuous refinement rather than phase transitions
- No "representation bottleneck" layer
- Each layer makes incremental progress
Entropy decay is approximately exponential in early layers, then plateaus:
H(p_l) ≈ H(p_0) * exp(-l / τ) + H(p_final)
Where τ ≈ 2-4 (time constant depends on model size).
This suggests:
- Uncertainty is rapidly resolved
- Later layers focus on refinement rather than information gain
- Different position types converge at different rates
The logit lens reveals the actual computation path:
- Models don't accumulate information then suddenly predict
- Instead, they form predictions early and refine iteratively
- This explains why attention visualization alone is insufficient
The findings suggest:
- Layer-by-layer analysis is more useful than treating the model as a black box
- Early layers can be understood as forming simple heuristic guesses
- Later layers perform sophisticated refinement and uncertainty resolution
The smooth prediction trajectory has implications:
- Models learn to make useful intermediate predictions (good for knowledge distillation)
- Weight decay encourages smooth information flow
- Residual connections preserve the ability to access intermediate predictions
- Probes may overfit to training distribution
- Different domains may require retraining
- Transferability across model sizes is unclear
Mitigation: Use held-out test set, apply regularization during training
- Logits don't capture all model information
- Some computation may be in orthogonal subspaces
- Multi-head attention structure not directly visible
Mitigation: Combine with other interpretability techniques (attribution, attention, SVD)
- Large models require significant GPU memory
- Training tuned lens on 20B+ parameter models is expensive
- Extraction of hidden states can be slow
Mitigation: Use layer sampling, gradient checkpointing, distributed training
- Temperature parameter affects sharpness of predictions
- Not always clear what temperature to use
- Different layers may need different temperatures
Mitigation: Validate on held-out validation set
- Knowing what predictions are made doesn't fully explain how
- Doesn't directly answer "what features does this layer extract?"
- Requires complementary techniques
Mitigation: Use attention patterns, gradient-based feature importance
- Structural Understanding: Decompose layer predictions by attention head
- Mechanistic Interpretation: Connect predictions to specific circuit components
- Generalization: Study how predictions transfer across domains and model sizes
- Adversarial Analysis: Use prediction trajectory to understand adversarial robustness
- Knowledge Distillation: Leverage intermediate predictions for improved training
- Editing: Use lens to identify and intervene on model behaviors
logit-lens-project/
├── lenses.py # Core LogitLens and TunedLens implementations
├── analysis.py # Analysis and visualization utilities
├── examples.py # Example usage
├── README.md # This file
└── requirements.txt # Dependencies
└── sample_outputs # sample outputs from running python example.py
pip install -r requirements.txtOR this if the above gives issues:
pip install torch transformers numpy matplotlib python examples.pyWhat happens:
- Loads GPT-2 model automatically
- Applies logit lens to example text
- Prints results showing accuracy by layer
- Creates visualization PNG files
- Takes 2-5 minutes
You'll see output like:
Layer 0: 0.150 accuracy
Layer 6: 0.450 accuracy
Layer 12: 0.650 accuracy
...
Layer 24: 0.950 accuracy
Saved: logit_lens_accuracy.png
Saved: logit_lens_kl_divergence.png
Saved: logit_lens_rank.png
Saved: logit_lens_entropy.png
The PNG visualization files show:
- How accuracy improves through layers
- How KL divergence decreases
- How token ranks improve
- How entropy reduces
from lenses import LogitLens
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
lens = LogitLens(model)
input_ids = tokenizer.encode("Hello world", return_tensors="pt")
predictions = lens.extract_predictions(input_ids)
# predictions is a dict: layer_idx -> (batch, seq_len, vocab_size)from lenses import TunedLens, TunedLensTrainer
import torch.optim as optim
# Create tuned lens
tuned_lens = TunedLens(model, num_layers=12, hidden_dim=768, vocab_size=50257)
# Train
trainer = TunedLensTrainer(model, tuned_lens)
optimizer = optim.Adam(tuned_lens.parameters(), lr=1e-3)
trainer.train_epoch(dataloader, optimizer)
# Use
predictions = tuned_lens.forward(hidden_state, layer_idx=5)from analysis import LensAnalyzer, LensVisualizer
analyzer = LensAnalyzer(tokenizer)
accuracies = analyzer.compute_prediction_accuracy(predictions, target_ids)
visualizer = LensVisualizer(analyzer)
fig = visualizer.plot_accuracy_progression(accuracies)
fig.savefig("accuracy.png")torch>=1.9.0
transformers>=4.10.0
numpy>=1.19.0
matplotlib>=3.3.0
seaborn>=0.11.0
git clone <this-repo>
cd logit-lens-project
pip install -r requirements.txt
python examples.py@article{nostalgebraist2020,
title={Interpreting GPT: The Logit Lens},
author={nostalgebraist},
journal={LessWrong},
year={2020}
}
@article{belrose2023tuned,
title={Eliciting Latent Predictions from Transformers with the Tuned Lens},
author={Belrose, Nora and Ostrovsky, Igor and McKinney, Lev and Furman, Zach and Smith, Logan and Halawi, Danny and Biderman, Stella and Steinhardt, Jacob},
journal={arXiv preprint arXiv:2303.08112},
year={2023}
}