Skip to content

[Model] Add ColPali late interaction model for multi-modal retrieval#36818

Open
Kaonael wants to merge 1 commit intovllm-project:mainfrom
Kaonael:add-colpali-model
Open

[Model] Add ColPali late interaction model for multi-modal retrieval#36818
Kaonael wants to merge 1 commit intovllm-project:mainfrom
Kaonael:add-colpali-model

Conversation

@Kaonael
Copy link

@Kaonael Kaonael commented Mar 11, 2026

Add support for ColPali (ColBERT-style late interaction on PaliGemma backbone) for multi-vector document retrieval and reranking.

Target model: vidore/colpali-v1.3-hf
Reference: https://arxiv.org/abs/2407.01449

Purpose

Add support for https://arxiv.org/abs/2407.01449 — a ColBERT-style late interaction model built on PaliGemma backbone (SigLIP + Gemma) for multi-vector document retrieval and reranking.
Target model: vidore/colpali-v1.3-hf.

  • Model implementation extending PaliGemmaForConditionalGeneration with a linear projection head for per-token embeddings
  • Custom ColPaliConfig extending PaliGemmaConfig (no trust_remote_code needed)
  • Registration in model registry, config registry, and test model registry
  • Multimodal processor handling for image+text pooling inputs
  • Weight loading with support for multiple checkpoint naming conventions (HF transformers, colpali-engine)
  • Updated docs/models/supported_models.md

Test Plan

uv run vllm serve vidore/colpali-v1.3-hf --runner pooling
# Tokenize
curl -s -X POST http://0.0.0.0:8000/tokenize \
  -H "Content-Type: application/json" \
  -d '{"prompt": "What is machine learning?"}'
# Detokenize
curl -s -X POST http://0.0.0.0:8000/detokenize \
  -H "Content-Type: application/json" \
  -d '{"tokens": [1841, 603, 6479, 6044, 235336]}'
# Pooling: single text
curl -s -X POST http://0.0.0.0:8000/pooling \
  -H "Content-Type: application/json" \
  -d '{"model":"vidore/colpali-v1.3-hf","input":"What is artificial intelligence?"}'
# Pooling: batch texts
curl -s -X POST http://0.0.0.0:8000/pooling \
  -H "Content-Type: application/json" \
  -d '{"model":"vidore/colpali-v1.3-hf","input":["query one","query two","query three"]}'
# Pooling: image (chat format with base64)
curl -s -X POST http://0.0.0.0:8000/pooling \
  -H "Content-Type: application/json" \
  -d '{"model":"vidore/colpali-v1.3-hf","messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,<BASE64_STRING>"}}]}]}'
# Score: text vs text
curl -s -X POST http://0.0.0.0:8000/score \
  -H "Content-Type: application/json" \
  -d '{"model":"vidore/colpali-v1.3-hf","text_1":"What is machine learning?","text_2":["Machine learning is a subset of AI.","The weather is nice today."]}'
# Score: text vs image
curl -s -X POST http://0.0.0.0:8000/score \
  -H "Content-Type: application/json" \
  -d '{"model":"vidore/colpali-v1.3-hf","text_1":"a red square image","text_2":[{"content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,<BASE64_STRING>"}}]}]}'
# Rerank
curl -s -X POST http://0.0.0.0:8000/rerank \
  -H "Content-Type: application/json" \
  -d '{"model":"vidore/colpali-v1.3-hf","query":"What is machine learning?","documents":["Machine learning is a branch of artificial intelligence.","The weather today is sunny and warm.","Deep learning uses neural networks with many layers.","Python is a popular programming language."]}'

Test Result

Endpoint Status Result
/tokenize OK [1841, 603, 6479, 6044, 235336]
/detokenize OK "What is machine learning?"
/pooling (text) OK 1 item, 5x128
/pooling (batch) OK 3 items, 2x128 each
/pooling (image) OK 1 item, 1026x128
/score (text vs text) OK ML=2.28, weather=0.85
/score (text vs image) OK score=1.69
/rerank OK ML(2.49) > DL(1.32) > Python(0.91) > weather(0.82)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link

