Video diffusion models generate stunningly realistic content, but their computational demandsβspecifically within self-attention layersβare staggering. To address this, we present LiteAttention, a temporal sparse attention mechanism directly addressing the redundancy in attention computations across diffusion timesteps.
By identifying non-essential tiles early in the generation process and propagating these "skip decisions" forward, LiteAttention eliminates redundant computations without repeated profiling overheads. The result? Up to 54% attention sparsity on production-grade models like Wan2.1 and Wan2.2, with zero degradation in visual quality. This translates to a nearly 1.9x speedup in wall-clock time.
LiteAttention is actively developed to provide the fastest, most flexible sparse attention for diffusion models. Here is the recent evolution of the codebase:
- INT8 Quantization: Added support for INT8 quantization (
use_int8=True) for Q (per-block) and K (per-block with channel-wise mean smoothing), significantly reducing memory usage and boosting performance. - Fixes: Resolved sequence parallelism correctness issues for rectangular QK skip lists and fixed default modes for
torch.compilesupport.
- Full Producer-Consumer Pipeline: Introduced q-pad and bi-directionality for enhanced execution efficiency and sequence handling.
- Fine-Grained Sequence Control: Added
must_do_listandmust_skip_listparameters. You can now explicitly define token ranges (e.g., prompt tokens vs padding) that must always be computed or that can always be skipped, bypassing the threshold logic entirely.
- Evolutionary Computation Skips (QK-Skip): The core algorithm that maintains a Skip-Mask, identifying non-essential tiles and completely bypassing the attention iteration (QK product, softmax, PV product) in later timesteps.
- Sequence Parallelism: Introduced
SeqParallelLiteAttentionfor multi-GPU scale-out. - Softmax LSE: Added the ability to return the softmax log-sum-exp (
return_softmax_lse=True) for combining partial attention computations (e.g., separating text-to-video vs video-to-video attention).
Traditional dynamic sparse attention methods evaluate sparsity criteria at every single timestep. This incurs a massive 10-20% runtime overhead just to figure out what to compute.
LiteAttention introduces evolutionary computation skips by leveraging the temporal coherence of diffusion attention.
- Early Profiling: During the initial diffusion timesteps, we compute the full attention matrix and track the maximum log-attention score for each tile.
- The Skip-Mask: If a tile's score falls below a set
threshold, it's marked as skippable. - Forward Propagation: Once a tile is marked as skippable, the entire attention computation for that tile is bypassed for all subsequent timesteps.
This gives us the content adaptivity of dynamic sparsity without the overhead, acting like an efficient, static pre-computation.
LiteAttention achieves state-of-the-art speeds while maintaining top-tier visual consistency metrics (evaluated via VBench).
| Model | AQ β | BC β | DD β | IQ β | SC β | TF β | TS β | Sparsity β | Runtime β (Speedup) |
|---|---|---|---|---|---|---|---|---|---|
| Wan2.1-14B Base | 0.676 | 0.977 | 0.417 | 68.74 | 0.965 | 0.962 | 0.137 | 0% | 1707 sec (1.00x) |
| Wan2.1-14B + LiteAttn | 0.677 | 0.975 | 0.500 | 66.76 | 0.963 | 0.962 | 0.142 | 42% | 902 sec (1.89x) |
| Wan2.2-14B Base | 0.693 | 0.977 | 0.583 | 72.73 | 0.970 | 0.953 | 0.133 | 0% | 1473 sec (1.00x) |
| Wan2.2-14B + LiteAttn | 0.698 | 0.977 | 0.500 | 71.44 | 0.969 | 0.953 | 0.135 | 32% | 893 sec (1.65x) |
VBench Metrics: AQ (Aesthetic Quality), BC (Background Consistency), DD (Dynamic Degree), IQ (Imaging Quality), SC (Subject Consistency), TF (Temporal Flickering), TS (Temporal Style)
Click to view Ablation Study: Sparsity vs Runtime
| Sparsity | Self-Attention Runtime | Runtime Improvement |
|---|---|---|
| 0% | 695 sec | 0% (baseline) |
| 21% | 573 sec | 18% |
| 42% | 418 sec | 40% |
| 57% | 308 sec | 56% |
| 77% | 163 sec | 77% |
The near-linear scaling demonstrates the efficiency of the QK-Skip algorithm.
| Threshold | Generation Time | Preview |
|---|---|---|
| Baseline (no skip) | 23m 51s | ![]() |
| Threshold -10 | 14m 19s | ![]() |
| Threshold -3 | 11m 46s | ![]() |
| Threshold 0 | 8m 31s | ![]() |
Requirements: Hopper H100/H200 GPU, CUDA >= 12.8, C++ 20, PyTorch 2.2+, Linux.
LiteAttention requires ninja for fast compilation.
Note: Pre-built wheels for common environments will be added soon to simplify installation.
uv is a fast Rust-based Python package installer.
# Clone the repository
git clone https://github.com/moonmath-ai/LiteAttention.git
cd LiteAttention
# Create a virtual environment and activate it
uv venv
source .venv/bin/activate
# Install dependencies
uv pip install ninja torch packaging einops structlog tomli-w
# Build and install LiteAttention
uv pip install --no-build-isolation .# Ensure ninja is working properly
pip uninstall -y ninja && pip install ninja
# Install dependencies
pip install torch packaging einops structlog tomli-w
# Clone and build
git clone https://github.com/moonmath-ai/LiteAttention.git
cd LiteAttention
pip install --no-build-isolation .LiteAttention is designed as a drop-in replacement for standard flash attention modules in DiT (Diffusion Transformer) models.
The complete initialization API for the core module is as follows:
def LiteAttention(
enable_skipping: bool = True,
threshold: float | None = None,
max_batch_size: int = 2,
reverse_skip_list: bool = True,
use_int8: bool = False
)Parameters:
enable_skipping(bool): Whether to enable skip list optimizations. Defaults toTrue. WhenFalse, performs standard Flash Attention.max_batch_size(int): Maximum batch size to pre-allocate memory for. Defaults to2. The actual batch size used during inference can be smaller than this value, but not larger.reverse_skip_list(bool): Whether to use the reversed skip list format (internal optimization). Defaults toTrue.use_int8(bool): Whether to use Int8 quantization for Q and K. Defaults toFalse. Enables per-block quantization for Q and channel-smoothed per-block quantization for K.threshold(float): Log-space threshold for skipping tiles. Controlled from the Registry. Change here should be used only for testing.
Replace your standard attention call with a LiteAttention instance. Crucially, instantiate a separate LiteAttention object for each layer so they maintain independent skip states.
from lite_attention import LiteAttention
class MyDiTBlock(nn.Module):
def __init__(self, ...):
super().__init__()
# Enable skipping and INT8 quantization!
self.lite_attention = LiteAttention(enable_skipping=True, use_int8=True)
def forward(self, q, k, v, must_do_list=None):
# ...
# Standard input format: (batch, seq_len, heads, head_dim)
x = self.lite_attention(q, k, v, must_do_list=must_do_list)
return xFor parts of the sequence that should explicitly be computed or skipped, you can pass the must_do_list and must_skip_list parameters during the forward pass:
output = self.lite_attention(query, key, value, must_do_list=must_do_list, must_skip_list=must_skip_list)These lists define ranges of tokens. The format is a flat list of start and end indices:
[start_0, end_0, start_1, end_1, ...]
start_i: Start index of the range (inclusive).end_i: End index of the range (exclusive).- Important: Indices must be in strict ascending order:
start_i < end_i < start_(i+1) < end_(i+1).
Example: If you have a sequence of length 100, and you want to ensure tokens 2-11, 40-44, and 60-79 are always computed, and tokens 80-99 are always skipped:
must_do_list = [2, 12, 40, 45, 60, 80]
must_skip_list = [80, 100]Important
self.lite_attention.enable_skip_optimization(enable=False).
When using multi-GPU with sequence parallelism, use SeqParallelLiteAttention:
def SeqParallelLiteAttention(
num_nodes: int,
enable_skipping: bool = True,
max_batch_size: int = 2,
use_int8: bool = False
)Parameters:
num_nodes(int): Number of GPUs/nodes across which the sequence is split.enable_skipping(bool): Whether to enable skip list optimizations. Defaults toTrue.max_batch_size(int): Maximum batch size to pre-allocate memory for. Defaults to2.use_int8(bool): Whether to use Int8 quantization for Q and K. Defaults toFalse.
Replace your standard attention call with a SeqParallelLiteAttention instance. You must pass the split_idx indicating the K/V split being processed by the current node (0 to num_nodes-1), not the current GPU index.
from lite_attention import SeqParallelLiteAttention
class MySeqParDiTBlock(nn.Module):
def __init__(self, num_nodes=8, **kwargs):
super().__init__()
# Initialize with the number of nodes
self.attn = SeqParallelLiteAttention(num_nodes=num_nodes, enable_skipping=True)
def forward(self, query, key, value, split_idx, scale=None):
# ...
# Pass split_idx to indicate which split of K and V we are processing
output = self.attn(query, key, value, split_idx, scale)
return outputTo unlock optimal generation/speed ratios, employ the new Registry to automatically calibrate thresholds for your specific model.
from lite_attention import LiteAttentionRegistry
model = build_my_model(...) # Initializes modules utilizing LiteAttention()
# Wrap the model. Modes: "calib", "load", "const"
registry = LiteAttentionRegistry.from_model(
model,
mode="calib",
filename="optimized_thresholds.toml",
calib_config={"target_error": 0.05, "metric": "L1"},
)
# Run Inference
video = model.generate(prompt, ...)
# Save the calibrated thresholds (triggers only if mode="calib")
registry.save_if_calib() To run normally using a fixed static threshold, just initialize with mode="const" and threshold=-10.0.
If you utilize LiteAttention in your research or deployment, please consider citing:
@misc{shmilovich2025liteattentiontemporalsparseattention,
title={LiteAttention: A Temporal Sparse Attention for Diffusion Transformers},
author={Dor Shmilovich and Tony Wu and Aviad Dahan and Yuval Domb},
year={2025},
eprint={2511.11062},
archivePrefix={arXiv},
primaryClass={cs.CV}
}Built upon the incredible foundation of FlashAttention3 by Tri Dao.
License: LiteAttention inherits the BSD 3-Clause license from FA3 for original code; new LiteAttention additions are distributed under the MIT license. See LICENSE-BSD and LICENSE-MIT.



