High-Performance Triton Kernels for Sparse Autoencoders
Memory-optimized implementation of Top-K Sparse Autoencoders for mechanistic interpretability research.
- 13.6x decoder speedup via sparse gather operations
- 97% memory reduction by never materializing dense activations
- Ghost Gradients for dead latent recovery
- FP8 quantization support (Ada Lovelace+)
- Drop-in PyTorch replacement with full
autogradsupport
flowchart TB
subgraph Input
X["Input x<br/>[batch, d_model]"]
end
subgraph Encoder["FUSED ENCODER"]
E1["Linear: x @ W_enc + b_enc"]
E2["Activation: ReLU / JumpReLU"]
E3["TopK: Select k largest"]
E1 --> E2 --> E3
end
subgraph Sparse["Sparse Representation"]
S["values + indices<br/>[batch, k]"]
end
subgraph Decoder["FUSED DECODER"]
D1["Sparse gather: Σ aᵢ · W_dec[iₖ]"]
D2["Reads only k rows, not all N"]
D1 --> D2
end
subgraph Output
Y["Output x̂<br/>[batch, d_model]"]
end
X --> Encoder
Encoder --> S
S --> Decoder
Decoder --> Y
linkStyle 0,1,2 stroke:#000000,stroke-width:2px
style Encoder fill:#93c5fd,stroke:#1e40af,stroke-width:2px,color:#1e3a5f
style Decoder fill:#86efac,stroke:#166534,stroke-width:2px,color:#14532d
style Sparse fill:#fcd34d,stroke:#b45309,stroke-width:2px,color:#78350f
Key Insight: The encoder outputs only k active features per sample (not all n_features), enabling massive memory savings and sparse decoder operations.
![]() |
![]() |
| Operation | PyTorch | Flash-SAE | Speedup | Memory Saved |
|---|---|---|---|---|
| Encoder | 22.9 ms | 21.7 ms | 1.06x | 20% |
| Decoder | 18.0 ms | 1.3 ms | 13.6x | 97% |
| Full Forward | 40.9 ms | 23.0 ms | 1.78x | 25% |
RTX 4070, batch=1024, d_model=4096, n_features=65536, k=64, bfloat16
Why is the decoder so much faster?
The decoder exploits sparsity: instead of dense matrix multiplication with all 65,536 features, Flash-SAE reads only the k=64 active features per sample. This reduces memory bandwidth from O(batch × n_features × d_model) to O(batch × k × d_model)—a 1000x reduction.
git clone https://github.com/alepot55/flash-sae.git
cd flash-sae
pip install -e ".[dev]"Requirements:
- Python >= 3.10
- PyTorch >= 2.4.0
- Triton >= 3.0.0
- CUDA GPU (Ampere+ recommended)
import torch
from flash_sae import FlashSAE, FlashSAEConfig
config = FlashSAEConfig(
d_model=4096, # LLM hidden dimension
n_features=65536, # Dictionary size (16x expansion)
k=64, # Top-k sparsity
activation="relu",
use_ghost_grads=True,
)
sae = FlashSAE(config).cuda()
x = torch.randn(1024, 4096, device='cuda', dtype=torch.bfloat16)
x_recon, aux_data = sae(x, return_aux=True)
loss, loss_dict = sae.compute_loss(x, x_recon, aux_data)
print(f"MSE: {loss_dict['mse']:.4f}, L0: {loss_dict['l0']:.1f}")python examples/train_gpt2_demo.pyFeatures:
- Native PyTorch hooks for activation capture
- Streaming dataset with shuffle buffer
- Ghost Gradient recovery for dead latents
Training code example
from flash_sae import FlashSAE, FlashSAEConfig, FlashSAETrainer
config = FlashSAEConfig(
d_model=768,
n_features=24576,
k=32,
use_ghost_grads=True,
ghost_coef=0.1,
)
sae = FlashSAE(config).cuda()
trainer = FlashSAETrainer(sae, lr=1e-4, warmup_steps=1000)
for batch in dataloader:
loss_dict = trainer.train_step(batch)
print(f"MSE: {loss_dict['mse']:.4f}, Dead: {100*loss_dict['dead_ratio']:.1f}%")| Parameter | Type | Default | Description |
|---|---|---|---|
d_model |
int | - | Input/output dimension |
n_features |
int | - | Dictionary size |
k |
int | 64 | Top-k sparsity |
activation |
str | "relu" |
"relu" or "jumprelu" |
jump_threshold |
float | 0.0 | JumpReLU threshold |
l1_coef |
float | 0.0 | L1 sparsity penalty |
use_ghost_grads |
bool | False | Enable ghost gradient loss |
ghost_coef |
float | 0.1 | Ghost gradient coefficient |
dead_threshold |
float | 0.001 | EMA threshold for dead detection |
normalize_decoder |
bool | True | Unit-norm decoder columns |
dtype |
dtype | bfloat16 |
Computation precision |
use_triton |
bool | True | Use Triton kernels |
values, indices, pre_topk = sae.encode(x, return_pre_topk=True)
x_recon = sae.decode(values, indices)
x_recon, aux_data = sae(x, return_aux=True)
loss, loss_dict = sae.compute_loss(x, x_recon, aux_data)
sae.normalize_decoder_() # Unit-norm columns
sae.remove_parallel_component_() # Orthogonalize encoderflash-sae/
├── flash_sae/
│ ├── __init__.py
│ ├── sae.py
│ └── kernels/
│ ├── encoder.py
│ ├── decoder.py
│ ├── topk.py
│ └── ghost_grads.py
├── examples/
│ └── train_gpt2_demo.py
├── benchmarks/
│ └── benchmark.py
├── tests/
│ └── test_sae.py
└── assets/
pytest tests/ -vAll 19 tests passing.
@software{flash_sae,
title = {Flash-SAE: High-Performance Triton Kernels for Sparse Autoencoders},
author = {Alessandro Potenza},
year = {2026},
url = {https://github.com/alepot55/flash-sae},
}- Templeton et al. (2024). Scaling Monosemanticity. Anthropic.
- Bricken et al. (2023). Towards Monosemanticity. Anthropic.
- Tillet et al. (2019). Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations.
MIT License. See LICENSE for details.



