[Model] Add ColPali late interaction model for multi-modal retrieval#36818
[Model] Add ColPali late interaction model for multi-modal retrieval#36818Kaonael wants to merge 1 commit intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
|
Documentation preview: https://vllm--36818.org.readthedocs.build/en/36818/ |
d43778c to
98c7aa2
Compare
Add support for ColPali (ColBERT-style late interaction on PaliGemma backbone) for multi-vector document retrieval and reranking. - Model implementation extending PaliGemmaForConditionalGeneration with a linear projection head for per-token embeddings - Custom ColPaliConfig extending PaliGemmaConfig with embedding projection fields (no trust_remote_code needed) - Registration in model registry and config registry - Multimodal processor handling for image+text pooling inputs - Weight loading with support for multiple checkpoint naming conventions (HF transformers, colpali-engine) - Comprehensive tests: token embedding, late interaction scoring, relevance ordering, and multimodal scoring - Added to test model registry and supported models documentation Target model: vidore/colpali-v1.3-hf Reference: https://arxiv.org/abs/2407.01449 Signed-off-by: Nikita Sukharev <kaonael@gmail.com>
98c7aa2 to
ff8e187
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces support for the ColPali late interaction model, enabling multi-vector document retrieval and reranking. The changes include adding the model implementation, configuration, and necessary registry updates, along with comprehensive test cases for token embedding, late interaction scoring, and multimodal scenarios. While the overall implementation is robust, there are several areas identified for improvement related to test precision, type safety, and potential runtime issues in the model's forward pass and weight loading. Addressing these concerns will enhance the model's correctness, maintainability, and efficiency.
| inputs_embeds=inputs_embeds, | ||
| **kwargs, | ||
| ) | ||
|
|
There was a problem hiding this comment.
The type: ignore comment here masks a potential type mismatch. If super().forward can indeed return something other than a torch.Tensor, the subsequent operations (self.custom_text_proj(hidden_states) and torch.nn.functional.normalize) will fail with an AttributeError or TypeError. It's critical to ensure hidden_states is always a tensor or to handle the non-tensor case explicitly and robustly.
Consider refactoring super().forward to always return a torch.Tensor or adding more specific error handling for non-tensor returns.
|
|
||
| if self.custom_text_proj is not None: | ||
| proj_dtype = self.custom_text_proj.weight.dtype | ||
| if hidden_states.dtype != proj_dtype: | ||
| hidden_states = hidden_states.to(proj_dtype) |
There was a problem hiding this comment.
The custom_text_proj can be None if embed_dim is not found in the config during __init__. If forward is called before load_weights has successfully inferred embed_dim and initialized custom_text_proj, this block will be skipped, and the projection will not occur. This could lead to incorrect outputs or unexpected behavior if the projection is a mandatory step for the model's functionality.
Ensure that custom_text_proj is always initialized before forward is called, or raise an error if it's None when it's expected to be present.
| torch.testing.assert_close( | ||
| norms, | ||
| torch.ones_like(norms), | ||
| rtol=1e-2, | ||
| atol=1e-2, |
There was a problem hiding this comment.
The rtol and atol values of 1e-2 for torch.testing.assert_close might be too lenient for verifying L2 normalization. This could potentially mask subtle precision issues in the token embeddings. Consider tightening these tolerances to ensure higher numerical accuracy.
For example, a tolerance of 1e-4 or 1e-5 might be more appropriate depending on the expected precision.
| torch.testing.assert_close( | |
| norms, | |
| torch.ones_like(norms), | |
| rtol=1e-2, | |
| atol=1e-2, | |
| torch.testing.assert_close( | |
| norms, | |
| torch.ones_like(norms), | |
| rtol=1e-4, | |
| atol=1e-4, | |
| ) |
| vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0]) | ||
|
|
||
| assert len(vllm_scores) == 1 | ||
| assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01) |
There was a problem hiding this comment.
A relative tolerance of 0.01 (1%) for pytest.approx in the late interaction scoring test is quite high. This could allow for significant discrepancies between manual and vLLM computed scores to pass unnoticed. It's recommended to use a stricter tolerance to ensure the scoring mechanism is highly accurate.
Consider reducing the rel value to 1e-3 or 1e-4.
| assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01) | |
| assert vllm_scores[0] == pytest.approx(manual_score, rel=0.001) |
| for s in scores: | ||
| assert isinstance(s.outputs.score, float) | ||
| # Text document about France should score higher than a random image | ||
| assert scores[0].outputs.score > scores[1].outputs.score |
There was a problem hiding this comment.
The assertion assert scores[0].outputs.score > scores[1].outputs.score provides a weak check for relevance ordering. While it confirms one score is higher, it doesn't quantify the expected difference. For a more robust test, especially in retrieval tasks, it's often beneficial to assert a minimum significant difference or a specific ordering with a margin. This helps catch regressions where the relevant document still scores higher but by a negligible amount.
Consider adding a check for a minimum score difference, e.g., scores[0].outputs.score - scores[1].outputs.score > some_threshold.
| """ | ||
|
|
||
| def get_hf_config(self): | ||
| return self.ctx.get_hf_config() |
There was a problem hiding this comment.
Bypassing the strict type check for hf_config by directly returning self.ctx.get_hf_config() can introduce fragility. If self.ctx.get_hf_config() returns an object that is not compatible with the expected PaliGemmaConfig interface, it could lead to runtime errors later. While the comment explains the reason, it's generally safer to ensure type compatibility or explicitly handle potential type mismatches.
Consider adding a runtime check or a more specific type annotation if the returned config is indeed a custom type that behaves like PaliGemmaConfig.
| # Force standard PaliGemmaProcessor even when trust_remote_code=True. | ||
| return self.ctx.get_hf_processor(PaliGemmaProcessor, **kwargs) |
There was a problem hiding this comment.
Similar to the get_hf_config method, directly calling self.ctx.get_hf_processor without a strict type check for the returned processor can lead to issues. If the ctx provides a processor that is not a PaliGemmaProcessor or a compatible subclass, subsequent operations might fail. Ensuring type safety here is crucial for the model's stability.
Consider adding a runtime assertion or a more explicit cast if the ctx is guaranteed to return a compatible processor.
| mm_kwargs: Mapping[str, object], | ||
| tok_kwargs: Mapping[str, object], | ||
| ) -> BatchFeature: | ||
| if mm_data and not prompt: |
There was a problem hiding this comment.
Providing a default prompt "Describe the image." when the input prompt is empty and multimodal data is present could lead to unintended model behavior or bias. If an empty prompt has a specific semantic meaning in certain contexts, this default could alter the model's interpretation. While it addresses the PaliGemmaProcessor requirement, it's important to consider if this default is always appropriate.
Consider if there's a way to make this default prompt configurable or to provide a more context-aware default if an empty prompt is truly ambiguous.
| proj_dtype = self.custom_text_proj.weight.dtype | ||
| if hidden_states.dtype != proj_dtype: |
There was a problem hiding this comment.
Performing hidden_states.to(proj_dtype) inside the forward method on every call can introduce an unnecessary performance overhead if the dtype frequently differs. It's more efficient to ensure that hidden_states already has the correct dtype before entering the forward pass, or to perform this conversion only when strictly necessary (e.g., once during initialization or when the dtype changes).
Consider ensuring hidden_states is already in proj_dtype earlier in the pipeline or optimizing this conversion.
|
|
||
| for name, weight in proj_weights: | ||
| if self.embed_dim is None and "weight" in name: | ||
| self.embed_dim = weight.shape[0] |
There was a problem hiding this comment.
The logic has_bias = any("bias" in n for n, _ in proj_weights) assumes that if any bias weight is present in the proj_weights list, the nn.Linear layer should be initialized with bias=True. This might not be robust if different ColPali variants have varying bias configurations or if a bias weight is present but not intended for the custom_text_proj layer. It's safer to explicitly check for the bias weight associated with custom_text_proj or rely on a configuration parameter.
Consider checking for a specific bias weight name (e.g., "custom_text_proj.bias") or using a config flag to determine has_bias.
| tok_kwargs: Mapping[str, object], | ||
| ) -> BatchFeature: | ||
| if mm_data and not prompt: | ||
| prompt = "Describe the image." |
There was a problem hiding this comment.
This should be the responsibility of the user, I think there is no need for us to provide a default prompt except through chat template.
| hidden_states = self.custom_text_proj(hidden_states) | ||
|
|
||
| # L2 normalize | ||
| return torch.nn.functional.normalize(hidden_states, p=2, dim=-1) |
There was a problem hiding this comment.
Normalization should be done by pooler
Add support for ColPali (ColBERT-style late interaction on PaliGemma backbone) for multi-vector document retrieval and reranking.
Target model: vidore/colpali-v1.3-hf
Reference: https://arxiv.org/abs/2407.01449
Purpose
Add support for https://arxiv.org/abs/2407.01449 — a ColBERT-style late interaction model built on PaliGemma backbone (SigLIP + Gemma) for multi-vector document retrieval and reranking.
Target model: vidore/colpali-v1.3-hf.
Test Plan
Test Result
/tokenize[1841, 603, 6479, 6044, 235336]/detokenize"What is machine learning?"/pooling(text)/pooling(batch)/pooling(image)/score(text vs text)/score(text vs image)/rerankEssential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.