|
4 | 4 | from dotenv import load_dotenv |
5 | 5 | from dataclasses import dataclass |
6 | 6 | from typing import Any, List, Optional, Dict |
7 | | -from litellm import completion, supports_response_schema, completion_cost |
| 7 | +from litellm import completion, supports_response_schema, completion_cost, rerank |
8 | 8 |
|
9 | 9 | logging.getLogger("LiteLLM").setLevel(logging.WARNING) |
10 | 10 |
|
@@ -91,82 +91,80 @@ def get(name, default=None): |
91 | 91 |
|
92 | 92 | @dataclass |
93 | 93 | class Reranker: |
94 | | - def __init__(self, tokenizer: Any, model: Any, device: Any): |
95 | | - self.tokenizer = tokenizer |
96 | | - self.model = model |
97 | | - self.device = device |
98 | | - |
99 | | - def score( |
100 | | - self, query: str, passages: List[str], max_length: int = 512 |
101 | | - ) -> Optional[List[float]]: |
102 | | - """Score (query, passage) pairs. Higher = more relevant.""" |
103 | | - if not passages: |
| 94 | + """Reranker backed by LiteLLM's hosted_vllm provider.""" |
| 95 | + |
| 96 | + model_name: str |
| 97 | + api_base: Optional[str] = None |
| 98 | + api_key: Optional[str] = None |
| 99 | + |
| 100 | + @staticmethod |
| 101 | + def read_field(obj: Any, name: str, default: Any = None) -> Any: |
| 102 | + if isinstance(obj, dict): |
| 103 | + return obj.get(name, default) |
| 104 | + return getattr(obj, name, default) |
| 105 | + |
| 106 | + @classmethod |
| 107 | + def extract_scores(cls, response: Any, total_docs: int) -> Optional[List[float]]: |
| 108 | + scores = [0.0] * total_docs |
| 109 | + results = cls.read_field(response, "results", []) or [] |
| 110 | + valid_count = 0 |
| 111 | + |
| 112 | + for item in results: |
| 113 | + index = cls.read_field(item, "index") |
| 114 | + relevance_score = cls.read_field(item, "relevance_score") |
| 115 | + if relevance_score is None: |
| 116 | + relevance_score = cls.read_field(item, "score") |
| 117 | + |
| 118 | + try: |
| 119 | + index = int(index) |
| 120 | + relevance_score = float(relevance_score) |
| 121 | + except (TypeError, ValueError): |
| 122 | + continue |
| 123 | + |
| 124 | + if 0 <= index < total_docs: |
| 125 | + scores[index] = relevance_score |
| 126 | + valid_count += 1 |
| 127 | + |
| 128 | + if valid_count == 0: |
104 | 129 | return None |
105 | 130 |
|
106 | | - pairs = [(query, passage) for passage in passages] |
| 131 | + return scores |
107 | 132 |
|
108 | | - try: |
109 | | - return self.compute_scores(pairs, max_length) |
110 | | - except Exception: |
| 133 | + def score(self, query: str, passages: List[str]) -> Optional[List[float]]: |
| 134 | + """Score passages through LiteLLM's rerank endpoint.""" |
| 135 | + if not passages: |
111 | 136 | return None |
112 | 137 |
|
113 | | - def compute_scores(self, pairs: List[tuple], max_length: int) -> List[float]: |
114 | | - """Compute relevance scores for query-passage pairs.""" |
115 | | - import torch |
116 | | - |
117 | | - with torch.no_grad(): |
118 | | - inputs = self.tokenizer( |
119 | | - pairs, |
120 | | - padding=True, |
121 | | - truncation=True, |
122 | | - return_tensors="pt", |
123 | | - max_length=max_length, |
| 138 | + try: |
| 139 | + response = rerank( |
| 140 | + model=self.model_name, |
| 141 | + query=query, |
| 142 | + documents=passages, |
| 143 | + top_n=len(passages), |
| 144 | + return_documents=False, |
| 145 | + api_base=self.api_base, |
| 146 | + api_key=self.api_key, |
124 | 147 | ) |
125 | | - inputs = {k: v.to(self.device) for k, v in inputs.items()} |
126 | | - outputs = self.model(**inputs, return_dict=True) |
127 | | - logits = outputs.logits.view(-1).float() |
128 | | - return torch.sigmoid(logits).tolist() |
129 | | - |
130 | | - |
131 | | -def get_device(device_index: int): |
132 | | - import torch |
133 | | - |
134 | | - try: |
135 | | - device_index = int(device_index) |
136 | | - except Exception: |
137 | | - device_index = 0 |
138 | | - |
139 | | - if torch.cuda.is_available() and device_index >= 0: |
140 | | - return torch.device(f"cuda:{device_index}") |
141 | | - return torch.device("cpu") |
142 | | - |
143 | | - |
144 | | -def load_reranker_model(model_name: str, device: Any): |
145 | | - """Load tokenizer and model for reranking.""" |
146 | | - from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| 148 | + except Exception: |
| 149 | + return None |
147 | 150 |
|
148 | | - tokenizer = AutoTokenizer.from_pretrained(model_name) |
149 | | - model = AutoModelForSequenceClassification.from_pretrained(model_name) |
150 | | - model.to(device) |
151 | | - model.eval() |
152 | | - return tokenizer, model |
| 151 | + return self.extract_scores(response, len(passages)) |
153 | 152 |
|
154 | 153 |
|
155 | 154 | def get_reranker(configs: Any) -> Optional[Reranker]: |
156 | | - """Get a Reranker instance from config, or None if unavailable.""" |
| 155 | + """Get a LOCAL reranker backed by LiteLLM's hosted_vllm provider.""" |
157 | 156 |
|
158 | 157 | def get(name, default=None): |
159 | 158 | return get_config_value(configs, name, default) |
160 | 159 |
|
161 | | - model_type = get("model_type") |
162 | | - model_name = get("model_name") |
| 160 | + model_type = get("reranker_model_type") |
| 161 | + model_name = get("reranker_model_name") |
163 | 162 |
|
164 | | - if model_type not in ("huggingface", "local") or not model_name: |
| 163 | + if model_type != "LOCAL" or not model_name: |
165 | 164 | return None |
166 | 165 |
|
167 | | - try: |
168 | | - device = get_device(get("device", 0)) |
169 | | - tokenizer, model = load_reranker_model(model_name, device) |
170 | | - return Reranker(tokenizer=tokenizer, model=model, device=device) |
171 | | - except Exception: |
172 | | - return None |
| 166 | + return Reranker( |
| 167 | + model_name=model_name, |
| 168 | + api_base=os.environ.get("LOCAL_BASE_URL"), |
| 169 | + api_key=os.environ.get("LOCAL_API_KEY"), |
| 170 | + ) |
0 commit comments