1616 "mixedbread.ai" : "https://api.mixedbread.ai/v1/reranking" ,
1717}
1818
19+ DOCUMENT_KEY_MAPPING = {
20+ "mixedbread.ai" : "input" ,
21+ "text-embeddings-inference" :"texts"
22+ }
23+ RETURN_DOCUMENTS_KEY_MAPPING = {
24+ "mixedbread.ai" :"return_input" ,
25+ "text-embeddings-inference" :"return_text"
26+ }
27+ RESULTS_KEY_MAPPING = {
28+ "voyage" : "data" ,
29+ "mixedbread.ai" : "data" ,
30+ "text-embeddings-inference" : None
31+ }
32+ SCORE_KEY_MAPPING = {
33+ "mixedbread.ai" : "score" ,
34+ "text-embeddings-inference" :"score"
35+ }
1936
2037class APIRanker (BaseRanker ):
21- def __init__ (self , model : str , api_key : str , api_provider : str , verbose : int = 1 ):
38+ def __init__ (self , model : str , api_key : str , api_provider : str , verbose : int = 1 , url : str = None ):
2239 self .api_key = api_key
2340 self .model = model
2441 self .api_provider = api_provider .lower ()
@@ -29,34 +46,31 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1
2946 "content-type" : "application/json" ,
3047 "Authorization" : f"Bearer { self .api_key } " ,
3148 }
32- self .url = URLS [self .api_provider ]
49+ self .url = url if url else URLS [self .api_provider ]
3350
3451
3552 def _get_document_text (self , r : dict ) -> str :
3653 if self .api_provider == "voyage" :
3754 return r ["document" ]
3855 elif self .api_provider == "mixedbread.ai" :
3956 return r ["input" ]
57+ elif self .api_provider == "text-embeddings-inference" :
58+ return r ["text" ]
4059 else :
4160 return r ["document" ]["text" ]
4261
4362 def _get_score (self , r : dict ) -> float :
44- if self .api_provider == "mixedbread.ai" :
45- return r ["score" ]
46- return r ["relevance_score" ]
63+ score_key = SCORE_KEY_MAPPING .get (self .api_provider ,"relevance_score" )
64+ return r [score_key ]
4765
4866 def _parse_response (
4967 self , response : dict , docs : List [Document ],
5068 ) -> RankedResults :
5169 ranked_docs = []
52- results_key = (
53- "results"
54- if self .api_provider not in ["voyage" , "mixedbread.ai" ]
55- else "data"
56- )
70+ results_key = RESULTS_KEY_MAPPING .get (self .api_provider ,"results" )
5771 print (response )
5872
59- for i , r in enumerate (response [results_key ]):
73+ for i , r in enumerate (response [results_key ] if results_key else response ):
6074 ranked_docs .append (
6175 Result (
6276 document = docs [r ["index" ]],
@@ -86,12 +100,8 @@ def _format_payload(self, query: str, docs: List[str]) -> str:
86100 top_key = (
87101 "top_n" if self .api_provider not in ["voyage" , "mixedbread.ai" ] else "top_k"
88102 )
89- documents_key = "documents" if self .api_provider != "mixedbread.ai" else "input"
90- return_documents_key = (
91- "return_documents"
92- if self .api_provider != "mixedbread.ai"
93- else "return_input"
94- )
103+ documents_key = DOCUMENT_KEY_MAPPING .get (self .api_provider ,"documents" )
104+ return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING .get (self .api_provider ,"return_documents" )
95105
96106 payload = {
97107 "model" : self .model ,
0 commit comments