Skip to content

Commit 4ba30cc

Browse files
authored
Merge pull request #10 from AnswerDotAI/feat/flashrank_and_mixedbread
feat: support mixedbread API and flashrank
2 parents ac69cb9 + b133a90 commit 4ba30cc

File tree

6 files changed

+122
-17
lines changed

6 files changed

+122
-17
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ __pycache__/
44
*.py[cod]
55
*$py.class
66

7+
.flashrank_cache
8+
79
# C extensions
810
*.so
911

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ dependencies = [
5252
]
5353

5454
[project.optional-dependencies]
55-
all = ["transformers", "torch", "litellm", "requests", "sentencepiece", "protobuf"]
55+
all = ["transformers", "torch", "litellm", "requests", "sentencepiece", "protobuf", "flashrank"]
5656
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
5757
api = ["requests"]
5858
gpt = ["litellm"]
59+
flashrank = ["flashrank"]
5960
dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"]
6061

6162
[project.urls]

rerankers/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,10 @@
3131
AVAILABLE_RANKERS["ColBERTRanker"] = ColBERTRanker
3232
except ImportError:
3333
pass
34+
35+
try:
36+
from rerankers.models.flashrank_ranker import FlashRankRanker
37+
38+
AVAILABLE_RANKERS["FlashRankRanker"] = FlashRankRanker
39+
except ImportError:
40+
pass

rerankers/models/api_rankers.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"cohere": "https://api.cohere.ai/v1/rerank",
1212
"jina": "https://api.jina.ai/v1/rerank",
1313
"voyage": "https://api.voyageai.com/v1/rerank",
14-
"mixedbread": NotImplemented,
14+
"mixedbread.ai": "https://api.mixedbread.ai/v1/reranking",
1515
}
1616

1717

@@ -29,24 +29,36 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1
2929
}
3030
self.url = URLS[self.api_provider]
3131

32+
def _get_document_text(self, r: dict) -> str:
33+
if self.api_provider == "voyage":
34+
return r["document"]
35+
elif self.api_provider == "mixedbread.ai":
36+
return r["input"]
37+
else:
38+
return r["document"]["text"]
39+
40+
def _get_score(self, r: dict) -> float:
41+
if self.api_provider == "mixedbread.ai":
42+
return r["score"]
43+
return r["relevance_score"]
44+
3245
def _parse_response(
3346
self, response: dict, doc_ids: Union[List[str], List[int]]
3447
) -> RankedResults:
3548
ranked_docs = []
36-
results_key = "results" if self.api_provider != "voyage" else "data"
49+
results_key = (
50+
"results"
51+
if self.api_provider not in ["voyage", "mixedbread.ai"]
52+
else "data"
53+
)
3754
print(response)
3855

