Skip to content
Open
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
60d5e7f
First step update - during create_prompt, use selected_index to subst…
XKTZ Aug 13, 2024
74118c4
Allow selected index to be variable on each step
XKTZ Aug 13, 2024
36bbcdc
Added basic _get_model_function - it is wrong now, we want to return …
XKTZ Aug 13, 2024
6d2094a
Renamed indexes to indices, [indices] is renamed to indices_batch
XKTZ Aug 14, 2024
cc78522
Renamed ReorderExecutor to Reorder Policy
XKTZ Aug 14, 2024
4936c8c
Moved LiT5Distill to policy
XKTZ Aug 15, 2024
95f065a
Transition rank listwise os llm to using reorder policy
XKTZ Aug 18, 2024
438c8bd
Added reorder policy for rank listwise os and fid score
XKTZ Aug 18, 2024
2aa8644
Revised bug in OS LLM, add Rank GPT, deprecated old functions
XKTZ Aug 18, 2024
7a19b87
Finish the tournament sort node
XKTZ Aug 25, 2024
c8734f0
Finish tournament sort
XKTZ Aug 26, 2024
7981c06
Finish reorganize of parameters, move window_size to ListwiseRankLLM
XKTZ Aug 30, 2024
bbbdc31
Added window size back
XKTZ Sep 5, 2024
bb0ad2b
Fix Rerankers
XKTZ Sep 5, 2024
62785bc
Merge pull request #4 from castorini/main
XKTZ Sep 5, 2024
9b1972a
Merge branch 'reorder' into main
XKTZ Sep 5, 2024
06b9f1f
Merge pull request #5 from XKTZ/main
XKTZ Sep 5, 2024
28e3a23
Added r parameter
XKTZ Sep 5, 2024
0e9edce
Some bug fix
XKTZ Sep 6, 2024
3005455
Reformatted
XKTZ Sep 6, 2024
3795308
Added top down
XKTZ Sep 12, 2024
edd1ff9
Merge pull request #6 from castorini/main
XKTZ Sep 12, 2024
89bab7c
Merge pull request #7 from XKTZ/main
XKTZ Sep 12, 2024
65d3508
Added final some stuff
XKTZ Sep 15, 2024
69e3d65
Updated readme
XKTZ Sep 15, 2024
8f4f729
Clean the README
XKTZ Sep 15, 2024
cc22c95
Changed filename
XKTZ Sep 16, 2024
26beeed
Updated silence
XKTZ Sep 16, 2024
36beed1
Fixed bug when batch size > 32, which is in rank_listwise_os_llm's cr…
XKTZ Sep 20, 2024
406fca7
Updated for removing use_tqdm=False
XKTZ Sep 25, 2024
9e7f846
Update parallel for topdown
XKTZ Sep 27, 2024
329d6eb
Reformat
XKTZ Sep 27, 2024
4f561e8
Added vllm_chunked_prefill
XKTZ Oct 3, 2024
8470d3d
Updated README
XKTZ Oct 19, 2024
9caa5fb
Put back step size for maintaining backward compatibility
XKTZ Oct 19, 2024
d7ecc8b
Support early stop in topdown
XKTZ Nov 18, 2024
891ca8f
Reformat
XKTZ Nov 21, 2024
db3a509
temp .gitignore
XKTZ Nov 21, 2024
743b5af
Add parameter include padding,
XKTZ Dec 15, 2024
cca845e
Merge remote-tracking branch 'origin/main' into reorder
XKTZ Dec 24, 2024
66c8742
Format
XKTZ Dec 24, 2024
acf8012
Add more fix
XKTZ Dec 24, 2024
f595054
Fix bug
XKTZ Jan 3, 2025
94a5aa5
Merge remote-tracking branch 'rankllm/main' into reorder
XKTZ Jan 3, 2025
8d0ea17
update
XKTZ Jan 4, 2025
ba2c3ef
Merge remote-tracking branch 'rankllm/main' into reorder
XKTZ Jan 4, 2025
ca0a1f0
Fix TensorRT
XKTZ Jan 4, 2025
2efdbd5
Fix gitignore
XKTZ Jan 4, 2025
dc9fdbf
Fix import
XKTZ Jan 4, 2025
4669a3f
Added Readme
XKTZ Jan 4, 2025
2a38464
MonoT5 type def
XKTZ Jan 4, 2025
75c4a74
Let RankFID able to fit the type def
XKTZ Jan 4, 2025
0b0426d
Remove Tournament sort method with ListT5
XKTZ Jan 17, 2025
fbb163c
Merge remote-tracking branch 'rankllm/main' into reorder
XKTZ Jan 17, 2025
e5842a2
Add a comment
XKTZ Jan 17, 2025
70f3914
lint
XKTZ Jan 24, 2025
4944927
Update rankfid
XKTZ Jan 29, 2025
0d5c5e7
Add some comments
XKTZ Jan 29, 2025
8b0f963
Merge branch 'reorder' of github.com:XKTZ/rank_llm into reorder
XKTZ Jan 29, 2025
1115247
Updated
XKTZ Jan 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ token_counts/
retrieve_results/
ranking_execution_summary/
repro/
demo_outputs/
demo_outputs/
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ Omit `--use_logits` if you wish to perform traditional listwise reranking.

