@@ -181,12 +181,19 @@ def generate(self, messages):
181181 converted_messages = self .convert_messages (messages )
182182
183183 def api_call ():
184- return self .client .chat .completions .create (
185- messages = converted_messages ,
186- model = self .model_id ,
187- temperature = self .client_kwargs .get ("temperature" , 0.5 ),
188- max_tokens = self .client_kwargs .get ("max_tokens" , 1024 ),
189- )
184+ # Create kwargs for the API call
185+ api_kwargs = {
186+ "messages" : converted_messages ,
187+ "model" : self .model_id ,
188+ "max_tokens" : self .client_kwargs .get ("max_tokens" , 1024 ),
189+ }
190+
191+ # Only include temperature if it's not None
192+ temperature = self .client_kwargs .get ("temperature" )
193+ if temperature is not None :
194+ api_kwargs ["temperature" ] = temperature
195+
196+ return self .client .chat .completions .create (** api_kwargs )
190197
191198 response = self .execute_with_retries (api_call )
192199
@@ -217,11 +224,16 @@ def _initialize_client(self):
217224 if not self ._initialized :
218225 self .model = genai .GenerativeModel (self .model_id )
219226
227+ # Create kwargs dictionary for GenerationConfig
220228 client_kwargs = {
221- "temperature" : self .client_kwargs .get ("temperature" , 0.5 ),
222229 "max_output_tokens" : self .client_kwargs .get ("max_tokens" , 1024 ),
223230 }
224231
232+ # Only include temperature if it's not None
233+ temperature = self .client_kwargs .get ("temperature" )
234+ if temperature is not None :
235+ client_kwargs ["temperature" ] = temperature
236+
225237 self .generation_config = genai .types .GenerationConfig (** client_kwargs )
226238 self ._initialized = True
227239
@@ -411,12 +423,19 @@ def generate(self, messages):
411423 converted_messages = self .convert_messages (messages )
412424
413425 def api_call ():
414- return self .client .messages .create (
415- messages = converted_messages ,
416- model = self .model_id ,
417- temperature = self .client_kwargs .get ("temperature" , 0.5 ),
418- max_tokens = self .client_kwargs .get ("max_tokens" , 1024 ),
419- )
426+ # Create kwargs for the API call
427+ api_kwargs = {
428+ "messages" : converted_messages ,
429+ "model" : self .model_id ,
430+ "max_tokens" : self .client_kwargs .get ("max_tokens" , 1024 ),
431+ }
432+
433+ # Only include temperature if it's not None
434+ temperature = self .client_kwargs .get ("temperature" )
435+ if temperature is not None :
436+ api_kwargs ["temperature" ] = temperature
437+
438+ return self .client .messages .create (** api_kwargs )
420439
421440 response = self .execute_with_retries (api_call )
422441
0 commit comments