Skip to content

Bug: score_colbert does not mask query tokens #75

@eelcovdw

Description

@eelcovdw

Hi, found a possible bug in

def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
# calc max sim
# base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
# Assert that all q_reps are at least as long as the query length
assert (
q_reps.shape[1] >= q_mask.shape[1]
), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}"
token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
scores, _ = token_scores.max(-1)
scores = scores.sum(1) / q_mask.sum(-1, keepdim=True)
return scores

The _colbert_score implementation is not masking over the query lengths, but it does average over the masked query length. The result is an inflated score. Minimal example:

from rerankers import Reranker

reranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert")
print("Reranker class:", type(reranker))

query = "machine learning"
doc_1 = "Machine learning is a subset of AI"
doc_2 = "hello world!"

res = reranker.rank(query=query, docs=[doc_1, doc_2])
for r in res.results:
    print(f"Score: {r.score:.4f} - Doc: {r.document.text[:20]}...")

with the above code snippet, I get scores around ~2.0. To clarify the issue I placed some debugging prints in _score_colbert

def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
    # calc max sim
    # base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py

    # Assert that all q_reps are at least as long as the query length
    assert (
        q_reps.shape[1] >= q_mask.shape[1]
    ), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}"

    token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
    token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
    scores_unnormalized, _ = token_scores.max(-1)
    
    print(f"scores_unnormalized shape: {scores_unnormalized.shape}")
    print(f"q_mask sum (actual query length): {q_mask.sum(-1, keepdim=True).item()}")
    print(f"scores_unnormalized:\n{scores_unnormalized}")
    print(f"Bug: summing over {scores_unnormalized.shape[1]} positions but dividing by {q_mask.sum(-1, keepdim=True).item()}")
    print()
    # Old implementation does not mask query
    scores_without_query_mask = scores_unnormalized.sum(1) / q_mask.sum(-1, keepdim=True)
    print(f"Incorrect scores: {scores_without_query_mask}")
    
    # New implementation masks query
    scores = (scores_unnormalized * q_mask.unsqueeze(-1)).sum(1) / q_mask.sum(-1, keepdim=True)
    print(f"Corrected scores: {scores}")

    return scores

The issue is clear in the output, the score does not mask out padded query tokens:

scores_unnormalized shape: torch.Size([1, 13, 2])
q_mask sum (actual query length): 5
scores_unnormalized:
tensor([[[0.9986, 0.9955],
         [0.9794, 0.9453],
         [0.9790, 0.6597],
         [0.9794, 0.7352],
         [0.9986, 0.9955],
         [0.9795, 0.9799],
         [0.9807, 0.9810],
         [0.9795, 0.9799],
         [0.9798, 0.9799],
         [0.9807, 0.9808],
         [0.9815, 0.9816],
         [0.9819, 0.9822],
         [0.9824, 0.9825]]], device='mps:0')
Bug: summing over 13 positions but dividing by 5

Incorrect scores: tensor([[2.5562, 2.4358]], device='mps:0')
Corrected scores: tensor([[0.9870, 0.8662]], device='mps:0')

proposed fix

mask the query token scores with the query mask before averaging

scores = (scores_unnormalized * q_mask.unsqueeze(-1)).sum(1) / q_mask.sum(-1, keepdim=True)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions