Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VL
- [Model-Specific Documentation](#model-specific-documentation)
- [Vision Feature Caching](#vision-feature-caching)
- [TurboQuant KV Cache](#turboquant-kv-cache)
- [TriAttention KV Cache Compression](#triattention-kv-cache-compression)
- [Fine-tuning](#fine-tuning)

## Model-Specific Documentation
Expand Down Expand Up @@ -601,6 +602,96 @@ Tested on gemma-4-31b-it at 128k context:

TurboQuant automatically quantizes `KVCache` layers (global attention). Models with `RotatingKVCache` (sliding window) or `ArraysCache` (MLA/absorbed keys) keep their native cache format for those layers since they are already memory-efficient.

## TriAttention KV Cache Compression

TriAttention ([arXiv:2604.04921](https://arxiv.org/abs/2604.04921)) compresses the KV cache by **pruning** low-importance tokens instead of quantizing them. It uses trigonometric series derived from pre-RoPE Q/K concentration to score key importance, retaining only the top-B most important tokens in the cache.

### How It Works

1. **Offline calibration** — Run a forward pass to compute per-head Q-center statistics (mean direction and magnitude in the frequency domain)
2. **Online scoring** — During generation, every 128 tokens, score each cached key using:
- **S_trig**: Trigonometric series based on Q-K distance preferences
- **S_norm**: Norm-based signal weighted by Q/K concentration (Mean Resultant Length)
3. **Pruning** — Retain the top-B scoring keys, evict the rest. Attention sinks and recent tokens are always protected.

### Quick Start

**Zero-config (online calibration)** — Q/K centers are computed from prefill tokens automatically:

```sh
mlx_vlm generate \
--model google/gemma-4-31b-it \
--triattention-budget 512 \
--prompt "Your prompt here..." \
--max-tokens 2048
```

```python
from mlx_vlm import generate

result = generate(
model, processor, prompt,
triattention_budget=512,
max_tokens=2048,
)
```

**With offline calibration** (optional, for repeated use with the same model):

```sh
# Calibrate once (~30s)
python -m mlx_vlm.triattention_calibrate \
--model google/gemma-4-31b-it \
--output gemma4_calib.safetensors

# Generate with pre-computed calibration
mlx_vlm generate \
--model google/gemma-4-31b-it \
--triattention-calib gemma4_calib.safetensors \
--triattention-budget 512 \
--prompt "Your prompt here..." \
--max-tokens 2048
```

> **Why online works:** The paper (Appendix H) shows Q/K centers are model-intrinsic properties that converge from as few as 50K tokens and are stable across domains (math, code, chat all yield MRL ~0.98). Even a short prompt provides enough signal.

### Performance

Benchmarked on Gemma4-31B-it with MM-NIAH (Multimodal Needle-in-a-Haystack), Apple M5 Ultra 512GB:

| Context | Imgs | Mode | Prefill t/s | Decode t/s | KV Cache | KV Saved | Peak Mem | Correct |
|---------|------|------|-------------|------------|----------|----------|----------|---------|
| ~1K | 1 | Baseline | 231 | 10.0 | 0.66 GB | — | 59.3 GB | Y |
| | | TA-512 | 236 | 10.0 | 0.64 GB | 3% | 59.3 GB | Y |
| ~7K | 4 | Baseline | 317 | 9.8 | 1.25 GB | — | 62.4 GB | Y |
| | | TA-512 | 328 | 10.0 | 0.82 GB | **34%** | 62.4 GB | Y |
| ~15K | 8 | Baseline | 313 | 9.7 | 1.72 GB | — | 63.6 GB | N |
| | | TA-512 | 331 | 10.0 | 0.82 GB | **52%** | 63.6 GB | N |
| ~30K | 15 | Baseline | 300 | 9.3 | 2.64 GB | — | 66.0 GB | Y |
| | | TA-512 | 333 | 10.0 | 0.82 GB | **69%** | 66.0 GB | N |
| ~60K | 26 | Baseline | 270 | 8.7 | 4.43 GB | — | 71.3 GB | Y |
| | | TA-512 | 337 | 10.0 | 0.82 GB | **81%** | 71.3 GB | N |

Key observations:
- **KV cache capped at budget** regardless of sequence length — 0.82 GB at 60K tokens vs 4.43 GB baseline (**81% reduction**)
- **Decode speed maintained** at ~10 t/s across all lengths (baseline degrades to 8.7 at 60K)
- **Prefill is faster** with TriAttention (lighter cache ops at long contexts)
- Best suited for **generative tasks** (essays, reasoning, code) where distance-based scoring is effective. For retrieval tasks (needle-in-a-haystack), accuracy can degrade at very long contexts.

### Compatibility

TriAttention works with any model using standard `nn.RoPE` or `ProportionalRoPE`:
- Gemma 3/4, LLaVA, Phi-4, Mistral 3/4, InternVL, Idefics, Molmo, Granite, Pixtral, and more

Models with non-standard RoPE (MRoPE, xDRoPE) are not supported:
- Qwen2.5-VL, Qwen3-VL, HunyuanVL, ERNIE 4.5

Sliding-window attention layers (e.g., Gemma4 local attention) are automatically skipped — only full-attention layers are compressed.

### TriAttention + TurboQuant

TriAttention (pruning) and TurboQuant (quantization) are currently mutually exclusive. Use TriAttention when you need to cap total KV tokens (long generation), and TurboQuant when you want to compress all tokens to lower precision.

# Fine-tuning

MLX-VLM supports fine-tuning models with LoRA and QLoRA.
Expand Down
48 changes: 48 additions & 0 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,19 @@ def parse_arguments():
default=DEFAULT_THINKING_END_TOKEN,
help="Token that marks the end of a thinking block (default: %(default)s).",
)
parser.add_argument(
"--triattention-calib",
type=str,
default=None,
help="Path to TriAttention calibration file (.safetensors). Enables "
"TriAttention KV cache compression when provided.",
)
parser.add_argument(
"--triattention-budget",
type=int,
default=None,
help="Maximum KV tokens to retain after TriAttention compression.",
)

return parser.parse_args()

Expand Down Expand Up @@ -394,6 +407,8 @@ def generate_step(
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prefill_step_size: Optional[int] = DEFAULT_PREFILL_STEP_SIZE,
triattention_calib: Optional[str] = None,
triattention_budget: Optional[int] = None,
**kwargs,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
Expand Down Expand Up @@ -470,6 +485,26 @@ def generate_step(
max_kv_size=max_kv_size,
)

# Apply TriAttention KV cache compression
_triattention_online_state = None
if triattention_calib is not None:
# Offline mode: use pre-computed calibration file
from .triattention import maybe_apply_triattention

maybe_apply_triattention(
prompt_cache,
model,
triattention_calib,
budget=triattention_budget,
)
elif triattention_budget is not None:
# Online mode: calibrate from prefill tokens (no calib file needed)
from .triattention import setup_online_triattention

_triattention_online_state = setup_online_triattention(
model, budget=triattention_budget
)

def _step(y, inputs_embeds=None):
nonlocal tokens, kwargs

Expand Down Expand Up @@ -550,6 +585,13 @@ def _step(y, inputs_embeds=None):

y, logprobs = _step(input_ids, inputs_embeds=inputs_embeds)

# Activate online TriAttention after prefill (hooks captured Q during prefill)
if _triattention_online_state is not None:
from .triattention import activate_online_triattention

activate_online_triattention(_triattention_online_state, prompt_cache)
_triattention_online_state = None

mx.async_eval(y)

n = 0
Expand Down Expand Up @@ -1594,6 +1636,9 @@ def main():
"vision_cache": vision_cache,
**kwargs,
}
if args.triattention_calib is not None:
stream_kwargs["triattention_calib"] = args.triattention_calib
stream_kwargs["triattention_budget"] = args.triattention_budget
if args.resize_shape is not None:
stream_kwargs["resize_shape"] = args.resize_shape
if args.prefill_step_size is not None:
Expand Down Expand Up @@ -1629,6 +1674,9 @@ def main():
"quantized_kv_start": args.quantized_kv_start,
**kwargs,
}
if args.triattention_calib is not None:
gen_kwargs["triattention_calib"] = args.triattention_calib
gen_kwargs["triattention_budget"] = args.triattention_budget
if args.resize_shape is not None:
gen_kwargs["resize_shape"] = args.resize_shape
if args.prefill_step_size is not None:
Expand Down
5 changes: 5 additions & 0 deletions mlx_vlm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
_BaseCache,
)

try:
from ..triattention import TriAttentionKVCache
except ImportError:
TriAttentionKVCache = None


def make_prompt_cache(
model: nn.Module,
Expand Down
Loading
Loading