Skip to content

Commit dd61524

Browse files
authored
Merge pull request #21 from AnswerDotAI/improved_colbert
Improved colbert
2 parents b407dc5 + e3e12a8 commit dd61524

File tree

4 files changed

+184
-19
lines changed

4 files changed

+184
-19
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an
1414

1515
## Updates
1616

17+
- v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan).
1718
- v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers.
1819
- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
1920
- v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ packages = [
1414
name = "rerankers"
1515

1616

17-
version = "0.3.1"
17+
version = "0.4.0"
1818

1919
description = "A unified API for various document re-ranking models."
2020

rerankers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from rerankers.documents import Document
33

44
__all__ = ["Reranker", "Document"]
5-
__version__ = "0.3.1"
5+
__version__ = "0.4.0"

rerankers/models/colbert_ranker.py

Lines changed: 181 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling."""
33

44
import torch
5-
from transformers import AutoModel, AutoTokenizer
5+
import torch.nn as nn
6+
from transformers import BertPreTrainedModel, BertModel, AutoModel, AutoTokenizer
67
from typing import List, Optional, Union
78
from math import ceil
89

@@ -67,17 +68,140 @@ def _insert_token(
6768
return updated_output
6869

6970

70-
def _colbert_score(
71-
q_reps,
72-
p_reps,
73-
q_mask: torch.Tensor,
74-
p_mask: torch.Tensor,
75-
):
71+
def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
72+
# calc max sim
73+
# base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
74+
75+
# Assert that all q_reps are at least as long as the query length
76+
assert (
77+
q_reps.shape[1] >= q_mask.shape[1]
78+
), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}"
79+
7680
token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
7781
token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
7882
scores, _ = token_scores.max(-1)
83+
scores = scores.sum(1) / q_mask.sum(-1, keepdim=True)
84+
return scores
85+
86+
87+
class ColBERTModel(BertPreTrainedModel):
88+
def __init__(self, config):
89+
super().__init__(config)
90+
self.bert = BertModel(config)
91+
self.linear = nn.Linear(config.hidden_size, 128, bias=False)
92+
self.init_weights()
93+
94+
def forward(
95+
self,
96+
input_ids=None,
97+
attention_mask=None,
98+
token_type_ids=None,
99+
position_ids=None,
100+
head_mask=None,
101+
inputs_embeds=None,
102+
encoder_hidden_states=None,
103+
encoder_attention_mask=None,
104+
output_attentions=None,
105+
output_hidden_states=None,
106+
):
107+
outputs = self.bert(
108+
input_ids,
109+
attention_mask=attention_mask,
110+
token_type_ids=token_type_ids,
111+
position_ids=position_ids,
112+
head_mask=head_mask,
113+
inputs_embeds=inputs_embeds,
114+
encoder_hidden_states=encoder_hidden_states,
115+
encoder_attention_mask=encoder_attention_mask,
116+
output_attentions=output_attentions,
117+
output_hidden_states=True, # Always output hidden states
118+
)
119+
120+
sequence_output = outputs[0]
121+
122+
return self.linear(sequence_output)
123+
124+
def _encode(self, texts: list[str], insert_token_id: int, is_query: bool = False):
125+
encoding = self.tokenizer(
126+
texts,
127+
return_tensors="pt",
128+
padding=True,
129+
max_length=self.max_length - 1, # for insert token
130+
truncation=True,
131+
)
132+
encoding = _insert_token(encoding, insert_token_id) # type: ignore
79133

80-
return scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True)
134+
if is_query:
135+
mask_token_id = self.tokenizer.mask_token_id
136+
137+
new_encodings = {"input_ids": [], "attention_mask": []}
138+
139+
for i, input_ids in enumerate(encoding["input_ids"]):
140+
original_length = (
141+
(input_ids != self.tokenizer.pad_token_id).sum().item()
142+
)
143+
144+
# Calculate QLEN dynamically for each query
145+
if original_length % 32 <= 8:
146+
QLEN = original_length + 8
147+
else:
148+
QLEN = ceil(original_length / 32) * 32
149+
150+
if original_length < QLEN:
151+
pad_length = QLEN - original_length
152+
padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length
153+
padded_attention_mask = (
154+
encoding["attention_mask"][i].tolist() + [0] * pad_length
155+
)
156+
else:
157+
padded_input_ids = input_ids[:QLEN].tolist()
158+
padded_attention_mask = encoding["attention_mask"][i][
159+
:QLEN
160+
].tolist()
161+
162+
new_encodings["input_ids"].append(padded_input_ids)
163+
new_encodings["attention_mask"].append(padded_attention_mask)
164+
165+
for key in new_encodings:
166+
new_encodings[key] = torch.tensor(
167+
new_encodings[key], device=self.device
168+
)
169+
170+
encoding = new_encodings
171+
172+
encoding = {key: value.to(self.device) for key, value in encoding.items()}
173+
return encoding
174+
175+
def _query_encode(self, query: list[str]):
176+
return self._encode(query, self.query_token_id, is_query=True)
177+
178+
def _document_encode(self, documents: list[str]):
179+
return self._encode(documents, self.document_token_id)
180+
181+
def _to_embs(self, encoding) -> torch.Tensor:
182+
with torch.no_grad():
183+
# embs = self.model(**encoding).last_hidden_state.squeeze(1)
184+
embs = self.model(**encoding)
185+
if self.normalize:
186+
embs = embs / embs.norm(dim=-1, keepdim=True)
187+
return embs
188+
189+
def _rerank(self, query: str, documents: list[str]) -> list[float]:
190+
query_encoding = self._query_encode([query])
191+
documents_encoding = self._document_encode(documents)
192+
query_embeddings = self._to_embs(query_encoding)
193+
document_embeddings = self._to_embs(documents_encoding)
194+
scores = (
195+
_colbert_score(
196+
query_embeddings,
197+
document_embeddings,
198+
query_encoding["attention_mask"],
199+
documents_encoding["attention_mask"],
200+
)
201+
.cpu()
202+
.tolist()[0]
203+
)
204+
return scores
81205

82206

83207
class ColBERTRanker(BaseRanker):
@@ -159,14 +283,9 @@ def _colbert_rank(
159283
return scores
160284

161285
def _query_encode(self, query: list[str]):
162-
tokenized_query_length = len(self.tokenizer.encode(query[0]))
163-
max_length = max(
164-
ceil(tokenized_query_length / 16) * 16, self.query_max_length
165-
) # Ensure not smaller than query_max_length
166-
max_length = int(
167-
min(max_length, self.doc_max_length)
168-
) # Ensure not larger than doc_max_length
169-
return self._encode(query, self.query_token_id, max_length)
286+
return self._encode(
287+
query, self.query_token_id, max_length=self.doc_max_length, is_query=True
288+
)
170289

171290
def _document_encode(self, documents: list[str]):
172291
tokenized_doc_lengths = [
@@ -189,7 +308,13 @@ def _document_encode(self, documents: list[str]):
189308
) # Ensure not larger than doc_max_length
190309
return self._encode(documents, self.document_token_id, max_length)
191310

192-
def _encode(self, texts: list[str], insert_token_id: int, max_length: int):
311+
def _encode(
312+
self,
313+
texts: list[str],
314+
insert_token_id: int,
315+
max_length: int,
316+
is_query: bool = False,
317+
):
193318
encoding = self.tokenizer(
194319
texts,
195320
return_tensors="pt",
@@ -198,6 +323,45 @@ def _encode(self, texts: list[str], insert_token_id: int, max_length: int):
198323
truncation=True,
199324
)
200325
encoding = _insert_token(encoding, insert_token_id) # type: ignore
326+
327+
if is_query:
328+
mask_token_id = self.tokenizer.mask_token_id
329+
330+
new_encodings = {"input_ids": [], "attention_mask": []}
331+
332+
for i, input_ids in enumerate(encoding["input_ids"]):
333+
original_length = (
334+
(input_ids != self.tokenizer.pad_token_id).sum().item()
335+
)
336+
337+
# Calculate QLEN dynamically for each query
338+
if original_length % 32 <= 8:
339+
QLEN = original_length + 8
340+
else:
341+
QLEN = ceil(original_length / 32) * 32
342+
343+
if original_length < QLEN:
344+
pad_length = QLEN - original_length
345+
padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length
346+
padded_attention_mask = (
347+
encoding["attention_mask"][i].tolist() + [0] * pad_length
348+
)
349+
else:
350+
padded_input_ids = input_ids[:QLEN].tolist()
351+
padded_attention_mask = encoding["attention_mask"][i][
352+
:QLEN
353+
].tolist()
354+
355+
new_encodings["input_ids"].append(padded_input_ids)
356+
new_encodings["attention_mask"].append(padded_attention_mask)
357+
358+
for key in new_encodings:
359+
new_encodings[key] = torch.tensor(
360+
new_encodings[key], device=self.device
361+
)
362+
363+
encoding = new_encodings
364+
201365
encoding = {key: value.to(self.device) for key, value in encoding.items()}
202366
return encoding
203367

0 commit comments

Comments
 (0)