Skip to content

Commit b407dc5

Browse files
authored
Merge pull request #20 from srisudarsan/main
feat: adds compatibility for reranking using text-embeddings-inference server
2 parents dab1eed + 3061c75 commit b407dc5

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

rerankers/models/api_rankers.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,26 @@
1616
"mixedbread.ai": "https://api.mixedbread.ai/v1/reranking",
1717
}
1818

19+
DOCUMENT_KEY_MAPPING = {
20+
"mixedbread.ai": "input",
21+
"text-embeddings-inference":"texts"
22+
}
23+
RETURN_DOCUMENTS_KEY_MAPPING = {
24+
"mixedbread.ai":"return_input",
25+
"text-embeddings-inference":"return_text"
26+
}
27+
RESULTS_KEY_MAPPING = {
28+
"voyage": "data",
29+
"mixedbread.ai": "data",
30+
"text-embeddings-inference": None
31+
}
32+
SCORE_KEY_MAPPING = {
33+
"mixedbread.ai": "score",
34+
"text-embeddings-inference":"score"
35+
}
1936

2037
class APIRanker(BaseRanker):
21-
def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1):
38+
def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1, url: str = None):
2239
self.api_key = api_key
2340
self.model = model
2441
self.api_provider = api_provider.lower()
@@ -29,34 +46,31 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1
2946
"content-type": "application/json",
3047
"Authorization": f"Bearer {self.api_key}",
3148
}
32-
self.url = URLS[self.api_provider]
49+
self.url = url if url else URLS[self.api_provider]
3350

3451

3552
def _get_document_text(self, r: dict) -> str:
3653
if self.api_provider == "voyage":
3754
return r["document"]
3855
elif self.api_provider == "mixedbread.ai":
3956
return r["input"]
57+
elif self.api_provider == "text-embeddings-inference":
58+
return r["text"]
4059
else:
4160
return r["document"]["text"]
4261

4362
def _get_score(self, r: dict) -> float:
44-
if self.api_provider == "mixedbread.ai":
45-
return r["score"]
46-
return r["relevance_score"]
63+
score_key = SCORE_KEY_MAPPING.get(self.api_provider,"relevance_score")
64+
return r[score_key]
4765

4866
def _parse_response(
4967
self, response: dict, docs: List[Document],
5068
) -> RankedResults:
5169
ranked_docs = []
52-
results_key = (
53-
"results"
54-
if self.api_provider not in ["voyage", "mixedbread.ai"]
55-
else "data"
56-
)
70+
results_key = RESULTS_KEY_MAPPING.get(self.api_provider,"results")
5771
print(response)
5872

59-
for i, r in enumerate(response[results_key]):
73+
for i, r in enumerate(response[results_key] if results_key else response):
6074
ranked_docs.append(
6175
Result(
6276
document=docs[r["index"]],
@@ -86,12 +100,8 @@ def _format_payload(self, query: str, docs: List[str]) -> str:
86100
top_key = (
87101
"top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k"
88102
)
89-
documents_key = "documents" if self.api_provider != "mixedbread.ai" else "input"
90-
return_documents_key = (
91-
"return_documents"
92-
if self.api_provider != "mixedbread.ai"
93-
else "return_input"
94-
)
103+
documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider,"documents")
104+
return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get(self.api_provider,"return_documents")
95105

96106
payload = {
97107
"model": self.model,

rerankers/reranker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES",
3131
},
3232
"flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"},
33+
"text-embeddings-inference": {"other": "BAAI/bge-reranker-base"}
3334
}
3435

3536
DEPS_MAPPING = {
@@ -43,7 +44,7 @@
4344
"RankLLMRanker": "rankllm",
4445
}
4546

46-
PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai"]
47+
PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"]
4748

4849

4950
def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> str:
@@ -69,6 +70,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
6970
"cohere": "APIRanker",
7071
"jina": "APIRanker",
7172
"voyage": "APIRanker",
73+
"text-embeddings-inference": "APIRanker",
7274
"rankgpt": "RankGPTRanker",
7375
"lit5": "LiT5Ranker",
7476
"t5": "T5Ranker",
@@ -92,6 +94,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
9294
"cohere": "APIRanker",
9395
"jina": "APIRanker",
9496
"voyage": "APIRanker",
97+
"text-embeddings-inference": "APIRanker",
9598
"ms-marco-minilm-l-12-v2": "FlashRankRanker",
9699
"ms-marco-multibert-l-12": "FlashRankRanker",
97100
"vicuna": "RankLLMRanker",

0 commit comments

Comments
 (0)