-
Notifications
You must be signed in to change notification settings - Fork 98
Open
Description
Hi, found a possible bug in
rerankers/rerankers/models/colbert_ranker.py
Lines 71 to 84 in 7bb2521
| def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor): | |
| # calc max sim | |
| # base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py | |
| # Assert that all q_reps are at least as long as the query length | |
| assert ( | |
| q_reps.shape[1] >= q_mask.shape[1] | |
| ), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}" | |
| token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps) | |
| token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4) | |
| scores, _ = token_scores.max(-1) | |
| scores = scores.sum(1) / q_mask.sum(-1, keepdim=True) | |
| return scores |
The _colbert_score implementation is not masking over the query lengths, but it does average over the masked query length. The result is an inflated score. Minimal example:
from rerankers import Reranker
reranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert")
print("Reranker class:", type(reranker))
query = "machine learning"
doc_1 = "Machine learning is a subset of AI"
doc_2 = "hello world!"
res = reranker.rank(query=query, docs=[doc_1, doc_2])
for r in res.results:
print(f"Score: {r.score:.4f} - Doc: {r.document.text[:20]}...")
with the above code snippet, I get scores around ~2.0. To clarify the issue I placed some debugging prints in _score_colbert
def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
# calc max sim
# base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
# Assert that all q_reps are at least as long as the query length
assert (
q_reps.shape[1] >= q_mask.shape[1]
), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}"
token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
scores_unnormalized, _ = token_scores.max(-1)
print(f"scores_unnormalized shape: {scores_unnormalized.shape}")
print(f"q_mask sum (actual query length): {q_mask.sum(-1, keepdim=True).item()}")
print(f"scores_unnormalized:\n{scores_unnormalized}")
print(f"Bug: summing over {scores_unnormalized.shape[1]} positions but dividing by {q_mask.sum(-1, keepdim=True).item()}")
print()
# Old implementation does not mask query
scores_without_query_mask = scores_unnormalized.sum(1) / q_mask.sum(-1, keepdim=True)
print(f"Incorrect scores: {scores_without_query_mask}")
# New implementation masks query
scores = (scores_unnormalized * q_mask.unsqueeze(-1)).sum(1) / q_mask.sum(-1, keepdim=True)
print(f"Corrected scores: {scores}")
return scores
The issue is clear in the output, the score does not mask out padded query tokens:
scores_unnormalized shape: torch.Size([1, 13, 2])
q_mask sum (actual query length): 5
scores_unnormalized:
tensor([[[0.9986, 0.9955],
[0.9794, 0.9453],
[0.9790, 0.6597],
[0.9794, 0.7352],
[0.9986, 0.9955],
[0.9795, 0.9799],
[0.9807, 0.9810],
[0.9795, 0.9799],
[0.9798, 0.9799],
[0.9807, 0.9808],
[0.9815, 0.9816],
[0.9819, 0.9822],
[0.9824, 0.9825]]], device='mps:0')
Bug: summing over 13 positions but dividing by 5
Incorrect scores: tensor([[2.5562, 2.4358]], device='mps:0')
Corrected scores: tensor([[0.9870, 0.8662]], device='mps:0')
proposed fix
mask the query token scores with the query mask before averaging
scores = (scores_unnormalized * q_mask.unsqueeze(-1)).sum(1) / q_mask.sum(-1, keepdim=True)
Metadata
Metadata
Assignees
Labels
No labels