88from collections import namedtuple
99from io import BytesIO
1010
11- import google .generativeai as genai
11+ from google import genai
12+ from google .genai import types
13+
1214from anthropic import Anthropic
13- from google .generativeai import caching
1415from openai import OpenAI
1516
1617LLMResponse = namedtuple (
@@ -230,7 +231,8 @@ def __init__(self, client_config):
230231 def _initialize_client (self ):
231232 """Initialize the Generative AI client if not already initialized."""
232233 if not self ._initialized :
233- self .model = genai .GenerativeModel (self .model_id )
234+ self .client = genai .Client ()
235+ self .model = None
234236
235237 # Create kwargs dictionary for GenerationConfig
236238 client_kwargs = {
@@ -241,71 +243,46 @@ def _initialize_client(self):
241243 temperature = self .client_kwargs .get ("temperature" )
242244 if temperature is not None :
243245 client_kwargs ["temperature" ] = temperature
246+
247+ thinking_budget = self .client_kwargs .get ("thinking_budget" , - 1 )
244248
245- self .generation_config = genai .types .GenerationConfig (** client_kwargs )
249+ self .generation_config = genai .types .GenerateContentConfig (
250+ ** client_kwargs ,
251+ thinking_config = types .ThinkingConfig (thinking_budget = thinking_budget )
252+ )
246253 self ._initialized = True
247254
248255 def convert_messages (self , messages ):
249- """Convert messages to the format expected by the Generative AI API .
256+ """Convert messages to the format expected by the new Google GenAI SDK .
250257
251258 Args:
252259 messages (list): A list of message objects.
253260
254261 Returns:
255- list: A list of messages formatted for the Generative AI API.
262+ list[types.Content] : A list of Content objects formatted for the API.
256263 """
257- # Convert standard Message objects to Gemini's format
258264 converted_messages = []
265+
259266 for msg in messages :
260267 parts = []
268+
261269 role = msg .role
262270 if role == "assistant" :
263271 role = "model"
264272 elif role == "system" :
265273 role = "user"
274+
266275 if msg .content :
267- parts .append (msg .content )
276+ parts .append (types .Part (text = msg .content ))
277+
268278 if msg .attachment is not None :
269- parts .append (msg .attachment )
279+ parts .append (types .Part (image = msg .attachment ))
280+
270281 converted_messages .append (
271- {
272- "role" : role ,
273- "parts" : parts ,
274- }
282+ types .Content (role = role , parts = parts )
275283 )
276284 return converted_messages
277285
278- def get_completion (self , converted_messages , max_retries = 5 , delay = 5 ):
279- """Get the completion from the model with retries upon failure.
280-
281- Args:
282- converted_messages (list): Messages formatted for the Generative AI API.
283- max_retries (int, optional): Maximum number of retries. Defaults to 5.
284- delay (int, optional): Delay between retries in seconds. Defaults to 5.
285-
286- Returns:
287- Response object from the API.
288-
289- Raises:
290- Exception: If the API call fails after the maximum number of retries.
291- """
292- retries = 0
293- while retries < max_retries :
294- try :
295- response = self .model .generate_content (
296- converted_messages ,
297- generation_config = self .generation_config ,
298- )
299- return response
300- except Exception as e :
301- retries += 1
302- logger .error (f"Retryable error during generate_content: { e } . Retry { retries } /{ max_retries } " )
303- sleep_time = delay * (2 ** (retries - 1 )) # Exponential backoff
304- time .sleep (sleep_time )
305-
306- # If maximum retries are reached and still no valid response
307- raise Exception (f"Failed to get a valid completion after { max_retries } retries." )
308-
309286 def extract_completion (self , response ):
310287 """Extract the completion text from the API response.
311288
@@ -354,9 +331,10 @@ def generate(self, messages):
354331 converted_messages = self .convert_messages (messages )
355332
356333 def api_call ():
357- response = self .model .generate_content (
358- converted_messages ,
359- generation_config = self .generation_config ,
334+ response = self .client .models .generate_content (
335+ model = self .model_id ,
336+ contents = converted_messages ,
337+ config = self .generation_config ,
360338 )
361339 # Attempt to extract completion immediately after API call
362340 completion = self .extract_completion (response )
0 commit comments