Skip to content

Commit b20b5e1

Browse files
feat: google-genai api (#61)
* feat: google-genai api * chore: add thinking_budget * remove old import * fix import * remove import
1 parent 393505a commit b20b5e1

File tree

4 files changed

+28
-50
lines changed

4 files changed

+28
-50
lines changed

balrog/client.py

Lines changed: 25 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from collections import namedtuple
99
from io import BytesIO
1010

11-
import google.generativeai as genai
11+
from google import genai
12+
from google.genai import types
13+
1214
from anthropic import Anthropic
13-
from google.generativeai import caching
1415
from openai import OpenAI
1516

1617
LLMResponse = namedtuple(
@@ -230,7 +231,8 @@ def __init__(self, client_config):
230231
def _initialize_client(self):
231232
"""Initialize the Generative AI client if not already initialized."""
232233
if not self._initialized:
233-
self.model = genai.GenerativeModel(self.model_id)
234+
self.client = genai.Client()
235+
self.model = None
234236

235237
# Create kwargs dictionary for GenerationConfig
236238
client_kwargs = {
@@ -241,71 +243,46 @@ def _initialize_client(self):
241243
temperature = self.client_kwargs.get("temperature")
242244
if temperature is not None:
243245
client_kwargs["temperature"] = temperature
246+
247+
thinking_budget = self.client_kwargs.get("thinking_budget", -1)
244248

245-
self.generation_config = genai.types.GenerationConfig(**client_kwargs)
249+
self.generation_config = genai.types.GenerateContentConfig(
250+
**client_kwargs,
251+
thinking_config=types.ThinkingConfig(thinking_budget=thinking_budget)
252+
)
246253
self._initialized = True
247254

248255
def convert_messages(self, messages):
249-
"""Convert messages to the format expected by the Generative AI API.
256+
"""Convert messages to the format expected by the new Google GenAI SDK.
250257
251258
Args:
252259
messages (list): A list of message objects.
253260
254261
Returns:
255-
list: A list of messages formatted for the Generative AI API.
262+
list[types.Content]: A list of Content objects formatted for the API.
256263
"""
257-
# Convert standard Message objects to Gemini's format
258264
converted_messages = []
265+
259266
for msg in messages:
260267
parts = []
268+
261269
role = msg.role
262270
if role == "assistant":
263271
role = "model"
264272
elif role == "system":
265273
role = "user"
274+
266275
if msg.content:
267-
parts.append(msg.content)
276+
parts.append(types.Part(text=msg.content))
277+
268278
if msg.attachment is not None:
269-
parts.append(msg.attachment)
279+
parts.append(types.Part(image=msg.attachment))
280+
270281
converted_messages.append(
271-
{
272-
"role": role,
273-
"parts": parts,
274-
}
282+
types.Content(role=role, parts=parts)
275283
)
276284
return converted_messages
277285

278-
def get_completion(self, converted_messages, max_retries=5, delay=5):
279-
"""Get the completion from the model with retries upon failure.
280-
281-
Args:
282-
converted_messages (list): Messages formatted for the Generative AI API.
283-
max_retries (int, optional): Maximum number of retries. Defaults to 5.
284-
delay (int, optional): Delay between retries in seconds. Defaults to 5.
285-
286-
Returns:
287-
Response object from the API.
288-
289-
Raises:
290-
Exception: If the API call fails after the maximum number of retries.
291-
"""
292-
retries = 0
293-
while retries < max_retries:
294-
try:
295-
response = self.model.generate_content(
296-
converted_messages,
297-
generation_config=self.generation_config,
298-
)
299-
return response
300-
except Exception as e:
301-
retries += 1
302-
logger.error(f"Retryable error during generate_content: {e}. Retry {retries}/{max_retries}")
303-
sleep_time = delay * (2 ** (retries - 1)) # Exponential backoff
304-
time.sleep(sleep_time)
305-
306-
# If maximum retries are reached and still no valid response
307-
raise Exception(f"Failed to get a valid completion after {max_retries} retries.")
308-
309286
def extract_completion(self, response):
310287
"""Extract the completion text from the API response.
311288
@@ -354,9 +331,10 @@ def generate(self, messages):
354331
converted_messages = self.convert_messages(messages)
355332

356333
def api_call():
357-
response = self.model.generate_content(
358-
converted_messages,
359-
generation_config=self.generation_config,
334+
response = self.client.models.generate_content(
335+
model=self.model_id,
336+
contents=converted_messages,
337+
config=self.generation_config,
360338
)
361339
# Attempt to extract completion immediately after API call
362340
completion = self.extract_completion(response)

balrog/config/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ client:
3232
generate_kwargs:
3333
temperature: 1.0 # Sampling temperature. If null the API default temperature is used instead
3434
max_tokens: 4096 # Max tokens to generate in the response
35+
thinking_budget: null # Thinking budget. Set to a number of tokens to control.
3536
timeout: 60 # Timeout for API requests in seconds
3637
max_retries: 5 # Max number of retries for failed API calls
3738
delay: 2 # Exponential backoff factor between retries in seconds

balrog/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from collections import defaultdict
88
from pathlib import Path
99

10-
import google.generativeai as genai
1110
import openai
1211

1312

@@ -206,7 +205,7 @@ def setup_environment(
206205
"""
207206
secrets = load_secrets(os.path.join(original_cwd, "SECRETS"))
208207
if secrets[gemini_tag]:
209-
genai.configure(api_key=secrets[gemini_tag])
208+
os.environ["GEMINI_API_KEY"] = secrets[gemini_tag]
210209
if secrets[anthropic_tag]:
211210
os.environ["ANTHROPIC_API_KEY"] = secrets[anthropic_tag]
212211
if secrets[openai_tag]:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
install_requires=[
2222
"openai",
2323
"anthropic",
24-
"google-generativeai",
24+
"google-genai",
2525
"hydra-core",
2626
"opencv-python-headless",
2727
"wandb",

0 commit comments

Comments
 (0)