Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 25 additions & 47 deletions balrog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from collections import namedtuple
from io import BytesIO

import google.generativeai as genai
from google import genai
from google.genai import types

from anthropic import Anthropic
from google.generativeai import caching
from openai import OpenAI

LLMResponse = namedtuple(
Expand Down Expand Up @@ -230,7 +231,8 @@ def __init__(self, client_config):
def _initialize_client(self):
"""Initialize the Generative AI client if not already initialized."""
if not self._initialized:
self.model = genai.GenerativeModel(self.model_id)
self.client = genai.Client()
self.model = None

# Create kwargs dictionary for GenerationConfig
client_kwargs = {
Expand All @@ -241,71 +243,46 @@ def _initialize_client(self):
temperature = self.client_kwargs.get("temperature")
if temperature is not None:
client_kwargs["temperature"] = temperature

thinking_budget = self.client_kwargs.get("thinking_budget", -1)

self.generation_config = genai.types.GenerationConfig(**client_kwargs)
self.generation_config = genai.types.GenerateContentConfig(
**client_kwargs,
thinking_config=types.ThinkingConfig(thinking_budget=thinking_budget)
)
self._initialized = True

def convert_messages(self, messages):
"""Convert messages to the format expected by the Generative AI API.
"""Convert messages to the format expected by the new Google GenAI SDK.

Args:
messages (list): A list of message objects.

Returns:
list: A list of messages formatted for the Generative AI API.
list[types.Content]: A list of Content objects formatted for the API.
"""
# Convert standard Message objects to Gemini's format
converted_messages = []

for msg in messages:
parts = []

role = msg.role
if role == "assistant":
role = "model"
elif role == "system":
role = "user"

if msg.content:
parts.append(msg.content)
parts.append(types.Part(text=msg.content))

if msg.attachment is not None:
parts.append(msg.attachment)
parts.append(types.Part(image=msg.attachment))

converted_messages.append(
{
"role": role,
"parts": parts,
}
types.Content(role=role, parts=parts)
)
return converted_messages

def get_completion(self, converted_messages, max_retries=5, delay=5):
"""Get the completion from the model with retries upon failure.

Args:
converted_messages (list): Messages formatted for the Generative AI API.
max_retries (int, optional): Maximum number of retries. Defaults to 5.
delay (int, optional): Delay between retries in seconds. Defaults to 5.

Returns:
Response object from the API.

Raises:
Exception: If the API call fails after the maximum number of retries.
"""
retries = 0
while retries < max_retries:
try:
response = self.model.generate_content(
converted_messages,
generation_config=self.generation_config,
)
return response
except Exception as e:
retries += 1
logger.error(f"Retryable error during generate_content: {e}. Retry {retries}/{max_retries}")
sleep_time = delay * (2 ** (retries - 1)) # Exponential backoff
time.sleep(sleep_time)

# If maximum retries are reached and still no valid response
raise Exception(f"Failed to get a valid completion after {max_retries} retries.")

def extract_completion(self, response):
"""Extract the completion text from the API response.

Expand Down Expand Up @@ -354,9 +331,10 @@ def generate(self, messages):
converted_messages = self.convert_messages(messages)

def api_call():
response = self.model.generate_content(
converted_messages,
generation_config=self.generation_config,
response = self.client.models.generate_content(
model=self.model_id,
contents=converted_messages,
config=self.generation_config,
)
# Attempt to extract completion immediately after API call
completion = self.extract_completion(response)
Expand Down
1 change: 1 addition & 0 deletions balrog/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ client:
generate_kwargs:
temperature: 1.0 # Sampling temperature. If null the API default temperature is used instead
max_tokens: 4096 # Max tokens to generate in the response
thinking_budget: null # Thinking budget. Set to a number of tokens to control.
timeout: 60 # Timeout for API requests in seconds
max_retries: 5 # Max number of retries for failed API calls
delay: 2 # Exponential backoff factor between retries in seconds
Expand Down
3 changes: 1 addition & 2 deletions balrog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections import defaultdict
from pathlib import Path

import google.generativeai as genai
import openai


Expand Down Expand Up @@ -206,7 +205,7 @@ def setup_environment(
"""
secrets = load_secrets(os.path.join(original_cwd, "SECRETS"))
if secrets[gemini_tag]:
genai.configure(api_key=secrets[gemini_tag])
os.environ["GEMINI_API_KEY"] = secrets[gemini_tag]
if secrets[anthropic_tag]:
os.environ["ANTHROPIC_API_KEY"] = secrets[anthropic_tag]
if secrets[openai_tag]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
install_requires=[
"openai",
"anthropic",
"google-generativeai",
"google-genai",
"hydra-core",
"opencv-python-headless",
"wandb",
Expand Down