Skip to content

Commit bd9ff2a

Browse files
authored
Add RankLLM legacy CLI wrappers (#366)
* Add packaged rank-llm entrypoint baseline * Fix ruff formatting in CLI packaging test * Add RankLLM legacy CLI wrappers * Restore legacy wrapper behavior
1 parent feab819 commit bd9ff2a

File tree

9 files changed

+385
-331
lines changed

9 files changed

+385
-331
lines changed

src/rank_llm/cli/legacy.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
from collections.abc import Sequence
5+
from enum import Enum
6+
7+
8+
def translate_legacy_argv(
9+
argv: Sequence[str],
10+
*,
11+
mapping: dict[str, str] | None = None,
12+
drop_flags: set[str] | None = None,
13+
) -> list[str]:
14+
translated: list[str] = []
15+
mapping = mapping or {}
16+
drop_flags = drop_flags or set()
17+
skip_next = False
18+
19+
for token in argv:
20+
if skip_next:
21+
skip_next = False
22+
continue
23+
if not token.startswith("--"):
24+
translated.append(token)
25+
continue
26+
27+
if "=" in token:
28+
flag, value = token.split("=", 1)
29+
normalized = flag[2:]
30+
if normalized in drop_flags:
31+
continue
32+
mapped_flag = mapping.get(normalized, normalized.replace("_", "-"))
33+
translated.append(f"--{mapped_flag}={value}")
34+
continue
35+
36+
normalized = token[2:]
37+
if normalized in drop_flags:
38+
skip_next = True
39+
continue
40+
mapped_flag = mapping.get(normalized, normalized.replace("_", "-"))
41+
translated.append(f"--{mapped_flag}")
42+
43+
return translated
44+
45+
46+
def namespace_to_legacy_argv(
47+
args: argparse.Namespace,
48+
*,
49+
mapping: dict[str, str] | None = None,
50+
drop_flags: set[str] | None = None,
51+
) -> list[str]:
52+
argv: list[str] = []
53+
mapping = mapping or {}
54+
drop_flags = drop_flags or set()
55+
56+
for key, value in vars(args).items():
57+
if key in drop_flags or value is None or value is False:
58+
continue
59+
flag = mapping.get(key, key.replace("_", "-"))
60+
if value is True:
61+
argv.append(f"--{flag}")
62+
continue
63+
if isinstance(value, list):
64+
if value:
65+
argv.append(f"--{flag}")
66+
argv.extend(str(item) for item in value)
67+
continue
68+
if isinstance(value, Enum):
69+
value = value.value
70+
argv.extend([f"--{flag}", str(value)])
71+
return argv

src/rank_llm/scripts/generate_retrieve_results_json_cache.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646
--topk 20
4747
"""
4848

49-
import argparse
5049
import json
5150
import os
5251
import sys
5352
from collections import defaultdict
53+
from collections.abc import Sequence
5454

5555
from rank_llm._optional import missing_extra_error
5656

@@ -62,6 +62,8 @@
6262
get_topics = None
6363
from tqdm import tqdm
6464

65+
from rank_llm.cli.legacy import namespace_to_legacy_argv, translate_legacy_argv
66+
from rank_llm.cli.main import main as cli_main
6567
from rank_llm.retrieve import TOPICS
6668

6769
sys.path.append(os.getcwd())
@@ -183,24 +185,16 @@ def write_output_file(output_file_path, data):
183185
json.dump(data, file)
184186

185187

186-
def main():
187-
parser = argparse.ArgumentParser()
188-
parser.add_argument("--trec_file", required=True)
189-
parser.add_argument("--collection_file", required=True)
190-
parser.add_argument("--query_file", required=True)
191-
parser.add_argument("--output_file", required=True)
192-
parser.add_argument("--output_trec_file", type=str, default=None)
193-
parser.add_argument("--topk", type=int, default=20)
194-
args = parser.parse_args()
195-
196-
results = generate_retrieve_results(
197-
args.trec_file,
198-
args.collection_file,
199-
args.query_file,
200-
args.topk,
201-
args.output_trec_file,
202-
)
203-
write_output_file(args.output_file, results)
188+
def main(args: Sequence[str] | object | None = None) -> int:
189+
if hasattr(args, "__dict__"):
190+
argv = namespace_to_legacy_argv(args)
191+
elif args is None:
192+
argv = sys.argv[1:]
193+
else:
194+
argv = list(args)
195+
196+
translated = translate_legacy_argv(argv)
197+
return cli_main(["retrieve-cache", *translated])
204198

205199

206200
if __name__ == "__main__":

0 commit comments

Comments
 (0)