Skip to content

Commit 2d03c8b

Browse files
committed
Fix bug
1 parent acf8012 commit 2d03c8b

File tree

4 files changed

+153
-109
lines changed

4 files changed

+153
-109
lines changed

src/rank_llm/rerank/listwise/listwise_rankllm.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Dict, List, Tuple, Union
1010

1111
from ftfy import fix_text
12+
from gguf import Optional
1213
from tqdm import tqdm
1314

1415
from rank_llm.data import RankingExecInfo, Request, Result
@@ -58,7 +59,7 @@ class ListwiseRankLLM(RankLLM, ABC):
5859

5960
def __init__(
6061
self,
61-
reorder_policy: ReorderPolicy,
62+
reorder_policy: Optional[ReorderPolicy],
6263
model: str,
6364
context_size: int,
6465
window_size: int,
@@ -74,7 +75,6 @@ def __init__(
7475
)
7576
self._window_size = window_size
7677
self._use_alpha = use_alpha
77-
7878

7979
def rerank_batch(
8080
self,
@@ -88,16 +88,16 @@ def rerank_batch(
8888
) -> List[Result]:
8989
populate_exec_summary: bool = kwargs.get("populate_exec_summary", False)
9090

91-
batch_size = kwargs.get("batch_size", 1)
91+
batch_size = kwargs.get("batch_size") or len(requests)
9292

9393
if not batched:
9494
batch_size = 1
9595

9696
reorder_policy = self.reorder_policy
9797
model_functions, consumption = self._get_model_function(batched, **kwargs)
9898

99-
# reranking using vllm
100-
if len(set([len(req.candidates) for req in requests])) != 1:
99+
# reranking using batched mode
100+
if batched and len(set([len(req.candidates) for req in requests])) != 1:
101101
raise ValueError("Batched requests must have the same number of candidates")
102102

103103
result: list[Result] = []
@@ -462,12 +462,11 @@ def _clean_response(self, response: str) -> str:
462462
else:
463463
for c in response:
464464
if not c.isdigit():
465-
if len(new_response) == 0 or new_response[-1] != " ":
466465
new_response += " "
467466
else:
468467
new_response += c
469468
new_response = new_response.strip()
470-
469+
new_response = re.sub(r"\s+", " ", new_response)
471470
return new_response
472471

473472
def _remove_duplicate(self, response: List[int]) -> List[int]:

src/rank_llm/rerank/listwise/rank_gpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
window_size=window_size,
6363
prompt_mode=prompt_mode,
6464
num_few_shot_examples=num_few_shot_examples,
65+
use_alpha=False, # Alphabet is not supported in OpenAI for now
6566
)
6667
if isinstance(keys, str):
6768
keys = [keys]
@@ -103,6 +104,7 @@ def rerank_batch(
103104
rank_end: int = 100,
104105
shuffle_candidates: bool = False,
105106
logging: bool = False,
107+
batched: bool = False,
106108
**kwargs: Any,
107109
) -> List[Result]:
108110
return super().rerank_batch(
@@ -130,7 +132,7 @@ def _call_completion(
130132
*args, **kwargs, timeout=30
131133
)
132134
elif completion_mode == self.CompletionMode.TEXT:
133-
completion = openai.Completion.create(*args, **kwargs)
135+
completion = openaiCompletion.create(*args, **kwargs)
134136
else:
135137
raise ValueError(
136138
"Unsupported completion mode: %V" % completion_mode

0 commit comments

Comments
 (0)