- Introduction
- Features
- Installation
- Quick Start
- Usage Examples
- MiniSTU
- Configuration
- Training
- Contributing
- License
- Acknowledgments
This repository complements the Flash STU: Fast Spectral Transform Units paper and contains an optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU) as proposed in Spectral State Space Models by Agarwal et al. (2024).
Flash STU is a hybrid architecture that interleaves spectral state space model layers with sliding window attention, enabling scalability to billions of parameters for language modeling while maintaining near-linear time complexity. The STU module is a fast and flexible building block that can be adapted into a wide range of neural network architectures, especially those that aim to solve tasks with long-range dependencies.
- β‘οΈ Hybrid Architecture: Interleaves STU and sliding window attention layers
- π Fast Convolutions: Optimized spectral convolutions using Flash FFT
- π¨ Efficient Attention: Sliding window attention using Flash Attention 2
- π§ HuggingFace Compatible: Fully compatible with HuggingFace
transformersAPI - π― Advanced Features:
- KV caching for generation
- Gradient checkpointing
- STU MLP sandwiching
- Memory-efficient tiling
- π Distributed Training: Support for DDP and FSDP
- π¦ Flexible Building Blocks: Use standalone
FlashSTUBlockin your own architectures
Note: CUDA is required to run code from this repository.
For optimal performance, this repository was tested with:
- Python 3.12.5
- PyTorch 2.4.1
- Triton 3.0.0
- CUDA 12.4
and may be incompatible with other versions.
For a minimal setup without CUDA-heavy dependencies like Flash Attention and Flash FFT:
-
Install PyTorch with CUDA support:
pip install torch --index-url https://download.pytorch.org/whl/cu124
-
Install core dependencies only:
pip install -e .Or alternatively, use the minimal requirements file:
pip install -r requirements-minimal.txt
This lightweight installation includes all core dependencies (numpy, einops, transformers, etc.) but excludes the optional CUDA-heavy performance optimizations below.
For optimal performance with Flash Attention and Flash FFT Conv optimizations:
-
Install PyTorch with CUDA support:
pip install torch --index-url https://download.pytorch.org/whl/cu124
-
Install core dependencies:
pip install -e . -
Install Flash Attention (optional):
MAX_JOBS=4 pip install flash-attn --no-build-isolation
-
Install Flash FFT Conv (optional):
pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv pip install git+https://github.com/HazyResearch/flash-fft-conv.git
Or install directly from GitHub:
pip install git+https://github.com/hazan-lab/flash-stu-2.gitNote: Installing from source will only install the lightweight version. For full performance, manually install Flash Attention and Flash FFT Conv as shown above.
import torch
from flash_stu import FlashSTU, FlashSTUConfig
# Create configuration
config = FlashSTUConfig(
n_embd=512,
n_layers=12,
n_heads=8,
seq_len=2048,
vocab_size=50257,
)
# Initialize model (spectral filters computed automatically)
model = FlashSTU(config).cuda()
# Forward pass (HuggingFace compatible)
input_ids = torch.randint(0, config.vocab_size, (2, 128)).cuda()
outputs = model(input_ids=input_ids)
# Generate text
generated = model.generate(
input_ids=input_ids[:, :10],
max_length=50,
temperature=0.8,
top_k=40,
)from flash_stu import FlashSTU, FlashSTUConfig
# Configure model
config = FlashSTUConfig(
n_embd=768,
n_layers=12,
n_heads=12,
seq_len=2048,
vocab_size=50257,
window_size=512, # Sliding window size
num_eigh=24, # Number of spectral filters
)
# Create model
model = FlashSTU(config).cuda()
# Training loop
input_ids = torch.randint(0, config.vocab_size, (4, 512)).cuda()
labels = input_ids.clone()
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs['loss']
loss.backward()from flash_stu import FlashSTUBlock
# Create standalone STU block
stu_block = FlashSTUBlock(
d_model=512,
sequence_length=2048,
num_filters=24,
use_attention=False, # Pure STU, no attention
).cuda()
# Use in your own architecture
x = torch.randn(2, 2048, 512).cuda()
output = stu_block(x)from flash_stu import FlashSTUBlock
import torch.nn as nn
class HybridModel(nn.Module):
def __init__(self, d_model=512, n_layers=12):
super().__init__()
self.layers = nn.ModuleList([
FlashSTUBlock(
d_model=d_model,
sequence_length=2048,
num_filters=24,
use_attention=(i % 2 == 1), # Alternate STU and Attention
)
for i in range(n_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x, use_cache=False)
return x# Enable sandwiching for better expressiveness
config = FlashSTUConfig(
n_embd=512,
n_layers=12,
stu_enable_mlp_sandwich=True,
stu_mlp_hidden_size=2048, # Sandwich MLP hidden size
)
model = FlashSTU(config).cuda()# Save model
model.save_pretrained("./my_flash_stu_model")
# Load model
from flash_stu import FlashSTU
model = FlashSTU.from_pretrained("./my_flash_stu_model").cuda()
# Generate
input_ids = torch.randint(0, config.vocab_size, (1, 10)).cuda()
output = model.generate(input_ids, max_length=100)# Use tiling for large models with limited memory
config = FlashSTUConfig(
n_embd=2048,
n_layers=24,
use_approx=True, # Required for tiling
d_in_tile=512, # Tile input dimension
d_out_tile=512, # Tile output dimension
)
model = FlashSTU(config).cuda()For research and experimentation with the core spectral filtering innovation, we provide MiniSTU: a lightweight, standalone implementation focused on learning linear dynamical systems.
- π― Core STU Only: Pure spectral transform without attention/transformer layers
- π¦ Minimal Dependencies: Just PyTorch + NumPy
- π§ͺ LDS Learning: Built-in utilities for learning dynamical systems
- π Educational: Clean, well-documented code for understanding STU
from mini_stu import MiniSTU, random_LDS, train_stu_on_lds
# Create a random linear dynamical system
lds = random_LDS(state_dim=20, input_dim=10, output_dim=5)
# Train MiniSTU to approximate it
stu, losses = train_stu_on_lds(
lds,
seq_len=128,
num_filters=24,
num_steps=1000,
)
# Use the trained model
import torch
x = torch.randn(1, 128, 10)
y = stu(x) # Shape: [1, 128, 5]See mini_stu/README.md for complete documentation and examples/mini_stu_example.py for a full working example.
| Parameter | Type | Default | Description |
|---|---|---|---|
n_embd |
int | 1536 | Embedding/hidden dimension |
n_layers |
int | 26 | Total number of layers |
n_heads |
int | 8 | Number of attention heads |
seq_len |
int | 8192 | Maximum sequence length |
window_size |
int/tuple | 1024 | Sliding window size for attention |
num_eigh |
int | 24 | Number of spectral filters for STU |
vocab_size |
int | 200064 | Vocabulary size |
use_hankel_L |
bool | False | Use Hankel-L (single branch) formulation |
use_approx |
bool | True | Use approx mode (~50x fewer STU params, recommended) |
use_flash_fft |
bool | True | Use Flash FFT for convolutions |
stu_enable_mlp_sandwich |
bool | False | Enable MLP sandwiching for STU |
torch_dtype |
dtype | bfloat16 | Model parameter dtype |
Small Model (125M parameters):
config = FlashSTUConfig(
n_embd=768,
n_layers=12,
n_heads=12,
seq_len=2048,
num_eigh=24,
)Medium Model (350M parameters):
config = FlashSTUConfig(
n_embd=1024,
n_layers=24,
n_heads=16,
seq_len=4096,
num_eigh=32,
)Large Model (1B+ parameters):
config = FlashSTUConfig(
n_embd=2048,
n_layers=32,
n_heads=32,
seq_len=8192,
num_eigh=48,
use_gradient_checkpointing=True, # Save memory
)An example LLM pretraining script is provided in example.py for you to test out the repository.
If your compute cluster does not have internet access, you will need to pre-download the entire dataset before running the example training script.
To download the dataset, run:
cd training
python data.pyNote: The FineWeb-Edu 10B-token sample is a relatively large dataset. It can be swapped out for something smaller, e.g. TinyStories (476.6M tokens).
To begin training, make sure you are in the training directory and run the following command in your terminal:
torchrun example.pyIf you are in a compute cluster that uses Slurm and environment modules, you can submit a job using the following command:
sbatch job.slurmModel configurations can be adjusted as needed in config.json. Be sure to adjust the configurations of the Slurm job based on your cluster's constraints.
Note: PyTorch's
torch.compilecurrently does not have great support for distributed wrapper modules like DDP or FSDP. If you encounter errors during training, try disablingtorch.compile. For more information ontorch.compile, see this informal manual.
Contributions are welcomed! Writing performant distributed code is always tricky. We welcome contributors to:
- Submit pull requests
- Report issues
- Help improve the project overall
Apache 2.0 License
You can freely use, modify, and distribute the software, even in proprietary products, as long as you:
- Include proper attribution
- Include a copy of the license
- Mention any changes made
It also provides an express grant of patent rights from contributors.
See the LICENSE file for more details.
- Y. Isabel Liu, Windsor Nguyen, Yagiz Devre, Evan Dogariu, Anirudha Majumdar, Elad Hazan
- Kia Ghods, Hubert Strauss
Special thanks to (in no particular order):
- Elad Hazan and the authors of the Spectral State Space Models paper
- The Flash Attention team
- The Flash FFT team
- The PyTorch team
- Princeton Research Computing and Princeton Language and Intelligence, for supplying compute
- Andrej Karpathy, for his awesome NanoGPT repository
If you use this repository, or otherwise find our work valuable, please cite Flash STU:
@article{flashstu,
title={Flash STU: Fast Spectral Transform Units},
author={Y. Isabel Liu, Windsor Nguyen, Yagiz Devre, Evan Dogariu, Anirudha Majumdar, Elad Hazan},
journal={arXiv preprint arXiv:2409.10489},
year={2024},
url={https://arxiv.org/abs/2409.10489}
}