Implementation of CorrSteer, a generation-time steering method using correlated Sparse Autoencoder (SAE) features.
- Correlation-based feature selection from generation-time activations
- Streaming computation with O(1) memory complexity
- Multi-layer strategies (CorrSteer-S/A/P)
- Side Effect Ratio (SER) for measuring unintended changes
Install Astral UV:
pip install uvCreate virtual environment and install:
uv venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
uv pip install -e .# MMLU with SAE features
python train.py train --model=gemma2b --task=mmlu --layer=global --eval
# MMLU with raw activations
python train.py train --model=gemma2b --task=mmlu --layer=global --raw --eval
# MMLU with mean pooling
python train.py train --model=gemma2b --task=mmlu --layer=global --pool=mean --eval
# BBQ disambiguation
python train.py train --model=gemma2b --task=bbq --layer=global --mask=all --filter_value=disambig --eval
# HarmBench with raw activations
python train.py train --model=gemma2b --task=harmbench --layer=global --raw --eval
# SimpleQA with mean pooling
python train.py train --model=gemma2b --task=simpleqa --layer=global --pool=mean --eval
# GSM8K with mean pooling for both correlation and steering
python train.py train --model=gemma2b --task=gsm8k --layer=foreach --pool=mean --steer_pool=mean --eval# Baseline evaluation
python eval.py baseline --task=mmlu# CorrSteer-S: Single best feature globally
python train.py train --task=mmlu --layer=global --eval
# CorrSteer-A: Top feature from each layer
python train.py train --task=mmlu --layer=foreach --eval
# CorrSteer-P: Validation-based pruning
python train.py train --task=mmlu --layer=foreach --validate --evalcorrsteer/
├── config.py # Dataset and model configurations
├── dataset.py # Data loading and processing
├── model.py # Model and SAE integration
├── steer.py # Steering hooks for inference
└── utils.py # Utility functions
train.py # Training with streaming correlation
eval.py # Evaluation with SER computation
sft.py # Supervised fine-tuning
MIT License - see LICENSE file for details.
- SAE Lens for SAE implementation
- Gemma Scope for pretrained SAEs
- HuggingFace for model hosting