44from rerankers .results import RankedResults , Result
55from rerankers .utils import prep_docs
66
7- from rank_llm .data import Candidate , Query , Request
8- from rank_llm .rerank .listwise .vicuna_reranker import VicunaReranker
9- from rank_llm .rerank .listwise .zephyr_reranker import ZephyrReranker
10- from rank_llm .rerank .listwise .rank_gpt import SafeOpenai
7+ # from rerankers import Reranker
8+
119from rank_llm .rerank .reranker import Reranker as rankllm_Reranker
10+ from rank_llm .rerank import PromptMode , get_azure_openai_args , get_genai_api_key , get_openai_api_key
11+ from rank_llm .data import Candidate , Query , Request
1212
1313
1414class RankLLMRanker (BaseRanker ):
1515 def __init__ (
1616 self ,
17- model : str ,
17+ model : str = "rank_zephyr" ,
1818 api_key : Optional [str ] = None ,
1919 lang : str = "en" ,
2020 verbose : int = 1 ,
21+ # RankLLM specific arguments
22+ window_size : int = 20 ,
23+ context_size : int = 4096 ,
24+ prompt_mode : PromptMode = PromptMode .RANK_GPT ,
25+ num_few_shot_examples : int = 0 ,
26+ few_shot_file : Optional [str ] = None ,
27+ num_gpus : int = 1 ,
28+ variable_passages : bool = False ,
29+ use_logits : bool = False ,
30+ use_alpha : bool = False ,
31+ stride : int = 10 ,
32+ use_azure_openai : bool = False ,
2133 ) -> "RankLLMRanker" :
2234 self .api_key = api_key
2335 self .model = model
2436 self .verbose = verbose
2537 self .lang = lang
26-
27- if "zephyr" in self .model .lower ():
28- self .rankllm_ranker = ZephyrReranker ()
29- elif "vicuna" in self .model .lower ():
30- self .rankllm_ranker = VicunaReranker ()
31- elif "gpt" in self .model .lower ():
32- self .rankllm_ranker = rankllm_Reranker (
33- SafeOpenai (model = self .model , context_size = 4096 , keys = self .api_key )
34- )
38+
39+ # RankLLM-specific parameters
40+ self .window_size = window_size
41+ self .context_size = context_size
42+ self .prompt_mode = prompt_mode
43+ self .num_few_shot_examples = num_few_shot_examples
44+ self .few_shot_file = few_shot_file
45+ self .num_gpus = num_gpus
46+ self .variable_passages = variable_passages
47+ self .use_logits = use_logits
48+ self .use_alpha = use_alpha
49+ self .stride = stride
50+ self .use_azure_openai = use_azure_openai
51+
52+ kwargs = {
53+ "model_path" : self .model ,
54+ "default_model_coordinator" : None ,
55+ "context_size" : self .context_size ,
56+ "prompt_mode" : self .prompt_mode ,
57+ "num_gpus" : self .num_gpus ,
58+ "use_logits" : self .use_logits ,
59+ "use_alpha" : self .use_alpha ,
60+ "num_few_shot_examples" : self .num_few_shot_examples ,
61+ "few_shot_file" : self .few_shot_file ,
62+ "variable_passages" : self .variable_passages ,
63+ "interactive" : False ,
64+ "window_size" : self .window_size ,
65+ "stride" : self .stride ,
66+ "use_azure_openai" : self .use_azure_openai ,
67+ }
68+ model_coordinator = rankllm_Reranker .create_model_coordinator (** kwargs )
69+ self .reranker = rankllm_Reranker (model_coordinator )
3570
3671 def rank (
3772 self ,
@@ -52,7 +87,7 @@ def rank(
5287 ],
5388 )
5489
55- rankllm_results = self .rankllm_ranker .rerank (
90+ rankllm_results = self .reranker .rerank (
5691 request ,
5792 rank_end = len (docs ) if rank_end == 0 else rank_end ,
5893 window_size = min (20 , len (docs )),
0 commit comments