|
1 | 1 | import copy |
| 2 | +from pathlib import Path |
2 | 3 | from typing import Any, Dict, List, Optional, Union |
3 | 4 |
|
4 | | -from rank_llm.data import Query, Request |
| 5 | +from rank_llm.data import DataWriter, Query, Request, read_requests_from_file |
5 | 6 | from rank_llm.rerank import IdentityReranker, RankLLM, Reranker |
6 | 7 | from rank_llm.rerank.reranker import extract_kwargs |
7 | 8 | from rank_llm.retrieve import ( |
@@ -79,7 +80,6 @@ def retrieve_and_rerank( |
79 | 80 | # Reranker is of type RankLLM |
80 | 81 | for pass_ct in range(num_passes): |
81 | 82 | print(f"Pass {pass_ct + 1} of {num_passes}:") |
82 | | - |
83 | 83 | rerank_results = reranker.rerank_batch( |
84 | 84 | requests, |
85 | 85 | rank_end=top_k_retrieve, |
@@ -125,6 +125,55 @@ def retrieve_and_rerank( |
125 | 125 | EvalFunction.eval(["-c", "-m", "ndcg_cut.10", TOPICS[dataset], file_name]) |
126 | 126 | else: |
127 | 127 | print(f"Skipping evaluation as {dataset} is not in TOPICS.") |
| 128 | + elif ( |
| 129 | + retrieval_mode == RetrievalMode.CACHED_FILE |
| 130 | + and reranker.get_model_coordinator() is not None |
| 131 | + ): |
| 132 | + writer = DataWriter(rerank_results) |
| 133 | + keys_and_defaults = [ |
| 134 | + ("output_jsonl_file", ""), |
| 135 | + ("output_trec_file", ""), |
| 136 | + ("invocations_history_file", ""), |
| 137 | + ] |
| 138 | + [ |
| 139 | + output_jsonl_file, |
| 140 | + output_trec_file, |
| 141 | + invocations_history_file, |
| 142 | + ] = extract_kwargs(keys_and_defaults, **kwargs) |
| 143 | + if output_jsonl_file: |
| 144 | + path = Path(output_jsonl_file) |
| 145 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 146 | + writer.write_in_jsonl_format(output_jsonl_file) |
| 147 | + if output_trec_file: |
| 148 | + path = Path(output_trec_file) |
| 149 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 150 | + writer.write_in_trec_eval_format(output_trec_file) |
| 151 | + keys_and_defaults = [("populate_invocations_history", False)] |
| 152 | + [populate_invocations_history] = extract_kwargs(keys_and_defaults, **kwargs) |
| 153 | + if populate_invocations_history: |
| 154 | + if invocations_history_file: |
| 155 | + path = Path(invocations_history_file) |
| 156 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 157 | + writer.write_inference_invocations_history(invocations_history_file) |
| 158 | + else: |
| 159 | + raise ValueError( |
| 160 | + "--invocations_history_file must be a valid jsonl file to store invocations history." |
| 161 | + ) |
| 162 | + keys_and_defaults = [("qrels_file", "")] |
| 163 | + [qrels_file] = extract_kwargs(keys_and_defaults, **kwargs) |
| 164 | + if qrels_file: |
| 165 | + from rank_llm.evaluation.trec_eval import EvalFunction |
| 166 | + |
| 167 | + print("Evaluating:") |
| 168 | + EvalFunction.from_results( |
| 169 | + rerank_results, qrels_file, ["-c", "-m", "ndcg_cut.1"] |
| 170 | + ) |
| 171 | + EvalFunction.from_results( |
| 172 | + rerank_results, qrels_file, ["-c", "-m", "ndcg_cut.5"] |
| 173 | + ) |
| 174 | + EvalFunction.from_results( |
| 175 | + rerank_results, qrels_file, ["-c", "-m", "ndcg_cut.10"] |
| 176 | + ) |
128 | 177 |
|
129 | 178 | if interactive: |
130 | 179 | return (rerank_results, reranker.get_model_coordinator()) |
@@ -211,5 +260,11 @@ def retrieve( |
211 | 260 | requests = Retriever.from_custom_index( |
212 | 261 | index_path=index_path, topics_path=topics_path, index_type=index_type |
213 | 262 | ) |
| 263 | + elif retrieval_mode == RetrievalMode.CACHED_FILE: |
| 264 | + keys_and_defaults = [ |
| 265 | + ("requests_file", ""), |
| 266 | + ] |
| 267 | + [requests_file] = extract_kwargs(keys_and_defaults, **kwargs) |
| 268 | + requests = read_requests_from_file(requests_file) |
214 | 269 |
|
215 | 270 | return requests |
0 commit comments