Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 57 additions & 29 deletions balrog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import datetime
import logging
import time
import json
import csv
import os
from collections import namedtuple
from io import BytesIO

Expand Down Expand Up @@ -351,38 +354,63 @@ def generate(self, messages):
converted_messages = self.convert_messages(messages)

def api_call():
return self.model.generate_content(
response = self.model.generate_content(
converted_messages,
generation_config=self.generation_config,
)

response = self.execute_with_retries(api_call)

def extract_completion_call():
return self.extract_completion(response)

completion = self.execute_with_retries(extract_completion_call)

return LLMResponse(
model_id=self.model_id,
completion=completion,
stop_reason=(
getattr(response.candidates[0], "finish_reason", "unknown")
if response and getattr(response, "candidates", [])
else "unknown"
),
input_tokens=(
getattr(response.usage_metadata, "prompt_token_count", 0)
if response and getattr(response, "usage_metadata", None)
else 0
),
output_tokens=(
getattr(response.usage_metadata, "candidates_token_count", 0)
if response and getattr(response, "usage_metadata", None)
else 0
),
reasoning=None,
)
# Attempt to extract completion immediately after API call
completion = self.extract_completion(response)
# Return both response and completion if successful
return response, completion

try:
# Execute the API call and extraction together with retries
response, completion = self.execute_with_retries(api_call)

# Check if the successful response contains an empty completion
if not completion or completion.strip() == "":
logger.warning(f"Gemini returned an empty completion for model {self.model_id}. Returning default empty response.")
return LLMResponse(
model_id=self.model_id,
completion="",
stop_reason="empty_response",
input_tokens=getattr(response.usage_metadata, "prompt_token_count", 0) if response and getattr(response, "usage_metadata", None) else 0,
output_tokens=getattr(response.usage_metadata, "candidates_token_count", 0) if response and getattr(response, "usage_metadata", None) else 0,
reasoning=None,
)
else:
# If completion is not empty, return the normal response
return LLMResponse(
model_id=self.model_id,
completion=completion,
stop_reason=(
getattr(response.candidates[0], "finish_reason", "unknown")
if response and getattr(response, "candidates", [])
else "unknown"
),
input_tokens=(
getattr(response.usage_metadata, "prompt_token_count", 0)
if response and getattr(response, "usage_metadata", None)
else 0
),
output_tokens=(
getattr(response.usage_metadata, "candidates_token_count", 0)
if response and getattr(response, "usage_metadata", None)
else 0
),
reasoning=None,
)
except Exception as e:
logger.error(f"API call failed after {self.max_retries} retries: {e}. Returning empty completion.")
# Return a default response indicating failure
return LLMResponse(
model_id=self.model_id,
completion="",
stop_reason="error_max_retries",
input_tokens=0, # Assuming 0 tokens consumed if call failed
output_tokens=0,
reasoning=None,
)


class ClaudeWrapper(LLMClientWrapper):
Expand Down