|
17 | 17 | logger = logging.getLogger("AutoRAG") |
18 | 18 |
|
19 | 19 | MAX_TOKEN_DICT = { # model name : token limit |
| 20 | + "gpt-5.1-2025-11-13": 272_000, |
| 21 | + "gpt-5.1": 272_000, |
20 | 22 | "gpt-5": 272_000, |
| 23 | + "gpt-5-pro": 272_000, |
21 | 24 | "gpt-5-2025-08-07": 272_000, |
22 | 25 | "gpt-5-chat-latest": 272_000, |
23 | 26 | "gpt-5-mini-2025-08-07": 272_000, |
@@ -147,12 +150,17 @@ def _pure( |
147 | 150 | self.llm.startswith("o1") |
148 | 151 | or self.llm.startswith("o3") |
149 | 152 | or self.llm.startswith("o4") |
150 | | - or self.llm.startswith("gpt-5") |
151 | 153 | ): |
152 | 154 | tasks = [ |
153 | 155 | self.get_result_reasoning(prompt, **openai_chat_params) |
154 | 156 | for prompt in prompts |
155 | 157 | ] |
| 158 | + elif self.llm.startswith("gpt-5"): |
| 159 | + responses_create_params = pop_params(self.client.responses.create, kwargs) |
| 160 | + tasks = [ |
| 161 | + self.get_result_gpt_5(prompt, **responses_create_params) |
| 162 | + for prompt in prompts |
| 163 | + ] |
156 | 164 | else: |
157 | 165 | tasks = [ |
158 | 166 | self.get_result(prompt, **openai_chat_params) for prompt in prompts |
@@ -269,7 +277,6 @@ async def get_result_reasoning(self, prompt: Union[str, List[dict]], **kwargs): |
269 | 277 | self.llm.startswith("o1") |
270 | 278 | or self.llm.startswith("o3") |
271 | 279 | or self.llm.startswith("o4") |
272 | | - or self.llm.startswith("gpt-5") |
273 | 280 | ): |
274 | 281 | raise ValueError("get_result_reasoning is only for o1,o3,o4,gpt-5 models.") |
275 | 282 | # The default temperature for the o1 model is 1. 1 is only supported. |
@@ -299,6 +306,33 @@ async def get_result_reasoning(self, prompt: Union[str, List[dict]], **kwargs): |
299 | 306 | pseudo_log_probs = [0.5] * len(tokens) |
300 | 307 | return answer, tokens, pseudo_log_probs |
301 | 308 |
|
| 309 | + async def get_result_gpt_5(self, prompt: Union[str, List[dict]], **kwargs): |
| 310 | + if not self.llm.startswith("gpt-5"): |
| 311 | + raise ValueError("get_result_gpt_5 is only for gpt-5 models.") |
| 312 | + api_key = getattr(self.client, "api_key", None) |
| 313 | + if isinstance(api_key, str) and api_key.startswith("mock_"): |
| 314 | + answer = "Why not" |
| 315 | + tokens = self.tokenizer.encode(answer, allowed_special="all") |
| 316 | + pseudo_log_probs = [0.5] * len(tokens) |
| 317 | + return answer, tokens, pseudo_log_probs |
| 318 | + messages = parse_prompt(prompt) |
| 319 | + instruction = "\n\n".join( |
| 320 | + [msg["content"] for msg in messages if msg["role"] == "system"] |
| 321 | + ) |
| 322 | + user_input = "\n\n".join( |
| 323 | + [msg["content"] for msg in messages if msg["role"] == "user"] |
| 324 | + ) |
| 325 | + response = await self.client.responses.create( |
| 326 | + model=self.llm, |
| 327 | + instructions=instruction, |
| 328 | + input=user_input, |
| 329 | + **kwargs, |
| 330 | + ) |
| 331 | + answer: str = response.output_text |
| 332 | + tokens = self.tokenizer.encode(answer, allowed_special="all") |
| 333 | + pseudo_log_probs = [0.5] * len(tokens) |
| 334 | + return answer, tokens, pseudo_log_probs |
| 335 | + |
302 | 336 |
|
303 | 337 | def truncate_by_token( |
304 | 338 | prompt: Union[str, List[Dict]], tokenizer: Encoding, max_token_size: int |
|
0 commit comments