Skip to content

High-performance Triton kernels for Sparse Autoencoders (SAEs) achieving 13.6x speedup and 97% memory reduction via sparse kernel fusion.

License

Notifications You must be signed in to change notification settings

alepot55/flash-sae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash-SAE ⚡

High-Performance Triton Kernels for Sparse Autoencoders

Python 3.10+ PyTorch 2.4+ Triton 3.0+ License: MIT Tests

Memory-optimized implementation of Top-K Sparse Autoencoders for mechanistic interpretability research.


Highlights

  • 13.6x decoder speedup via sparse gather operations
  • 97% memory reduction by never materializing dense activations
  • Ghost Gradients for dead latent recovery
  • FP8 quantization support (Ada Lovelace+)
  • Drop-in PyTorch replacement with full autograd support

Architecture

flowchart TB
    subgraph Input
        X["Input x<br/>[batch, d_model]"]
    end
    
    subgraph Encoder["FUSED ENCODER"]
        E1["Linear: x @ W_enc + b_enc"]
        E2["Activation: ReLU / JumpReLU"]
        E3["TopK: Select k largest"]
        E1 --> E2 --> E3
    end
    
    subgraph Sparse["Sparse Representation"]
        S["values + indices<br/>[batch, k]"]
    end
    
    subgraph Decoder["FUSED DECODER"]
        D1["Sparse gather: Σ aᵢ · W_dec[iₖ]"]
        D2["Reads only k rows, not all N"]
        D1 --> D2
    end
    
    subgraph Output
        Y["Output x̂<br/>[batch, d_model]"]
    end
    
    X --> Encoder
    Encoder --> S
    S --> Decoder
    Decoder --> Y
    
    linkStyle 0,1,2 stroke:#000000,stroke-width:2px
    style Encoder fill:#93c5fd,stroke:#1e40af,stroke-width:2px,color:#1e3a5f
    style Decoder fill:#86efac,stroke:#166534,stroke-width:2px,color:#14532d
    style Sparse fill:#fcd34d,stroke:#b45309,stroke-width:2px,color:#78350f
Loading

Key Insight: The encoder outputs only k active features per sample (not all n_features), enabling massive memory savings and sparse decoder operations.


Benchmarks

Benchmark Comparison Memory Efficiency
Operation PyTorch Flash-SAE Speedup Memory Saved
Encoder 22.9 ms 21.7 ms 1.06x 20%
Decoder 18.0 ms 1.3 ms 13.6x 97%
Full Forward 40.9 ms 23.0 ms 1.78x 25%

RTX 4070, batch=1024, d_model=4096, n_features=65536, k=64, bfloat16

Speedup Scaling

Why is the decoder so much faster?

The decoder exploits sparsity: instead of dense matrix multiplication with all 65,536 features, Flash-SAE reads only the k=64 active features per sample. This reduces memory bandwidth from O(batch × n_features × d_model) to O(batch × k × d_model)—a 1000x reduction.


Installation

git clone https://github.com/alepot55/flash-sae.git
cd flash-sae
pip install -e ".[dev]"

Requirements:

  • Python >= 3.10
  • PyTorch >= 2.4.0
  • Triton >= 3.0.0
  • CUDA GPU (Ampere+ recommended)

Quick Start

import torch
from flash_sae import FlashSAE, FlashSAEConfig

config = FlashSAEConfig(
    d_model=4096,        # LLM hidden dimension
    n_features=65536,    # Dictionary size (16x expansion)
    k=64,                # Top-k sparsity
    activation="relu",
    use_ghost_grads=True,
)

sae = FlashSAE(config).cuda()
x = torch.randn(1024, 4096, device='cuda', dtype=torch.bfloat16)

x_recon, aux_data = sae(x, return_aux=True)
loss, loss_dict = sae.compute_loss(x, x_recon, aux_data)

print(f"MSE: {loss_dict['mse']:.4f}, L0: {loss_dict['l0']:.1f}")

Training on LLM Activations

python examples/train_gpt2_demo.py

Training Curves

Features:

  • Native PyTorch hooks for activation capture
  • Streaming dataset with shuffle buffer
  • Ghost Gradient recovery for dead latents
Training code example
from flash_sae import FlashSAE, FlashSAEConfig, FlashSAETrainer

config = FlashSAEConfig(
    d_model=768,
    n_features=24576,
    k=32,
    use_ghost_grads=True,
    ghost_coef=0.1,
)

sae = FlashSAE(config).cuda()
trainer = FlashSAETrainer(sae, lr=1e-4, warmup_steps=1000)

for batch in dataloader:
    loss_dict = trainer.train_step(batch)
    print(f"MSE: {loss_dict['mse']:.4f}, Dead: {100*loss_dict['dead_ratio']:.1f}%")

API Reference

FlashSAEConfig

Parameter Type Default Description
d_model int - Input/output dimension
n_features int - Dictionary size
k int 64 Top-k sparsity
activation str "relu" "relu" or "jumprelu"
jump_threshold float 0.0 JumpReLU threshold
l1_coef float 0.0 L1 sparsity penalty
use_ghost_grads bool False Enable ghost gradient loss
ghost_coef float 0.1 Ghost gradient coefficient
dead_threshold float 0.001 EMA threshold for dead detection
normalize_decoder bool True Unit-norm decoder columns
dtype dtype bfloat16 Computation precision
use_triton bool True Use Triton kernels

FlashSAE Methods

values, indices, pre_topk = sae.encode(x, return_pre_topk=True)
x_recon = sae.decode(values, indices)
x_recon, aux_data = sae(x, return_aux=True)
loss, loss_dict = sae.compute_loss(x, x_recon, aux_data)

sae.normalize_decoder_()          # Unit-norm columns
sae.remove_parallel_component_()  # Orthogonalize encoder

Project Structure

flash-sae/
├── flash_sae/
│   ├── __init__.py
│   ├── sae.py
│   └── kernels/
│       ├── encoder.py
│       ├── decoder.py
│       ├── topk.py
│       └── ghost_grads.py
├── examples/
│   └── train_gpt2_demo.py
├── benchmarks/
│   └── benchmark.py
├── tests/
│   └── test_sae.py
└── assets/

Tests

pytest tests/ -v

All 19 tests passing.


Citation

@software{flash_sae,
  title = {Flash-SAE: High-Performance Triton Kernels for Sparse Autoencoders},
  author = {Alessandro Potenza},
  year = {2026},
  url = {https://github.com/alepot55/flash-sae},
}

References

License

MIT License. See LICENSE for details.

About

High-performance Triton kernels for Sparse Autoencoders (SAEs) achieving 13.6x speedup and 97% memory reduction via sparse kernel fusion.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages