22Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling."""
33
44import torch
5- from transformers import AutoModel , AutoTokenizer
5+ import torch .nn as nn
6+ from transformers import BertPreTrainedModel , BertModel , AutoModel , AutoTokenizer
67from typing import List , Optional , Union
78from math import ceil
89
@@ -67,17 +68,140 @@ def _insert_token(
6768 return updated_output
6869
6970
70- def _colbert_score (
71- q_reps ,
72- p_reps ,
73- q_mask : torch .Tensor ,
74- p_mask : torch .Tensor ,
75- ):
71+ def _colbert_score (q_reps , p_reps , q_mask : torch .Tensor , p_mask : torch .Tensor ):
72+ # calc max sim
73+ # base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py
74+
75+ # Assert that all q_reps are at least as long as the query length
76+ assert (
77+ q_reps .shape [1 ] >= q_mask .shape [1 ]
78+ ), f"q_reps should have at least { q_mask .shape [1 ]} tokens, but has { q_reps .shape [1 ]} "
79+
7680 token_scores = torch .einsum ("qin,pjn->qipj" , q_reps , p_reps )
7781 token_scores = token_scores .masked_fill (p_mask .unsqueeze (0 ).unsqueeze (0 ) == 0 , - 1e4 )
7882 scores , _ = token_scores .max (- 1 )
83+ scores = scores .sum (1 ) / q_mask .sum (- 1 , keepdim = True )
84+ return scores
85+
86+
87+ class ColBERTModel (BertPreTrainedModel ):
88+ def __init__ (self , config ):
89+ super ().__init__ (config )
90+ self .bert = BertModel (config )
91+ self .linear = nn .Linear (config .hidden_size , 128 , bias = False )
92+ self .init_weights ()
93+
94+ def forward (
95+ self ,
96+ input_ids = None ,
97+ attention_mask = None ,
98+ token_type_ids = None ,
99+ position_ids = None ,
100+ head_mask = None ,
101+ inputs_embeds = None ,
102+ encoder_hidden_states = None ,
103+ encoder_attention_mask = None ,
104+ output_attentions = None ,
105+ output_hidden_states = None ,
106+ ):
107+ outputs = self .bert (
108+ input_ids ,
109+ attention_mask = attention_mask ,
110+ token_type_ids = token_type_ids ,
111+ position_ids = position_ids ,
112+ head_mask = head_mask ,
113+ inputs_embeds = inputs_embeds ,
114+ encoder_hidden_states = encoder_hidden_states ,
115+ encoder_attention_mask = encoder_attention_mask ,
116+ output_attentions = output_attentions ,
117+ output_hidden_states = True , # Always output hidden states
118+ )
119+
120+ sequence_output = outputs [0 ]
121+
122+ return self .linear (sequence_output )
123+
124+ def _encode (self , texts : list [str ], insert_token_id : int , is_query : bool = False ):
125+ encoding = self .tokenizer (
126+ texts ,
127+ return_tensors = "pt" ,
128+ padding = True ,
129+ max_length = self .max_length - 1 , # for insert token
130+ truncation = True ,
131+ )
132+ encoding = _insert_token (encoding , insert_token_id ) # type: ignore
79133
80- return scores .sum (1 ) / q_mask [:, 1 :].sum (- 1 , keepdim = True )
134+ if is_query :
135+ mask_token_id = self .tokenizer .mask_token_id
136+
137+ new_encodings = {"input_ids" : [], "attention_mask" : []}
138+
139+ for i , input_ids in enumerate (encoding ["input_ids" ]):
140+ original_length = (
141+ (input_ids != self .tokenizer .pad_token_id ).sum ().item ()
142+ )
143+
144+ # Calculate QLEN dynamically for each query
145+ if original_length % 32 <= 8 :
146+ QLEN = original_length + 8
147+ else :
148+ QLEN = ceil (original_length / 32 ) * 32
149+
150+ if original_length < QLEN :
151+ pad_length = QLEN - original_length
152+ padded_input_ids = input_ids .tolist () + [mask_token_id ] * pad_length
153+ padded_attention_mask = (
154+ encoding ["attention_mask" ][i ].tolist () + [0 ] * pad_length
155+ )
156+ else :
157+ padded_input_ids = input_ids [:QLEN ].tolist ()
158+ padded_attention_mask = encoding ["attention_mask" ][i ][
159+ :QLEN
160+ ].tolist ()
161+
162+ new_encodings ["input_ids" ].append (padded_input_ids )
163+ new_encodings ["attention_mask" ].append (padded_attention_mask )
164+
165+ for key in new_encodings :
166+ new_encodings [key ] = torch .tensor (
167+ new_encodings [key ], device = self .device
168+ )
169+
170+ encoding = new_encodings
171+
172+ encoding = {key : value .to (self .device ) for key , value in encoding .items ()}
173+ return encoding
174+
175+ def _query_encode (self , query : list [str ]):
176+ return self ._encode (query , self .query_token_id , is_query = True )
177+
178+ def _document_encode (self , documents : list [str ]):
179+ return self ._encode (documents , self .document_token_id )
180+
181+ def _to_embs (self , encoding ) -> torch .Tensor :
182+ with torch .no_grad ():
183+ # embs = self.model(**encoding).last_hidden_state.squeeze(1)
184+ embs = self .model (** encoding )
185+ if self .normalize :
186+ embs = embs / embs .norm (dim = - 1 , keepdim = True )
187+ return embs
188+
189+ def _rerank (self , query : str , documents : list [str ]) -> list [float ]:
190+ query_encoding = self ._query_encode ([query ])
191+ documents_encoding = self ._document_encode (documents )
192+ query_embeddings = self ._to_embs (query_encoding )
193+ document_embeddings = self ._to_embs (documents_encoding )
194+ scores = (
195+ _colbert_score (
196+ query_embeddings ,
197+ document_embeddings ,
198+ query_encoding ["attention_mask" ],
199+ documents_encoding ["attention_mask" ],
200+ )
201+ .cpu ()
202+ .tolist ()[0 ]
203+ )
204+ return scores
81205
82206
83207class ColBERTRanker (BaseRanker ):
@@ -159,14 +283,9 @@ def _colbert_rank(
159283 return scores
160284
161285 def _query_encode (self , query : list [str ]):
162- tokenized_query_length = len (self .tokenizer .encode (query [0 ]))
163- max_length = max (
164- ceil (tokenized_query_length / 16 ) * 16 , self .query_max_length
165- ) # Ensure not smaller than query_max_length
166- max_length = int (
167- min (max_length , self .doc_max_length )
168- ) # Ensure not larger than doc_max_length
169- return self ._encode (query , self .query_token_id , max_length )
286+ return self ._encode (
287+ query , self .query_token_id , max_length = self .doc_max_length , is_query = True
288+ )
170289
171290 def _document_encode (self , documents : list [str ]):
172291 tokenized_doc_lengths = [
@@ -189,7 +308,13 @@ def _document_encode(self, documents: list[str]):
189308 ) # Ensure not larger than doc_max_length
190309 return self ._encode (documents , self .document_token_id , max_length )
191310
192- def _encode (self , texts : list [str ], insert_token_id : int , max_length : int ):
311+ def _encode (
312+ self ,
313+ texts : list [str ],
314+ insert_token_id : int ,
315+ max_length : int ,
316+ is_query : bool = False ,
317+ ):
193318 encoding = self .tokenizer (
194319 texts ,
195320 return_tensors = "pt" ,
@@ -198,6 +323,45 @@ def _encode(self, texts: list[str], insert_token_id: int, max_length: int):
198323 truncation = True ,
199324 )
200325 encoding = _insert_token (encoding , insert_token_id ) # type: ignore
326+
327+ if is_query :
328+ mask_token_id = self .tokenizer .mask_token_id
329+
330+ new_encodings = {"input_ids" : [], "attention_mask" : []}
331+
332+ for i , input_ids in enumerate (encoding ["input_ids" ]):
333+ original_length = (
334+ (input_ids != self .tokenizer .pad_token_id ).sum ().item ()
335+ )
336+
337+ # Calculate QLEN dynamically for each query
338+ if original_length % 32 <= 8 :
339+ QLEN = original_length + 8
340+ else :
341+ QLEN = ceil (original_length / 32 ) * 32
342+
343+ if original_length < QLEN :
344+ pad_length = QLEN - original_length
345+ padded_input_ids = input_ids .tolist () + [mask_token_id ] * pad_length
346+ padded_attention_mask = (
347+ encoding ["attention_mask" ][i ].tolist () + [0 ] * pad_length
348+ )
349+ else :
350+ padded_input_ids = input_ids [:QLEN ].tolist ()
351+ padded_attention_mask = encoding ["attention_mask" ][i ][
352+ :QLEN
353+ ].tolist ()
354+
355+ new_encodings ["input_ids" ].append (padded_input_ids )
356+ new_encodings ["attention_mask" ].append (padded_attention_mask )
357+
358+ for key in new_encodings :
359+ new_encodings [key ] = torch .tensor (
360+ new_encodings [key ], device = self .device
361+ )
362+
363+ encoding = new_encodings
364+
201365 encoding = {key : value .to (self .device ) for key , value in encoding .items ()}
202366 return encoding
203367
0 commit comments