Skip to content

henliveira/vlm-token-pruner

Repository files navigation

Token Pruning for Vision-Language Models

Drop the visual tokens you don't need. Keep the ones you do.

vlm-token-pruner is a small, hookable library that implements several training-free visual-token reduction methods for multimodal LLMs (LLaVA, Qwen-VL, InternVL, …). The point is to make MLLM inference cheaper without re-training anything: you load your model, wrap it with a pruner, and the visual token count going into the language model drops by 50–80% with a minor (often imperceptible) hit to downstream metrics.

Overview

A visual encoder produces hundreds of patch tokens for a single image (LLaVA-1.5 emits 576 from a CLIP ViT-L/14 @ 336²; LLaVA-Next emits up to 2880). The language model then attends to all of them at every layer. For most queries, the vast majority of these tokens are redundant — either visually unimportant or never attended to in practice.

This library implements:

  • FastV-style pruning — drop tokens with the lowest attention weight from the language model side, at a chosen layer.
  • VisionZip-style keep-then-merge — pick top-K dominant visual tokens by encoder-side importance, then merge the rest by cosine similarity.
  • Random / uniform baselines — for sanity checks.
  • Bigram merging (à la ToMe) on the visual-token stream.
  • Spatial-grid pooling — naïve but a strong baseline.

All methods share a common Pruner interface and can be applied either via a one-shot transform on the visual tokens, or as a forward hook on the language model.

Architecture

  image ──▶ vision encoder ──▶ [visual tokens]
                                     │
                                     ▼
                            ┌─────────────────┐
                            │  Pruner.apply() │
                            └────────┬────────┘
                                     │  (fewer tokens)
                                     ▼
            text tokens ─────▶  language model  ─────▶ output

The pruner is just a Callable[[Tensor, Context], Tensor]. You can stack several (e.g. spatial pool then FastV) by composing them.

Installation

git clone https://github.com/henliveira/vlm-token-pruner
cd vlm-token-pruner
pip install -e .

PyTorch ≥ 2.1 and transformers ≥ 4.40 are required.

Quick Start

from vlm_pruner import attach
from vlm_pruner.methods import FastV

import torch
from transformers import LlavaForConditionalGeneration, AutoProcessor

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf", torch_dtype=torch.bfloat16
).cuda()
proc = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# After layer 2, keep the top-50% most attended visual tokens.
pruner = FastV(layer=2, keep_ratio=0.5)
attach(model, pruner, family="llava")

# Then use the model as usual.
inputs = proc(text="USER: <image>\nWhat is this?\nASSISTANT:",
              images=img, return_tensors="pt").to("cuda")
out = model.generate(**inputs, max_new_tokens=64)

A CLI is provided for quick benchmarking:

vlm-prune bench --model llava-hf/llava-1.5-7b-hf \
    --method fastv --keep-ratio 0.4 \
    --task vqav2 --limit 500

Benchmarks

LLaVA-1.5-7B, A100-40G, greedy decoding. All numbers from bench/ scripts.

Method Keep VQAv2 (acc) TextVQA (acc) POPE (acc) Latency (img+gen)
Baseline 100% 78.4 58.3 85.4 1.00×
Spatial pool 2×2 25% 75.1 54.0 83.7 0.55×
Random 50% 73.8 51.2 82.1 0.62×
FastV (L=2) 50% 78.0 57.4 85.1 0.63×
FastV (L=2) 25% 76.9 56.1 84.6 0.42×
VisionZip 25% 77.6 56.7 84.9 0.43×
FastV+ToMe 25% 77.4 56.4 84.8 0.41×

(These are my reproductions and may differ from published numbers by ±0.5 pt.)

Configuration

Methods are constructable from a YAML:

method:
  name: fastv
  layer: 2
  keep_ratio: 0.4
  ratio_schedule: constant    # or "linear", "step"

family: llava                  # llava | llava_next | qwen_vl | internvl
attach: hook                   # hook | transform

Repository layout

vlm_pruner/
  methods/
    fastv.py
    visionzip.py
    tome.py
    pool.py
    random.py
  attach/
    llava.py
    qwen_vl.py
    internvl.py
  bench/
    vqav2.py
    textvqa.py
    pope.py
  cli.py
configs/
  llava_fastv.yaml
  llava_visionzip.yaml

Roadmap

  • FastV + VisionZip + ToMe + spatial pool
  • LLaVA-1.5 / LLaVA-Next / Qwen-VL attach points
  • VQAv2 / TextVQA / POPE benchmarks
  • Video-MLLM support (LLaVA-Video, Qwen2-VL-Video)
  • Token reduction during generation (KV-cache-aware)

Citation

If this is useful in a paper:

@misc{vlmtokenpruner,
  author = {Hao Lin},
  title  = {vlm-token-pruner: training-free visual token reduction for MLLMs},
  year   = {2025},
  url    = {https://github.com/henliveira/vlm-token-pruner}
}

License

MIT — see LICENSE.

About

Training-free visual token pruning for multimodal LLMs — FastV, VisionZip, ToMe, spatial pooling.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors