Skip to content

Commit 02732ae

Browse files
feat: google-genai api
1 parent 393505a commit 02732ae

File tree

3 files changed

+27
-47
lines changed

3 files changed

+27
-47
lines changed

balrog/client.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from collections import namedtuple
99
from io import BytesIO
1010

11-
import google.generativeai as genai
11+
from google import genai
12+
from google.genai import types
13+
1214
from anthropic import Anthropic
1315
from google.generativeai import caching
1416
from openai import OpenAI
@@ -230,7 +232,8 @@ def __init__(self, client_config):
230232
def _initialize_client(self):
231233
"""Initialize the Generative AI client if not already initialized."""
232234
if not self._initialized:
233-
self.model = genai.GenerativeModel(self.model_id)
235+
self.client = genai.Client()
236+
self.model = None
234237

235238
# Create kwargs dictionary for GenerationConfig
236239
client_kwargs = {
@@ -241,71 +244,46 @@ def _initialize_client(self):
241244
temperature = self.client_kwargs.get("temperature")
242245
if temperature is not None:
243246
client_kwargs["temperature"] = temperature
247+
248+
thinking_budget = self.client_kwargs.get("thinking_budget", -1)
244249

245-
self.generation_config = genai.types.GenerationConfig(**client_kwargs)
250+
self.generation_config = genai.types.GenerateContentConfig(
251+
**client_kwargs,
252+
thinking_config=types.ThinkingConfig(thinking_budget=)
253+
)
246254
self._initialized = True
247255

248256
def convert_messages(self, messages):
249-
"""Convert messages to the format expected by the Generative AI API.
257+
"""Convert messages to the format expected by the new Google GenAI SDK.
250258
251259
Args:
252260
messages (list): A list of message objects.
253261
254262
Returns:
255-
list: A list of messages formatted for the Generative AI API.
263+
list[types.Content]: A list of Content objects formatted for the API.
256264
"""
257-
# Convert standard Message objects to Gemini's format
258265
converted_messages = []
266+
259267
for msg in messages:
260268
parts = []
269+
261270
role = msg.role
262271
if role == "assistant":
263272
role = "model"
264273
elif role == "system":
265274
role = "user"
275+
266276
if msg.content:
267-
parts.append(msg.content)
277+
parts.append(types.Part(text=msg.content))
278+
268279
if msg.attachment is not None:
269-
parts.append(msg.attachment)
280+
parts.append(types.Part(image=msg.attachment))
281+
270282
converted_messages.append(
271-
{
272-
"role": role,
273-
"parts": parts,
274-
}
283+
types.Content(role=role, parts=parts)
275284
)
276285
return converted_messages
277286

278-
def get_completion(self, converted_messages, max_retries=5, delay=5):
279-
"""Get the completion from the model with retries upon failure.
280-
281-
Args:
282-
converted_messages (list): Messages formatted for the Generative AI API.
283-
max_retries (int, optional): Maximum number of retries. Defaults to 5.
284-
delay (int, optional): Delay between retries in seconds. Defaults to 5.
285-
286-
Returns:
287-
Response object from the API.
288-
289-
Raises:
290-
Exception: If the API call fails after the maximum number of retries.
291-
"""
292-
retries = 0
293-
while retries < max_retries:
294-
try:
295-
response = self.model.generate_content(
296-
converted_messages,
297-
generation_config=self.generation_config,
298-
)
299-
return response
300-
except Exception as e:
301-
retries += 1
302-
logger.error(f"Retryable error during generate_content: {e}. Retry {retries}/{max_retries}")
303-
sleep_time = delay * (2 ** (retries - 1)) # Exponential backoff
304-
time.sleep(sleep_time)
305-
306-
# If maximum retries are reached and still no valid response
307-
raise Exception(f"Failed to get a valid completion after {max_retries} retries.")
308-
309287
def extract_completion(self, response):
310288
"""Extract the completion text from the API response.
311289
@@ -354,9 +332,10 @@ def generate(self, messages):
354332
converted_messages = self.convert_messages(messages)
355333

356334
def api_call():
357-
response = self.model.generate_content(
358-
converted_messages,
359-
generation_config=self.generation_config,
335+
response = self.client.models.generate_content(
336+
model=self.model_id,
337+
contents=converted_messages,
338+
config=self.generation_config,
360339
)
361340
# Attempt to extract completion immediately after API call
362341
completion = self.extract_completion(response)

balrog/config/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ client:
3232
generate_kwargs:
3333
temperature: 1.0 # Sampling temperature. If null the API default temperature is used instead
3434
max_tokens: 4096 # Max tokens to generate in the response
35+
thinking_budget: null # Thinking budget. Set to a number of tokens to control.
3536
timeout: 60 # Timeout for API requests in seconds
3637
max_retries: 5 # Max number of retries for failed API calls
3738
delay: 2 # Exponential backoff factor between retries in seconds

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
install_requires=[
2222
"openai",
2323
"anthropic",
24-
"google-generativeai",
24+
"google-genai",
2525
"hydra-core",
2626
"opencv-python-headless",
2727
"wandb",

0 commit comments

Comments
 (0)