Skip to content

Commit 01febd3

Browse files
authored
Merge pull request #13 from AnswerDotAI/feat/rankllm_
RELEASE: 0.3.0, RankLLM, Document, QoL
2 parents 5fb18e7 + b1739e3 commit 01febd3

File tree

10 files changed

+396
-160
lines changed

10 files changed

+396
-160
lines changed

README.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an
1414

1515
## Updates
1616

17-
- v0.2.0: 🆕 [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
17+
- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
18+
- v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
1819
- v0.1.2: Voyage reranking API
1920
- v0.1.1: Langchain integration fixed!
2021
- v0.1.0: Initial release
@@ -59,6 +60,9 @@ pip install "rerankers[api]"
5960
# FlashRank rerankers (ONNX-optimised, very fast on CPU)
6061
pip install "rerankers[fastrank]"
6162

63+
# RankLLM rerankers (better RankGPT + support for local models such as RankZephyr and RankVicuna)
64+
pip install "rerankers[rankllm]"
65+
6266
# All of the above
6367
pip install "rerankers[all]"
6468
```
@@ -105,12 +109,27 @@ ranker = Reranker("rankgpt3", api_key = API_KEY)
105109
# RankGPT with another LLM provider
106110
ranker = Reranker("MY_LLM_NAME" (check litellm docs), model_type = "rankgpt", api_key = API_KEY)
107111

112+
# RankLLM with default GPT (GPT-4o)
113+
ranker = Reranker("rankllm", api_key = API_KEY)
114+
115+
# RankLLM with specified GPT models
116+
ranker = Reranker('gpt-4-turbo', model_type="rankllm", api_key = API_KEY)
117+
118+
# EXPERIMENTAL: RankLLM with RankZephyr
119+
ranker = Reranker('rankzephyr')
120+
108121
# ColBERTv2 reranker
109122
ranker = Reranker("colbert")
110123

111124
# ... Or a non-default colbert model:
112125
ranker = Reranker(model_name_or_path, model_type = "colbert")
113126

127+
# Flashrank
128+
ranker = Reranker('flashrank')
129+
130+
# ... Or a specific model
131+
ranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
132+
114133
```
115134

116135
_Rerankers will always try to infer the model you're trying to use based on its name, but it's always safer to pass a `model_type` argument to it if you can!_
@@ -180,18 +199,18 @@ Legend:
180199

181200
Models:
182201
- ✅ Any standard SentenceTransformer or Transformers cross-encoder
183-
- 🟠 RankGPT (Implemented using original repo, but missing the rankllm's repo improvements)
202+
- RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one)
184203
- ✅ T5-based pointwise rankers (InRanker, MonoT5...)
185204
- ✅ Cohere, Jina, Voyage and MixedBread API rerankers
186205
-[FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU)
187206
- 🟠 ColBERT-based reranker - not a model initially designed for reranking, but quite strong (Implementation could be optimised and is from a third-party implementation.)
188-
- 📍 MixedBread API (Reranking API not yet released)
189-
- 📍⭐ RankLLM/RankZephyr (Proper RankLLM implementation will replace the RankGPT one, and introduce RankZephyr support)
207+
- 🟠⭐ RankLLM/RankZephyr: supported by wrapping the [rank-llm library](https://github.com/castorini/rank_llm) library! Support for RankZephyr/RankVicuna is untested, but RankLLM + GPT models fully works!
190208
- 📍 LiT5
191209

192210
Features:
211+
- ✅ Metadata!
193212
- ✅ Reranking
194213
- ✅ Consistency notebooks to ensure performance on `scifact` matches the litterature for any given model implementation (Except RankGPT, where results are harder to reproduce).
214+
- ✅ ONNX runtime support --> Offered through [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) -- in line with the philosophy of the lib, we won't reinvent the wheel when @PrithivirajDamodaran is doing amazing work!
195215
- 📍 Training on Python >=3.10 (via interfacing with other libraries)
196-
- 📍 ONNX runtime support --> Unlikely to be immediate
197216
- ❌(📍Maybe?) Training via rerankers directly

examples/overview.ipynb

Lines changed: 259 additions & 136 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ packages = [
1414
name = "rerankers"
1515

1616

17-
version = "0.2.0"
17+
version = "0.3.0"
1818

1919
description = "A unified API for various document re-ranking models."
2020

@@ -52,12 +52,13 @@ dependencies = [
5252
]
5353

5454
[project.optional-dependencies]
55-
all = ["transformers", "torch", "litellm", "requests", "sentencepiece", "protobuf", "flashrank"]
56-
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
55+
all = ["transformers", "torch", "litellm", "requests", "sentencepiece", "protobuf", "flashrank", "rank-llm"]
56+
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
5757
api = ["requests"]
5858
gpt = ["litellm"]
5959
flashrank = ["flashrank"]
60+
rankllm = ["rank-llm"]
6061
dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"]
6162

6263
[project.urls]
63-
"Homepage" = "https://github.com/bclavie/rerankers"
64+
"Homepage" = "https://github.com/answerdotai/rerankers"

rerankers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from rerankers.documents import Document
33

44
__all__ = ["Reranker", "Document"]
5-
__version__ = "0.2.0"
5+
__version__ = "0.3.0"

rerankers/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,10 @@
3838
AVAILABLE_RANKERS["FlashRankRanker"] = FlashRankRanker
3939
except ImportError:
4040
pass
41+
42+
try:
43+
from rerankers.models.rankllm_ranker import RankLLMRanker
44+
45+
AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker
46+
except ImportError:
47+
pass

rerankers/models/flashrank_ranker.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44

55

66
from typing import Union, List, Optional, Tuple
7-
from rerankers.utils import (
8-
vprint,
9-
ensure_docids,
10-
ensure_docs_list,
11-
)
7+
from rerankers.utils import vprint, prep_docs
128
from rerankers.results import RankedResults, Result
9+
from rerankers.documents import Document
1310

1411

1512
class FlashRankRanker(BaseRanker):
@@ -34,20 +31,21 @@ def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
3431
def rank(
3532
self,
3633
query: str,
37-
docs: List[str],
38-
doc_ids: Optional[List[Union[str, int]]] = None,
34+
docs: Union[str, List[str], Document, List[Document]],
35+
doc_ids: Optional[Union[List[str], List[int]]] = None,
36+
metadata: Optional[List[dict]] = None,
3937
) -> RankedResults:
40-
docs = ensure_docs_list(docs)
41-
doc_ids = ensure_docids(doc_ids, len(docs))
42-
passages = [{"id": doc_id, "text": doc} for doc_id, doc in zip(doc_ids, docs)]
38+
docs = prep_docs(docs, doc_ids, metadata)
39+
passages = [
40+
{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)
41+
]
4342

4443
rerank_request = RerankRequest(query=query, passages=passages)
4544
flashrank_results = self.model.rerank(rerank_request)
4645

4746
ranked_results = [
4847
Result(
49-
doc_id=result["id"],
50-
text=result["text"],
48+
document=docs[idx],
5149
score=result["score"],
5250
rank=idx + 1,
5351
)

rerankers/models/rankgpt_rankers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _query_llm(self, messages: List[Dict[str, str]]) -> str:
126126
def rank(
127127
self,
128128
query: str,
129-
docs: Union[Document, List[Document]],
129+
docs: Union[str, List[str], Document, List[Document]],
130130
doc_ids: Optional[Union[List[str], List[int]]] = None,
131131
metadata: Optional[List[dict]] = None,
132132
rank_start: int = 0,

rerankers/models/rankllm.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

rerankers/models/rankllm_ranker.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Optional, Union, List
2+
from rerankers.models.ranker import BaseRanker
3+
from rerankers.documents import Document
4+
from rerankers.results import RankedResults, Result
5+
from rerankers.utils import prep_docs
6+
7+
from rank_llm.data import Candidate, Query, Request
8+
from rank_llm.rerank.vicuna_reranker import VicunaReranker
9+
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
10+
from rank_llm.rerank.rank_gpt import SafeOpenai
11+
from rank_llm.rerank.reranker import Reranker as rankllm_Reranker
12+
13+
14+
class RankLLMRanker(BaseRanker):
15+
def __init__(
16+
self,
17+
model: str,
18+
api_key: Optional[str] = None,
19+
lang: str = "en",
20+
verbose: int = 1,
21+
) -> "RankLLMRanker":
22+
self.api_key = api_key
23+
self.model = model
24+
self.verbose = verbose
25+
self.lang = lang
26+
27+
if "zephyr" in self.model.lower():
28+
self.rankllm_ranker = ZephyrReranker()
29+
elif "vicuna" in self.model.lower():
30+
self.rankllm_ranker = VicunaReranker()
31+
elif "gpt" in self.model.lower():
32+
self.rankllm_ranker = rankllm_Reranker(
33+
SafeOpenai(model=self.model, context_size=4096, keys=self.api_key)
34+
)
35+
36+
def rank(
37+
self,
38+
query: str,
39+
docs: Union[str, List[str], Document, List[Document]],
40+
doc_ids: Optional[Union[List[str], List[int]]] = None,
41+
metadata: Optional[List[dict]] = None,
42+
rank_start: int = 0,
43+
rank_end: int = 0,
44+
) -> RankedResults:
45+
docs = prep_docs(docs, doc_ids, metadata)
46+
47+
request = Request(
48+
query=Query(text=query, qid=1),
49+
candidates=[
50+
Candidate(doc={"text": doc.text}, docid=doc_idx, score=1)
51+
for doc_idx, doc in enumerate(docs)
52+
],
53+
)
54+
55+
rankllm_results = self.rankllm_ranker.rerank(
56+
request,
57+
rank_end=len(docs) if rank_end == 0 else rank_end,
58+
window_size=min(20, len(docs)),
59+
step=10,
60+
)
61+
62+
ranked_docs = []
63+
64+
for rank, result in enumerate(rankllm_results.candidates, start=rank_start):
65+
ranked_docs.append(
66+
Result(
67+
document=docs[result.docid],
68+
rank=rank,
69+
)
70+
)
71+
72+
return RankedResults(results=ranked_docs, query=query, has_scores=False)
73+
74+
def score(self):
75+
print("Listwise ranking models like RankLLM cannot output scores!")
76+
return None

rerankers/reranker.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional
2+
import warnings
23
from rerankers.models import AVAILABLE_RANKERS
34
from rerankers.models.ranker import BaseRanker
45
from rerankers.utils import vprint
@@ -21,6 +22,7 @@
2122
"rankgpt": {"en": "gpt-4-turbo-preview", "other": "gpt-4-turbo-preview"},
2223
"rankgpt3": {"en": "gpt-3.5-turbo", "other": "gpt-3.5-turbo"},
2324
"rankgpt4": {"en": "gpt-4", "other": "gpt-4"},
25+
"rankllm": {"en": "gpt-4o", "other": "gpt-4o"},
2426
"colbert": {
2527
"en": "colbert-ir/colbertv2.0",
2628
"fr": "bclavie/FraColBERTv2",
@@ -38,6 +40,7 @@
3840
"APIRanker": "api",
3941
"ColBERTRanker": "transformers",
4042
"FlashRankRanker": "flashrank",
43+
"RankLLMRanker": "rankllm",
4144
}
4245

4346
PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai"]
@@ -72,6 +75,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
7275
"colbert": "ColBERTRanker",
7376
"cross-encoder": "TransformerRanker",
7477
"flashrank": "FlashRankRanker",
78+
"rankllm": "RankLLMRanker",
7579
}
7680
return model_mapping.get(explicit_model_type, explicit_model_type)
7781
else:
@@ -80,6 +84,8 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
8084
"lit5": "LiT5Ranker",
8185
"t5": "T5Ranker",
8286
"inranker": "T5Ranker",
87+
"rankllm": "RankLLMRanker",
88+
"rankgpt": "RankGPTRanker",
8389
"gpt": "RankGPTRanker",
8490
"zephyr": "RankZephyr",
8591
"colbert": "ColBERTRanker",
@@ -88,9 +94,16 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
8894
"voyage": "APIRanker",
8995
"ms-marco-minilm-l-12-v2": "FlashRankRanker",
9096
"ms-marco-multibert-l-12": "FlashRankRanker",
97+
"vicuna": "RankLLMRanker",
98+
"zephyr": "RankLLMRanker",
9199
}
92100
for key, value in model_mapping.items():
93101
if key in model_name:
102+
if key == "gpt":
103+
warnings.warn(
104+
"The key 'gpt' currently defaults to the rough rankGPT implementation. From version 0.0.5 onwards, 'gpt' will default to RankLLM instead. Please specify the 'rankgpt' `model_type` if you want to keep the current behaviour",
105+
DeprecationWarning,
106+
)
94107
return value
95108
if (
96109
any(

0 commit comments

Comments
 (0)