Skip to content

Commit 60c9740

Browse files
authored
expose reading from file in run_rankllm (#320)
* add support for passing a retrieval file (Cached_file retrieval mode) * add output files and qrel files to args, add populate invocations history to args to enable writing results in cached_file mode.
1 parent 6f880ba commit 60c9740

File tree

3 files changed

+112
-10
lines changed

3 files changed

+112
-10
lines changed

src/rank_llm/retrieve/retriever.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class RetrievalMode(Enum):
1818
DATASET = "dataset"
1919
CUSTOM = "custom"
20+
CACHED_FILE = "cached_file"
2021

2122
def __str__(self):
2223
return self.value

src/rank_llm/retrieve_and_rerank.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import copy
2+
from pathlib import Path
23
from typing import Any, Dict, List, Optional, Union
34

4-
from rank_llm.data import Query, Request
5+
from rank_llm.data import DataWriter, Query, Request, read_requests_from_file
56
from rank_llm.rerank import IdentityReranker, RankLLM, Reranker
67
from rank_llm.rerank.reranker import extract_kwargs
78
from rank_llm.retrieve import (
@@ -79,7 +80,6 @@ def retrieve_and_rerank(
7980
# Reranker is of type RankLLM
8081
for pass_ct in range(num_passes):
8182
print(f"Pass {pass_ct + 1} of {num_passes}:")
82-
8383
rerank_results = reranker.rerank_batch(
8484
requests,
8585
rank_end=top_k_retrieve,
@@ -125,6 +125,55 @@ def retrieve_and_rerank(
125125
EvalFunction.eval(["-c", "-m", "ndcg_cut.10", TOPICS[dataset], file_name])
126126
else:
127127
print(f"Skipping evaluation as {dataset} is not in TOPICS.")
128+
elif (
129+
retrieval_mode == RetrievalMode.CACHED_FILE
130+
and reranker.get_model_coordinator() is not None
131+
):
132+
writer = DataWriter(rerank_results)
133+
keys_and_defaults = [
134+
("output_jsonl_file", ""),
135+
("output_trec_file", ""),
136+
("invocations_history_file", ""),
137+
]
138+
[
139+
output_jsonl_file,
140+
output_trec_file,
141+
invocations_history_file,
142+
] = extract_kwargs(keys_and_defaults, **kwargs)
143+
if output_jsonl_file:
144+
path = Path(output_jsonl_file)
145+
path.parent.mkdir(parents=True, exist_ok=True)
146+
writer.write_in_jsonl_format(output_jsonl_file)
147+
if output_trec_file:
148+
path = Path(output_trec_file)
149+
path.parent.mkdir(parents=True, exist_ok=True)
150+
writer.write_in_trec_eval_format(output_trec_file)
151+
keys_and_defaults = [("populate_invocations_history", False)]
152+
[populate_invocations_history] = extract_kwargs(keys_and_defaults, **kwargs)
153+
if populate_invocations_history:
154+
if invocations_history_file:
155+
path = Path(invocations_history_file)
156+
path.parent.mkdir(parents=True, exist_ok=True)
157+
writer.write_inference_invocations_history(invocations_history_file)
158+
else:
159+
raise ValueError(
160+
"--invocations_history_file must be a valid jsonl file to store invocations history."
161+
)
162+
keys_and_defaults = [("qrels_file", "")]
163+
[qrels_file] = extract_kwargs(keys_and_defaults, **kwargs)
164+
if qrels_file:
165+
from rank_llm.evaluation.trec_eval import EvalFunction
166+
167+
print("Evaluating:")
168+
EvalFunction.from_results(
169+
rerank_results, qrels_file, ["-c", "-m", "ndcg_cut.1"]
170+
)
171+
EvalFunction.from_results(
172+
rerank_results, qrels_file, ["-c", "-m", "ndcg_cut.5"]
173+
)
174+
EvalFunction.from_results(
175+
rerank_results, qrels_file, ["-c", "-m", "ndcg_cut.10"]
176+
)
128177

129178
if interactive:
130179
return (rerank_results, reranker.get_model_coordinator())
@@ -211,5 +260,11 @@ def retrieve(
211260
requests = Retriever.from_custom_index(
212261
index_path=index_path, topics_path=topics_path, index_type=index_type
213262
)
263+
elif retrieval_mode == RetrievalMode.CACHED_FILE:
264+
keys_and_defaults = [
265+
("requests_file", ""),
266+
]
267+
[requests_file] = extract_kwargs(keys_and_defaults, **kwargs)
268+
requests = read_requests_from_file(requests_file)
214269

215270
return requests

src/rank_llm/scripts/run_rank_llm.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def main(args):
2828
dataset = args.dataset
2929
num_gpus = args.num_gpus
3030
retrieval_method = args.retrieval_method
31+
requests_file = args.requests_file
32+
qrels_file = args.qrels_file
33+
output_jsonl_file = args.output_jsonl_file
34+
output_trec_file = args.output_trec_file
35+
invocations_history_file = args.invocations_history_file
3136
prompt_template_path = args.prompt_template_path
3237
num_few_shot_examples = args.num_few_shot_examples
3338
few_shot_file = args.few_shot_file
@@ -36,7 +41,9 @@ def main(args):
3641
num_few_shot_examples = args.num_few_shot_examples
3742
device = "cuda" if torch.cuda.is_available() else "cpu"
3843
variable_passages = args.variable_passages
39-
retrieval_mode = RetrievalMode.DATASET
44+
retrieval_mode = (
45+
RetrievalMode.DATASET if args.dataset else RetrievalMode.CACHED_FILE
46+
)
4047
num_passes = args.num_passes
4148
stride = args.stride
4249
window_size = args.window_size
@@ -49,12 +56,26 @@ def main(args):
4956
sglang_batched = args.sglang_batched
5057
tensorrt_batched = args.tensorrt_batched
5158

59+
if args.requests_file:
60+
if args.retrieval_method:
61+
parser.error("--retrieval_method must not be used with --requests_file")
62+
if not os.path.exists(args.requests_file):
63+
parser.error(f"--requests_file not found: {args.requests_file}")
64+
65+
if args.dataset and not args.retrieval_method:
66+
parser.error("--retrieval_method is required when --dataset is provided")
67+
5268
_ = retrieve_and_rerank(
5369
model_path=model_path,
5470
query=query,
5571
batch_size=batch_size,
5672
dataset=dataset,
5773
retrieval_mode=retrieval_mode,
74+
requests_file=requests_file,
75+
qrels_file=qrels_file,
76+
output_jsonl_file=output_jsonl_file,
77+
output_trec_file=output_trec_file,
78+
invocations_history_file=invocations_history_file,
5879
retrieval_method=retrieval_method,
5980
top_k_retrieve=top_k_candidates,
6081
top_k_rerank=top_k_rerank,
@@ -142,21 +163,46 @@ def main(args):
142163
default=None,
143164
help="the max number of queries to process from the dataset",
144165
)
145-
parser.add_argument(
166+
retrieval_input_group = parser.add_mutually_exclusive_group(required=True)
167+
retrieval_input_group.add_argument(
146168
"--dataset",
147169
type=str,
148-
required=True,
149-
help=f"Should be one of 1- dataset name, must be in {TOPICS.keys()}, 2- a list of inline documents 3- a list of inline hits 4- filename containing retrieved results",
150-
)
151-
parser.add_argument(
152-
"--num_gpus", type=int, default=1, help="the number of GPUs to use"
170+
help=f"Should be one of 1- dataset name, must be in {TOPICS.keys()}, 2- a list of inline documents 3- a list of inline hits; must be used when --requests_file is not specified",
153171
)
154172
parser.add_argument(
155173
"--retrieval_method",
156174
type=RetrievalMethod,
157-
required=True,
175+
help="Required if --dataset is used; must be omitted with --requests_file",
158176
choices=list(RetrievalMethod),
159177
)
178+
retrieval_input_group.add_argument(
179+
"--requests_file",
180+
type=str,
181+
help=f"Path to a JSONL file containing requests; must be used when --dataset is not specified.",
182+
)
183+
parser.add_argument(
184+
"--qrels_file",
185+
type=str,
186+
help="Only used with --requests_file; when present the Trec eval will be executed using this qrels file",
187+
)
188+
parser.add_argument(
189+
"--output_jsonl_file",
190+
type=str,
191+
help="Only used with --requests_file; when present, the ranked results will be saved in this JSONL file.",
192+
)
193+
parser.add_argument(
194+
"--output_trec_file",
195+
type=str,
196+
help="Only used with --requests_file; when present, the ranked results will be saved in this txt file in trec format.",
197+
)
198+
parser.add_argument(
199+
"--invocations_history_file",
200+
type=str,
201+
help="Only used with --requests_file and --populate_invocations_history; when present, the LLM invocations history (prompts, completions, and input/output token counts) will be stored in this file.",
202+
)
203+
parser.add_argument(
204+
"--num_gpus", type=int, default=1, help="the number of GPUs to use"
205+
)
160206
parser.add_argument(
161207
"--prompt_mode",
162208
type=PromptMode,

0 commit comments

Comments
 (0)