mergify bot commented Mar 11, 2026

Documentation preview: https://vllm--36818.org.readthedocs.build/en/36818/

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models labels Mar 11, 2026
@Kaonael Kaonael force-pushed the add-colpali-model branch from d43778c to 98c7aa2 Compare March 11, 2026 20:14
Add support for ColPali (ColBERT-style late interaction on PaliGemma
backbone) for multi-vector document retrieval and reranking.

- Model implementation extending PaliGemmaForConditionalGeneration
  with a linear projection head for per-token embeddings
- Custom ColPaliConfig extending PaliGemmaConfig with embedding
  projection fields (no trust_remote_code needed)
- Registration in model registry and config registry
- Multimodal processor handling for image+text pooling inputs
- Weight loading with support for multiple checkpoint naming
  conventions (HF transformers, colpali-engine)
- Comprehensive tests: token embedding, late interaction scoring,
  relevance ordering, and multimodal scoring
- Added to test model registry and supported models documentation

Target model: vidore/colpali-v1.3-hf
Reference: https://arxiv.org/abs/2407.01449

Signed-off-by: Nikita Sukharev <kaonael@gmail.com>
@Kaonael Kaonael force-pushed the add-colpali-model branch from 98c7aa2 to ff8e187 Compare March 11, 2026 20:18
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the ColPali late interaction model, enabling multi-vector document retrieval and reranking. The changes include adding the model implementation, configuration, and necessary registry updates, along with comprehensive test cases for token embedding, late interaction scoring, and multimodal scenarios. While the overall implementation is robust, there are several areas identified for improvement related to test precision, type safety, and potential runtime issues in the model's forward pass and weight loading. Addressing these concerns will enhance the model's correctness, maintainability, and efficiency.

inputs_embeds=inputs_embeds,
**kwargs,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The type: ignore comment here masks a potential type mismatch. If super().forward can indeed return something other than a torch.Tensor, the subsequent operations (self.custom_text_proj(hidden_states) and torch.nn.functional.normalize) will fail with an AttributeError or TypeError. It's critical to ensure hidden_states is always a tensor or to handle the non-tensor case explicitly and robustly.

Consider refactoring super().forward to always return a torch.Tensor or adding more specific error handling for non-tensor returns.

Comment on lines +195 to +199

if self.custom_text_proj is not None:
proj_dtype = self.custom_text_proj.weight.dtype
if hidden_states.dtype != proj_dtype:
hidden_states = hidden_states.to(proj_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The custom_text_proj can be None if embed_dim is not found in the config during __init__. If forward is called before load_weights has successfully inferred embed_dim and initialized custom_text_proj, this block will be skipped, and the projection will not occur. This could lead to incorrect outputs or unexpected behavior if the projection is a mandatory step for the model's functionality.

Ensure that custom_text_proj is always initialized before forward is called, or raise an error if it's None when it's expected to be present.

Comment on lines +109 to +113
torch.testing.assert_close(
norms,
torch.ones_like(norms),
rtol=1e-2,
atol=1e-2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The rtol and atol values of 1e-2 for torch.testing.assert_close might be too lenient for verifying L2 normalization. This could potentially mask subtle precision issues in the token embeddings. Consider tightening these tolerances to ensure higher numerical accuracy.

For example, a tolerance of 1e-4 or 1e-5 might be more appropriate depending on the expected precision.

Suggested change
torch.testing.assert_close(
norms,
torch.ones_like(norms),
rtol=1e-2,
atol=1e-2,
torch.testing.assert_close(
norms,
torch.ones_like(norms),
rtol=1e-4,
atol=1e-4,
)

vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])

assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A relative tolerance of 0.01 (1%) for pytest.approx in the late interaction scoring test is quite high. This could allow for significant discrepancies between manual and vLLM computed scores to pass unnoticed. It's recommended to use a stricter tolerance to ensure the scoring mechanism is highly accurate.

Consider reducing the rel value to 1e-3 or 1e-4.

Suggested change
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.001)

for s in scores:
assert isinstance(s.outputs.score, float)
# Text document about France should score higher than a random image
assert scores[0].outputs.score > scores[1].outputs.score
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert scores[0].outputs.score > scores[1].outputs.score provides a weak check for relevance ordering. While it confirms one score is higher, it doesn't quantify the expected difference. For a more robust test, especially in retrieval tasks, it's often beneficial to assert a minimum significant difference or a specific ordering with a margin. This helps catch regressions where the relevant document still scores higher but by a negligible amount.

Consider adding a check for a minimum score difference, e.g., scores[0].outputs.score - scores[1].outputs.score > some_threshold.

"""

def get_hf_config(self):
return self.ctx.get_hf_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Bypassing the strict type check for hf_config by directly returning self.ctx.get_hf_config() can introduce fragility. If self.ctx.get_hf_config() returns an object that is not compatible with the expected PaliGemmaConfig interface, it could lead to runtime errors later. While the comment explains the reason, it's generally safer to ensure type compatibility or explicitly handle potential type mismatches.

Consider adding a runtime check or a more specific type annotation if the returned config is indeed a custom type that behaves like PaliGemmaConfig.

Comment on lines +54 to +55
# Force standard PaliGemmaProcessor even when trust_remote_code=True.
return self.ctx.get_hf_processor(PaliGemmaProcessor, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the get_hf_config method, directly calling self.ctx.get_hf_processor without a strict type check for the returned processor can lead to issues. If the ctx provides a processor that is not a PaliGemmaProcessor or a compatible subclass, subsequent operations might fail. Ensuring type safety here is crucial for the model's stability.

Consider adding a runtime assertion or a more explicit cast if the ctx is guaranteed to return a compatible processor.

mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data and not prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Providing a default prompt "Describe the image." when the input prompt is empty and multimodal data is present could lead to unintended model behavior or bias. If an empty prompt has a specific semantic meaning in certain contexts, this default could alter the model's interpretation. While it addresses the PaliGemmaProcessor requirement, it's important to consider if this default is always appropriate.

Consider if there's a way to make this default prompt configurable or to provide a more context-aware default if an empty prompt is truly ambiguous.

Comment on lines +197 to +198
proj_dtype = self.custom_text_proj.weight.dtype
if hidden_states.dtype != proj_dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performing hidden_states.to(proj_dtype) inside the forward method on every call can introduce an unnecessary performance overhead if the dtype frequently differs. It's more efficient to ensure that hidden_states already has the correct dtype before entering the forward pass, or to perform this conversion only when strictly necessary (e.g., once during initialization or when the dtype changes).

Consider ensuring hidden_states is already in proj_dtype earlier in the pipeline or optimizing this conversion.


for name, weight in proj_weights:
if self.embed_dim is None and "weight" in name:
self.embed_dim = weight.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic has_bias = any("bias" in n for n, _ in proj_weights) assumes that if any bias weight is present in the proj_weights list, the nn.Linear layer should be initialized with bias=True. This might not be robust if different ColPali variants have varying bias configurations or if a bias weight is present but not intended for the custom_text_proj layer. It's safer to explicitly check for the bias weight associated with custom_text_proj or rely on a configuration parameter.

Consider checking for a specific bias weight name (e.g., "custom_text_proj.bias") or using a config flag to determine has_bias.

tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data and not prompt:
prompt = "Describe the image."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the responsibility of the user, I think there is no need for us to provide a default prompt except through chat template.

hidden_states = self.custom_text_proj(hidden_states)

# L2 normalize
return torch.nn.functional.normalize(hidden_states, p=2, dim=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normalization should be done by pooler

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants