99from typing import Any , Dict , List , Tuple , Union
1010
1111from ftfy import fix_text
12+ from gguf import Optional
1213from tqdm import tqdm
1314
1415from rank_llm .data import RankingExecInfo , Request , Result
@@ -58,7 +59,7 @@ class ListwiseRankLLM(RankLLM, ABC):
5859
5960 def __init__ (
6061 self ,
61- reorder_policy : ReorderPolicy ,
62+ reorder_policy : Optional [ ReorderPolicy ] ,
6263 model : str ,
6364 context_size : int ,
6465 window_size : int ,
@@ -74,7 +75,6 @@ def __init__(
7475 )
7576 self ._window_size = window_size
7677 self ._use_alpha = use_alpha
77-
7878
7979 def rerank_batch (
8080 self ,
@@ -88,16 +88,16 @@ def rerank_batch(
8888 ) -> List [Result ]:
8989 populate_exec_summary : bool = kwargs .get ("populate_exec_summary" , False )
9090
91- batch_size = kwargs .get ("batch_size" , 1 )
91+ batch_size = kwargs .get ("batch_size" ) or len ( requests )
9292
9393 if not batched :
9494 batch_size = 1
9595
9696 reorder_policy = self .reorder_policy
9797 model_functions , consumption = self ._get_model_function (batched , ** kwargs )
9898
99- # reranking using vllm
100- if len (set ([len (req .candidates ) for req in requests ])) != 1 :
99+ # reranking using batched mode
100+ if batched and len (set ([len (req .candidates ) for req in requests ])) != 1 :
101101 raise ValueError ("Batched requests must have the same number of candidates" )
102102
103103 result : list [Result ] = []
@@ -462,12 +462,11 @@ def _clean_response(self, response: str) -> str:
462462 else :
463463 for c in response :
464464 if not c .isdigit ():
465- if len (new_response ) == 0 or new_response [- 1 ] != " " :
466465 new_response += " "
467466 else :
468467 new_response += c
469468 new_response = new_response .strip ()
470-
469+ new_response = re . sub ( r"\s+" , " " , new_response )
471470 return new_response
472471
473472 def _remove_duplicate (self , response : List [int ]) -> List [int ]:
0 commit comments