88from collections import namedtuple
99from io import BytesIO
1010
11- import google .generativeai as genai
11+ from google import genai
12+ from google .genai import types
13+
1214from anthropic import Anthropic
1315from google .generativeai import caching
1416from openai import OpenAI
@@ -230,7 +232,8 @@ def __init__(self, client_config):
230232 def _initialize_client (self ):
231233 """Initialize the Generative AI client if not already initialized."""
232234 if not self ._initialized :
233- self .model = genai .GenerativeModel (self .model_id )
235+ self .client = genai .Client ()
236+ self .model = None
234237
235238 # Create kwargs dictionary for GenerationConfig
236239 client_kwargs = {
@@ -241,71 +244,46 @@ def _initialize_client(self):
241244 temperature = self .client_kwargs .get ("temperature" )
242245 if temperature is not None :
243246 client_kwargs ["temperature" ] = temperature
247+
248+ thinking_budget = self .client_kwargs .get ("thinking_budget" , - 1 )
244249
245- self .generation_config = genai .types .GenerationConfig (** client_kwargs )
250+ self .generation_config = genai .types .GenerateContentConfig (
251+ ** client_kwargs ,
252+ thinking_config = types .ThinkingConfig (thinking_budget = )
253+ )
246254 self ._initialized = True
247255
248256 def convert_messages (self , messages ):
249- """Convert messages to the format expected by the Generative AI API .
257+ """Convert messages to the format expected by the new Google GenAI SDK .
250258
251259 Args:
252260 messages (list): A list of message objects.
253261
254262 Returns:
255- list: A list of messages formatted for the Generative AI API.
263+ list[types.Content] : A list of Content objects formatted for the API.
256264 """
257- # Convert standard Message objects to Gemini's format
258265 converted_messages = []
266+
259267 for msg in messages :
260268 parts = []
269+
261270 role = msg .role
262271 if role == "assistant" :
263272 role = "model"
264273 elif role == "system" :
265274 role = "user"
275+
266276 if msg .content :
267- parts .append (msg .content )
277+ parts .append (types .Part (text = msg .content ))
278+
268279 if msg .attachment is not None :
269- parts .append (msg .attachment )
280+ parts .append (types .Part (image = msg .attachment ))
281+
270282 converted_messages .append (
271- {
272- "role" : role ,
273- "parts" : parts ,
274- }
283+ types .Content (role = role , parts = parts )
275284 )
276285 return converted_messages
277286
278- def get_completion (self , converted_messages , max_retries = 5 , delay = 5 ):
279- """Get the completion from the model with retries upon failure.
280-
281- Args:
282- converted_messages (list): Messages formatted for the Generative AI API.
283- max_retries (int, optional): Maximum number of retries. Defaults to 5.
284- delay (int, optional): Delay between retries in seconds. Defaults to 5.
285-
286- Returns:
287- Response object from the API.
288-
289- Raises:
290- Exception: If the API call fails after the maximum number of retries.
291- """
292- retries = 0
293- while retries < max_retries :
294- try :
295- response = self .model .generate_content (
296- converted_messages ,
297- generation_config = self .generation_config ,
298- )
299- return response
300- except Exception as e :
301- retries += 1
302- logger .error (f"Retryable error during generate_content: { e } . Retry { retries } /{ max_retries } " )
303- sleep_time = delay * (2 ** (retries - 1 )) # Exponential backoff
304- time .sleep (sleep_time )
305-
306- # If maximum retries are reached and still no valid response
307- raise Exception (f"Failed to get a valid completion after { max_retries } retries." )
308-
309287 def extract_completion (self , response ):
310288 """Extract the completion text from the API response.
311289
@@ -354,9 +332,10 @@ def generate(self, messages):
354332 converted_messages = self .convert_messages (messages )
355333
356334 def api_call ():
357- response = self .model .generate_content (
358- converted_messages ,
359- generation_config = self .generation_config ,
335+ response = self .client .models .generate_content (
336+ model = self .model_id ,
337+ contents = converted_messages ,
338+ config = self .generation_config ,
360339 )
361340 # Attempt to extract completion immediately after API call
362341 completion = self .extract_completion (response )
0 commit comments