3956
for i, r in enumerate(response[results_key]):
40-
document_text = (
41-
r["document"]
42-
if self.api_provider == "voyage"
43-
else r["document"]["text"]
44-
)
4557
ranked_docs.append(
4658
Result(
4759
doc_id=doc_ids[r["index"]],
48-
text=document_text,
49-
score=r["relevance_score"],
60+
text=self._get_document_text(r),
61+
score=self._get_score(r),
5062
rank=i + 1,
5163
)
5264
)
@@ -67,13 +79,22 @@ def rank(
6779
return RankedResults(results=results, query=query, has_scores=True)
6880

6981
def _format_payload(self, query: str, docs: List[str]) -> str:
70-
top_key = "top_n" if self.api_provider != "voyage" else "top_k"
82+
top_key = (
83+
"top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k"
84+
)
85+
documents_key = "documents" if self.api_provider != "mixedbread.ai" else "input"
86+
return_documents_key = (
87+
"return_documents"
88+
if self.api_provider != "mixedbread.ai"
89+
else "return_input"
90+
)
91+
7192
payload = {
7293
"model": self.model,
7394
"query": query,
74-
"documents": docs,
95+
documents_key: docs,
7596
top_key: len(docs),
76-
"return_documents": True,
97+
return_documents_key: True,
7798
}
7899
return json.dumps(payload)
79100

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from rerankers.models.ranker import BaseRanker
2+
3+
from flashrank import Ranker, RerankRequest
4+
5+
6+
from typing import Union, List, Optional, Tuple
7+
from rerankers.utils import (
8+
vprint,
9+
ensure_docids,
10+
ensure_docs_list,
11+
)
12+
from rerankers.results import RankedResults, Result
13+
14+
15+
class FlashRankRanker(BaseRanker):
16+
def __init__(
17+
self,
18+
model_name_or_path: str,
19+
verbose: int = 1,
20+
cache_dir: str = "./.flashrank_cache",
21+
):
22+
self.verbose = verbose
23+
vprint(
24+
f"Loading model FlashRank model {model_name_or_path}...", verbose=verbose
25+
)
26+
self.model = Ranker(model_name=model_name_or_path, cache_dir=cache_dir)
27+
self.ranking_type = "pointwise"
28+
29+
def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
30+
return self.tokenizer(
31+
inputs, return_tensors="pt", padding=True, truncation=True
32+
).to(self.device)
33+
34+
def rank(
35+
self,
36+
query: str,
37+
docs: List[str],
38+
doc_ids: Optional[List[Union[str, int]]] = None,
39+
) -> RankedResults:
40+
docs = ensure_docs_list(docs)
41+
doc_ids = ensure_docids(doc_ids, len(docs))
42+
passages = [{"id": doc_id, "text": doc} for doc_id, doc in zip(doc_ids, docs)]
43+
44+
rerank_request = RerankRequest(query=query, passages=passages)
45+
flashrank_results = self.model.rerank(rerank_request)
46+
47+
ranked_results = [
48+
Result(
49+
doc_id=result["id"],
50+
text=result["text"],
51+
score=result["score"],
52+
rank=idx + 1,
53+
)
54+
for idx, result in enumerate(flashrank_results)
55+
]
56+
return RankedResults(results=ranked_results, query=query, has_scores=True)
57+
58+
def score(self, query: str, doc: str) -> float:
59+
rerank_request = RerankRequest(
60+
query=query, passages=[{"id": "temp_id", "text": doc}]
61+
)
62+
flashrank_result = self.model.rerank(rerank_request)
63+
score = flashrank_result[0]["score"]
64+
return score

rerankers/reranker.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
DEFAULTS = {
77
"jina": {"en": "jina-reranker-v1-base-en"},
8-
"cohere": {"en": "rerank-english-v2.0", "other": "rerank-multilingual-v2.0"},
8+
"cohere": {"en": "rerank-english-v3.0", "other": "rerank-multilingual-v3.0"},
99
"voyage": {"en": "rerank-lite-1"},
10+
"mixedbread.ai": {"en": "mixedbread-ai/mxbai-rerank-large-v1"},
1011
"cross-encoder": {
1112
"en": "mixedbread-ai/mxbai-rerank-base-v1",
1213
"fr": "antoinelouis/crossencoder-camembert-base-mmarcoFR",
@@ -26,6 +27,7 @@
2627
"ja": "bclavie/JaColBERTv2",
2728
"es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES",
2829
},
30+
"flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"},
2931
}
3032

3133
DEPS_MAPPING = {
@@ -35,9 +37,10 @@
3537
"RankGPTRanker": "gpt",
3638
"APIRanker": "api",
3739
"ColBERTRanker": "transformers",
40+
"FlashRankRanker": "flashrank",
3841
}
3942

40-
PROVIDERS = ["cohere", "jina", "voyage"]
43+
PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai"]
4144

4245

4346
def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> str:
@@ -68,6 +71,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
6871
"t5": "T5Ranker",
6972
"colbert": "ColBERTRanker",
7073
"cross-encoder": "TransformerRanker",
74+
"flashrank": "FlashRankRanker",
7175
}
7276
return model_mapping.get(explicit_model_type, explicit_model_type)
7377
else:
@@ -82,12 +86,18 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
8286
"cohere": "APIRanker",
8387
"jina": "APIRanker",
8488
"voyage": "APIRanker",
89+
"ms-marco-minilm-l-12-v2": "FlashRankRanker",
90+
"ms-marco-multibert-l-12": "FlashRankRanker",
8591
}
8692
for key, value in model_mapping.items():
8793
if key in model_name:
8894
return value
89-
if any(
90-
keyword in model_name for keyword in ["minilm", "bert", "cross-encoders/"]
95+
if (
96+
any(
97+
keyword in model_name
98+
for keyword in ["minilm", "bert", "cross-encoders/"]
99+
)
100+
and "/" in model_name
91101
):
92102
return "TransformerRanker"
93103
print(

0 commit comments

Comments
 (0)