If you would like to contribute to the project, please refer to the [contribution guidelines](CONTRIBUTING.md).

### Run end to end - Rank Zephyr with [Tournament Sort](https://arxiv.org/abs/2402.15838) (`top_k = 10, r = 1`)

* Other LLM reranker method, like

```
python src/rank_llm/scripts/run_rank_llm.py --model_path=castorini/rank_zephyr_7b_v1_full --top_k_candidates=100 --dataset=dl20 \
--retrieval_method=SPLADE++_EnsembleDistil_ONNX --prompt_mode=rank_GPT --context_size=4096 --variable_passages \
--reorder_policy="tournament_sort:{top_k: 10, r: 1}"
```

### Run end to end - Rank Zephyr with [Top Down](https://)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XKTZ what is this?


## 🦙🐧 Model Zoo

The following is a table of the listwise models our repository was primarily built to handle (with the models hosted on HuggingFace):
Expand Down Expand Up @@ -283,6 +295,34 @@ If you would like to cite the FIRST methodology, please consider citing:
}
```

If you woud like to cite the ListT5's tournament sort methodology, please consider citing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the first instance of tournament sort, something has probably been done before, maybe we drop it for now


[[2402.15838] ListT5: Listwise Reranking with Fusion-in-Decoder Improves Zero-shot Retrieval](https://arxiv.org/abs/2405.14589)

```
@ARTICLE{yoon2024listt5,
title = {Listt5: Listwise reranking with fusion-in-decoder improves zero-shot retrieval},
author = {Yoon, Soyoung and Choi, Eunbi and Kim, Jiyeon and Yun, Hyeongu and Kim, Yireun and Hwang, Seung-won},
journal = {arXiv preprint arXiv:2402.15838},
year={2024}
journal = {arXiv:2402.15838}
}
```

If you would like to cite the Top Down Paritioning methodology, please consider citing

[[2405.14589] Top-Down Partitioning for Efficient List-Wise Ranking](https://arxiv.org/abs/2405.14589)

```
@ARTICLE{parry2024top,
title = {Top-Down Partitioning for Efficient List-Wise Ranking},
author = {Parry, Andrew and MacAvaney, Sean and Ganguly, Debasis},
journal = {arXiv preprint arXiv:2405.14589},
year = {2024}
journal = {arXiv:2405.14589}
}
```

Comment on lines 206 to 311
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

## 🙏 Acknowledgments

This research is supported in part by the Natural Sciences and Engineering Research Council (NSERC) of Canada.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ python-dotenv>=1.0.1
faiss-cpu>=1.8.0
ftfy>=6.2.0
dacite>=1.8.1
fschat[model_worker]>=0.2.36
fschat[model_worker]>=0.2.36
json-repair
211 changes: 205 additions & 6 deletions src/rank_llm/rerank/listwise/listwise_rankllm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,50 @@
import copy
import json
import logging
import random
import re
from abc import ABC
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

import json_repair
from ftfy import fix_text
from gguf import Optional
from tqdm import tqdm

from rank_llm.data import RankingExecInfo, Request, Result
from rank_llm.rerank import PromptMode, RankLLM
from rank_llm.rerank.listwise.reorder.reorder_policy import (
ModelFunction,
ReorderPolicy,
SlidingWindowReorderPolicy,
)
from rank_llm.rerank.listwise.reorder.top_down_reorder_policy import (
TopDownReorderPolicy,
)
from rank_llm.rerank.listwise.reorder.tournament_sort_reorder_policy import (
TournamentSortReorderPolicy,
)

logger = logging.getLogger(__name__)

ALPH_START_IDX = ord("A") - 1


SUPPORT_REORDER_POLICIES = [
SlidingWindowReorderPolicy,
TournamentSortReorderPolicy,
TopDownReorderPolicy,
]


@dataclass
class RerankConsumption:
consumption_reference_by_batch: int
consumption_reference_by_item: int


class ListwiseRankLLM(RankLLM, ABC):
"""
All children of ListwiseRankLLM must implement these functions:
Expand All @@ -32,18 +60,77 @@ class ListwiseRankLLM(RankLLM, ABC):

def __init__(
self,
reorder_policy: Optional[ReorderPolicy],
model: str,
context_size: int,
window_size: int,
prompt_mode: PromptMode,
num_few_shot_examples: int,
window_size: int,
use_alpha: bool = False,
) -> None:
super().__init__(model, context_size, prompt_mode)
self._num_few_shot_examples = num_few_shot_examples

self.reorder_policy = reorder_policy or SlidingWindowReorderPolicy()
self._window_size = window_size
self._use_alpha = use_alpha

def rerank_batch(
self,
requests: List[Request],
rank_start: int = 0,
rank_end: int = 100,
shuffle_candidates: bool = False,
logging: bool = False,
batched: bool = False,
**kwargs: Any,
) -> List[Result]:
populate_exec_summary: bool = kwargs.get("populate_exec_summary", False)

batch_size = kwargs.get("batch_size") or len(requests)

if not batched:
batch_size = 1

reorder_policy = self.reorder_policy
model_functions, consumption = self._get_model_function(batched, **kwargs)

# reranking using batched mode
if batched and len(set([len(req.candidates) for req in requests])) != 1:
raise ValueError("Batched requests must have the same number of candidates")

result: list[Result] = []

with tqdm(range(0, len(requests)), leave=False) as bar:
for i in range(0, len(requests), batch_size):
batch = requests[i : min(i + batch_size, len(requests))]
batch_result = reorder_policy.reorder(
requests=[
Result(
query=copy.deepcopy(request.query),
candidates=copy.deepcopy(request.candidates),
ranking_exec_summary=[],
)
for request in batch
],
rank_start=max(rank_start, 0),
rank_end=min(
rank_end, len(requests[0].candidates)
), # TODO: Fails arbitrary hit sizes
model=model_functions,
shuffle_candidates=shuffle_candidates,
logging=logging,
populate_exec_summary=populate_exec_summary,
)
result.extend(batch_result)
bar.update(len(batch))

logger.info(
f"\n\nAverage consumption per request: {consumption.consumption_reference_by_item / len(requests) : .2f}\n\n"
)

return result

def get_output_filename(
self,
top_k_candidates: int,
Expand All @@ -61,6 +148,7 @@ def get_output_filename(
name = f"{name}_{dataset_name}"
if self._num_few_shot_examples > 0:
name += f"_{self._num_few_shot_examples}_shot"
name += f"_{self.reorder_policy.param_name()}"
return (
f"{name}_shuffled_{datetime.isoformat(datetime.now())}"
if shuffle_candidates
Expand All @@ -76,6 +164,7 @@ def max_tokens(self) -> int:
"""
return self._context_size

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def permutation_pipeline_batched(
self,
results: List[Result],
Expand All @@ -99,7 +188,9 @@ def permutation_pipeline_batched(
prompts = []
logger.info("Loading prompts.")
prompts = self.create_prompt_batched(
results, rank_start, rank_end, batch_size=32
results,
[list(range(rank_start, rank_end)) for _ in range(len(results))],
batch_size=32,
)
if logging:
for prompt in prompts:
Expand All @@ -126,6 +217,7 @@ def permutation_pipeline_batched(

return results

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def permutation_pipeline(
self,
result: Result,
Expand All @@ -146,7 +238,9 @@ def permutation_pipeline(
Returns:
Result: The processed result object after applying permutation.
"""
prompt, in_token_count = self.create_prompt(result, rank_start, rank_end)
prompt, in_token_count = self.create_prompt(
result, list(range(rank_start, rank_end))
)
if logging:
logger.info(f"Prompt: {prompt}\n")
permutation, out_token_count = self.run_llm(
Expand Down Expand Up @@ -184,6 +278,7 @@ def shuffle_and_rescore(
cand["score"] = 1.0 / (i + 1)
cand["rank"] = i + 1

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def sliding_windows_batched(
self,
requests: List[Request],
Expand Down Expand Up @@ -234,6 +329,7 @@ def sliding_windows_batched(
start_pos = start_pos - step
return rerank_results

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def sliding_windows(
self,
request: Request,
Expand Down Expand Up @@ -342,7 +438,7 @@ def get_ranking_cost(
start_pos = rank_end - window_size
while start_pos >= rank_start:
start_pos = max(start_pos, rank_start)
prompt, _ = self.create_prompt(result, start_pos, end_pos)
prompt, _ = self.create_prompt(result, list(range(start_pos, end_pos)))
input_token_count += self.get_num_tokens(prompt)
end_pos = end_pos - step
start_pos = start_pos - step
Expand All @@ -369,7 +465,7 @@ def _clean_response(self, response: str) -> str:
else:
new_response += c
new_response = new_response.strip()

new_response = re.sub(r"\s+", " ", new_response)
return new_response

def _remove_duplicate(self, response: List[int]) -> List[int]:
Expand Down Expand Up @@ -452,3 +548,106 @@ def convert_doc_to_prompt_content(
# For Japanese should cut by character: content = content[:int(max_length)]
content = " ".join(content.split()[: int(max_length)])
return self._replace_number(content)

def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]):
perm = [
int(x) - 1 for x in self._clean_response(perm_string).strip().split(" ")
]
perm = [
int(x)
for x in self._remove_duplicate(perm)
if 0 <= int(x) < len(selected_indices)
]
perm = perm + [i for i in range(len(selected_indices)) if i not in perm]
return perm

def _get_model_function(
self, batched: bool = False, silence: bool = False, **kwargs
) -> Tuple[ModelFunction, RerankConsumption]:
# [(Request, SelectIndex)] -> [Prompt]

consumption = RerankConsumption(0, 0)

if batched:

def create_prompt(batch: List[Tuple[Result, List[int]]]):
return [
prompt
for prompt, _ in self.create_prompt_batched(
[result for result, selected_indices in batch],
[selected_indices for result, selected_indices in batch],
32,
)
]

def execute(
batch: List[Union[str, Dict[str, str]]],
selected_indices_batch: List[List[int]],
):
consumption.consumption_reference_by_batch += 1
consumption.consumption_reference_by_item += len(batch)

return [
self._permutation_to_rank(s, selected_indices)
for (s, _), selected_indices in zip(
self.run_llm_batched(batch, silence=silence, **kwargs),
selected_indices_batch,
)
]

else:

def create_prompt(batch: List[Tuple[Result, List[int]]]):
return [
self.create_prompt(result, selected_indices)[0]
for result, selected_indices in batch
]

def execute(
batch: List[Union[str, Dict[str, str]]],
selected_indices_batch: List[List[int]],
):
consumption.consumption_reference_by_batch += 1
consumption.consumption_reference_by_item += len(batch)

return [
self._permutation_to_rank(
self.run_llm(x, silence=silence, **kwargs)[0], selected_indices
)
for x, selected_indices in zip(batch, selected_indices_batch)
]

return (
ModelFunction(
create_prompt=create_prompt,
execute=execute,
window_size=self._window_size,
),
consumption,
)

@staticmethod
def get_reorder_policy(reorder_policy: str, **kwargs) -> ReorderPolicy:
for policy in SUPPORT_REORDER_POLICIES:
if reorder_policy.startswith(policy.name()):
reorder_params = reorder_policy[len(policy.name()) :]
reorder_params = reorder_params.strip()
if len(reorder_params) <= 1:
return policy()
else:
assert reorder_params[0] == ":" and reorder_params[1] == "{"
reorder_params = reorder_params[1:]
try:
reorder_param_dict = json_repair.repair_json(reorder_params)
if isinstance(reorder_param_dict, str):
reorder_param_dict = json.loads(reorder_param_dict)
if not isinstance(reorder_param_dict, dict):
raise Exception(
f"Didn't successfully parse reorder parameter into a dict, right now it is {reorder_param_dict} with type {(type(reorder_param_dict))}"
)
except Exception as e:
print(e)
raise Exception(f"Cannot load reorder policy {reorder_policy}")
return policy(**reorder_param_dict, extra_args=dict(**kwargs))

raise Exception(f"Cannot find reorder policy {reorder_policy}")
Loading
Loading