Skip to content

Commit ff8e187

Browse files
committed
[Model] Add ColPali late interaction model for multi-modal retrieval
Add support for ColPali (ColBERT-style late interaction on PaliGemma backbone) for multi-vector document retrieval and reranking. - Model implementation extending PaliGemmaForConditionalGeneration with a linear projection head for per-token embeddings - Custom ColPaliConfig extending PaliGemmaConfig with embedding projection fields (no trust_remote_code needed) - Registration in model registry and config registry - Multimodal processor handling for image+text pooling inputs - Weight loading with support for multiple checkpoint naming conventions (HF transformers, colpali-engine) - Comprehensive tests: token embedding, late interaction scoring, relevance ordering, and multimodal scoring - Added to test model registry and supported models documentation Target model: vidore/colpali-v1.3-hf Reference: https://arxiv.org/abs/2407.01449 Signed-off-by: Nikita Sukharev <kaonael@gmail.com>
1 parent 2a68464 commit ff8e187

File tree

8 files changed

+643
-0
lines changed

8 files changed

+643
-0
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ The following table lists those that are tested in vLLM.
827827
| ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- |
828828
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
829829
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
830+
| `ColPaliForRetrieval` | ColPali | T / I | `vidore/colpali-v1.3-hf` | | |
830831
| `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | |
831832
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
832833
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests for ColPali late interaction model for multi-modal retrieval.
4+
5+
ColPali is a multi-vector retrieval model based on PaliGemma backbone
6+
(SigLIP + Gemma) with ColBERT-style late interaction scoring (MaxSim).
7+
It produces per-token embeddings for both text and image inputs.
8+
"""
9+
10+
import base64
11+
from io import BytesIO
12+
13+
import pytest
14+
import torch
15+
from PIL import Image
16+
17+
from vllm.entrypoints.chat_utils import (
18+
ChatCompletionContentPartImageParam,
19+
ChatCompletionContentPartTextParam,
20+
)
21+
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
22+
23+
from ....conftest import VllmRunner
24+
25+
MODELS = [
26+
"vidore/colpali-v1.3-hf",
27+
]
28+
29+
EMBED_DIMS = {
30+
"vidore/colpali-v1.3-hf": 128,
31+
}
32+
33+
TEXT_QUERIES = [
34+
"What is the capital of France?",
35+
"Describe the contents of the document.",
36+
]
37+
38+
TEXT_DOCUMENTS = [
39+
"The capital of France is Paris.",
40+
"This document contains important financial data.",
41+
]
42+
43+
DTYPE = "half"
44+
GPU_MEMORY_UTILIZATION = 0.7
45+
46+
47+
def _make_base64_image(
48+
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
49+
) -> str:
50+
"""Create a small solid-color PNG image and return its base64 data URI."""
51+
img = Image.new("RGB", (width, height), color)
52+
buf = BytesIO()
53+
img.save(buf, format="PNG")
54+
b64 = base64.b64encode(buf.getvalue()).decode()
55+
return f"data:image/png;base64,{b64}"
56+
57+
58+
def _make_image_mm_param(
59+
image_uri: str,
60+
text: str | None = None,
61+
) -> ScoreMultiModalParam:
62+
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
63+
content: list = [
64+
ChatCompletionContentPartImageParam(
65+
type="image_url",
66+
image_url={"url": image_uri},
67+
),
68+
]
69+
if text is not None:
70+
content.append(
71+
ChatCompletionContentPartTextParam(type="text", text=text),
72+
)
73+
return ScoreMultiModalParam(content=content)
74+
75+
76+
def _make_text_mm_param(text: str) -> ScoreMultiModalParam:
77+
"""Build a ScoreMultiModalParam containing only text."""
78+
return ScoreMultiModalParam(
79+
content=[ChatCompletionContentPartTextParam(type="text", text=text)],
80+
)
81+
82+
83+
def _run_token_embed_test(
84+
vllm_runner: type[VllmRunner],
85+
model: str,
86+
*,
87+
dtype: str,
88+
) -> None:
89+
"""Verify per-token embedding shape and L2 normalization."""
90+
with vllm_runner(
91+
model,
92+
runner="pooling",
93+
dtype=dtype,
94+
max_model_len=4096,
95+
enforce_eager=True,
96+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
97+
) as vllm_model:
98+
outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
99+
100+
assert len(outputs) == 1
101+
emb = torch.tensor(outputs[0])
102+
# Token embeddings should be 2D: [num_tokens, embed_dim]
103+
assert emb.dim() == 2
104+
assert emb.shape[1] == EMBED_DIMS[model]
105+
assert emb.shape[0] > 1
106+
107+
# Verify L2 normalization
108+
norms = torch.norm(emb, p=2, dim=-1)
109+
torch.testing.assert_close(
110+
norms,
111+
torch.ones_like(norms),
112+
rtol=1e-2,
113+
atol=1e-2,
114+
)
115+
116+
117+
def _run_late_interaction_test(
118+
vllm_runner: type[VllmRunner],
119+
model: str,
120+
*,
121+
dtype: str,
122+
) -> None:
123+
"""Verify MaxSim scoring matches manual computation."""
124+
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
125+
126+
with vllm_runner(
127+
model,
128+
runner="pooling",
129+
dtype=dtype,
130+
max_model_len=4096,
131+
enforce_eager=True,
132+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
133+
) as vllm_model:
134+
q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
135+
d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])
136+
137+
q_emb = torch.tensor(q_outputs[0])
138+
d_emb = torch.tensor(d_outputs[0])
139+
140+
manual_score = compute_maxsim_score(q_emb, d_emb).item()
141+
142+
vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])
143+
144+
assert len(vllm_scores) == 1
145+
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
146+
147+
148+
def _run_relevance_test(
149+
vllm_runner: type[VllmRunner],
150+
model: str,
151+
*,
152+
dtype: str,
153+
) -> None:
154+
"""Verify that relevant documents score higher than irrelevant ones."""
155+
query = "What is machine learning?"
156+
documents = [
157+
"Machine learning is a subset of artificial intelligence.",
158+
"The weather forecast shows rain tomorrow.",
159+
"Deep learning uses neural networks for complex tasks.",
160+
]
161+
162+
with vllm_runner(
163+
model,
164+
runner="pooling",
165+
dtype=dtype,
166+
max_model_len=4096,
167+
enforce_eager=True,
168+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
169+
) as vllm_model:
170+
scores = vllm_model.score(query, documents)
171+
172+
assert len(scores) == 3
173+
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
174+
assert scores[2] > scores[1], "DL doc should score higher than weather doc"
175+
176+
177+
@pytest.mark.parametrize("model", MODELS)
178+
@pytest.mark.parametrize("dtype", [DTYPE])
179+
def test_colpali_token_embed(
180+
vllm_runner,
181+
model: str,
182+
dtype: str,
183+
) -> None:
184+
_run_token_embed_test(vllm_runner, model, dtype=dtype)
185+
186+
187+
@pytest.mark.parametrize("model", MODELS)
188+
@pytest.mark.parametrize("dtype", [DTYPE])
189+
def test_colpali_late_interaction_scoring(
190+
vllm_runner,
191+
model: str,
192+
dtype: str,
193+
) -> None:
194+
_run_late_interaction_test(vllm_runner, model, dtype=dtype)
195+
196+
197+
@pytest.mark.parametrize("model", MODELS)
198+
@pytest.mark.parametrize("dtype", [DTYPE])
199+
def test_colpali_relevance_ordering(
200+
vllm_runner,
201+
model: str,
202+
dtype: str,
203+
) -> None:
204+
_run_relevance_test(vllm_runner, model, dtype=dtype)
205+
206+
207+
# ── Multimodal scoring tests ────────────────────────────────
208+
209+
210+
def _run_multimodal_text_query_image_docs_test(
211+
vllm_runner: type[VllmRunner],
212+
model: str,
213+
*,
214+
dtype: str,
215+
) -> None:
216+
"""Score a text query against image documents via the multimodal path."""
217+
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
218+
blue_image = _make_base64_image(64, 64, color=(0, 0, 255))
219+
220+
query = "Describe the red object"
221+
image_docs = [
222+
_make_image_mm_param(red_image),
223+
_make_image_mm_param(blue_image),
224+
]
225+
226+
with vllm_runner(
227+
model,
228+
runner="pooling",
229+
dtype=dtype,
230+
max_model_len=4096,
231+
enforce_eager=True,
232+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
233+
) as vllm_model:
234+
scores = vllm_model.llm.score(query, image_docs)
235+
236+
assert len(scores) == 2
237+
for s in scores:
238+
assert isinstance(s.outputs.score, float)
239+
240+
241+
def _run_multimodal_mixed_docs_test(
242+
vllm_runner: type[VllmRunner],
243+
model: str,
244+
*,
245+
dtype: str,
246+
) -> None:
247+
"""Score a text query against a mix of text and image documents."""
248+
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
249+
250+
query = "What is the capital of France?"
251+
documents: list = [
252+
"The capital of France is Paris.",
253+
_make_image_mm_param(red_image),
254+
]
255+
256+
with vllm_runner(
257+
model,
258+
runner="pooling",
259+
dtype=dtype,
260+
max_model_len=4096,
261+
enforce_eager=True,
262+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
263+
) as vllm_model:
264+
scores = vllm_model.llm.score(query, documents)
265+
266+
assert len(scores) == 2
267+
for s in scores:
268+
assert isinstance(s.outputs.score, float)
269+
# Text document about France should score higher than a random image
270+
assert scores[0].outputs.score > scores[1].outputs.score
271+
272+
273+
def _run_multimodal_image_query_text_docs_test(
274+
vllm_runner: type[VllmRunner],
275+
model: str,
276+
*,
277+
dtype: str,
278+
) -> None:
279+
"""Score an image query against text documents."""
280+
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
281+
image_query = _make_image_mm_param(red_image, text="red color")
282+
283+
documents = [
284+
"A bright red sports car.",
285+
"The weather forecast shows rain tomorrow.",
286+
]
287+
288+
with vllm_runner(
289+
model,
290+
runner="pooling",
291+
dtype=dtype,
292+
max_model_len=4096,
293+
enforce_eager=True,
294+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
295+
) as vllm_model:
296+
scores = vllm_model.llm.score(image_query, documents)
297+
298+
assert len(scores) == 2
299+
for s in scores:
300+
assert isinstance(s.outputs.score, float)
301+
302+
303+
@pytest.mark.parametrize("model", MODELS)
304+
@pytest.mark.parametrize("dtype", [DTYPE])
305+
def test_colpali_multimodal_text_query_image_docs(
306+
vllm_runner,
307+
model: str,
308+
dtype: str,
309+
) -> None:
310+
_run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype)
311+
312+
313+
@pytest.mark.parametrize("model", MODELS)
314+
@pytest.mark.parametrize("dtype", [DTYPE])
315+
def test_colpali_multimodal_mixed_docs(
316+
vllm_runner,
317+
model: str,
318+
dtype: str,
319+
) -> None:
320+
_run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype)
321+
322+
323+
@pytest.mark.parametrize("model", MODELS)
324+
@pytest.mark.parametrize("dtype", [DTYPE])
325+
def test_colpali_multimodal_image_query_text_docs(
326+
vllm_runner,
327+
model: str,
328+
dtype: str,
329+
) -> None:
330+
_run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ def check_available_online(
625625
"ColModernVBertForRetrieval": _HfExamplesInfo(
626626
"ModernVBERT/colmodernvbert-merged",
627627
),
628+
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
628629
"ColQwen3": _HfExamplesInfo(
629630
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
630631
),

0 commit comments

Comments
 (0)