Skip to content

Commit f4d4ddd

Browse files
fix: execute with retries for gemini (#48)
* fix: execute with retries for gemini * fix: gemini retry
1 parent e006561 commit f4d4ddd

File tree

1 file changed

+57
-29
lines changed

1 file changed

+57
-29
lines changed

balrog/client.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import datetime
33
import logging
44
import time
5+
import json
6+
import csv
7+
import os
58
from collections import namedtuple
69
from io import BytesIO
710

@@ -351,38 +354,63 @@ def generate(self, messages):
351354
converted_messages = self.convert_messages(messages)
352355

353356
def api_call():
354-
return self.model.generate_content(
357+
response = self.model.generate_content(
355358
converted_messages,
356359
generation_config=self.generation_config,
357360
)
358-
359-
response = self.execute_with_retries(api_call)
360-
361-
def extract_completion_call():
362-
return self.extract_completion(response)
363-
364-
completion = self.execute_with_retries(extract_completion_call)
365-
366-
return LLMResponse(
367-
model_id=self.model_id,
368-
completion=completion,
369-
stop_reason=(
370-
getattr(response.candidates[0], "finish_reason", "unknown")
371-
if response and getattr(response, "candidates", [])
372-
else "unknown"
373-
),
374-
input_tokens=(
375-
getattr(response.usage_metadata, "prompt_token_count", 0)
376-
if response and getattr(response, "usage_metadata", None)
377-
else 0
378-
),
379-
output_tokens=(
380-
getattr(response.usage_metadata, "candidates_token_count", 0)
381-
if response and getattr(response, "usage_metadata", None)
382-
else 0
383-
),
384-
reasoning=None,
385-
)
361+
# Attempt to extract completion immediately after API call
362+
completion = self.extract_completion(response)
363+
# Return both response and completion if successful
364+
return response, completion
365+
366+
try:
367+
# Execute the API call and extraction together with retries
368+
response, completion = self.execute_with_retries(api_call)
369+
370+
# Check if the successful response contains an empty completion
371+
if not completion or completion.strip() == "":
372+
logger.warning(f"Gemini returned an empty completion for model {self.model_id}. Returning default empty response.")
373+
return LLMResponse(
374+
model_id=self.model_id,
375+
completion="",
376+
stop_reason="empty_response",
377+
input_tokens=getattr(response.usage_metadata, "prompt_token_count", 0) if response and getattr(response, "usage_metadata", None) else 0,
378+
output_tokens=getattr(response.usage_metadata, "candidates_token_count", 0) if response and getattr(response, "usage_metadata", None) else 0,
379+
reasoning=None,
380+
)
381+
else:
382+
# If completion is not empty, return the normal response
383+
return LLMResponse(
384+
model_id=self.model_id,
385+
completion=completion,
386+
stop_reason=(
387+
getattr(response.candidates[0], "finish_reason", "unknown")
388+
if response and getattr(response, "candidates", [])
389+
else "unknown"
390+
),
391+
input_tokens=(
392+
getattr(response.usage_metadata, "prompt_token_count", 0)
393+
if response and getattr(response, "usage_metadata", None)
394+
else 0
395+
),
396+
output_tokens=(
397+
getattr(response.usage_metadata, "candidates_token_count", 0)
398+
if response and getattr(response, "usage_metadata", None)
399+
else 0
400+
),
401+
reasoning=None,
402+
)
403+
except Exception as e:
404+
logger.error(f"API call failed after {self.max_retries} retries: {e}. Returning empty completion.")
405+
# Return a default response indicating failure
406+
return LLMResponse(
407+
model_id=self.model_id,
408+
completion="",
409+
stop_reason="error_max_retries",
410+
input_tokens=0, # Assuming 0 tokens consumed if call failed
411+
output_tokens=0,
412+
reasoning=None,
413+
)
386414

387415

388416
class ClaudeWrapper(LLMClientWrapper):

0 commit comments

Comments
 (0)