Skip to content

Commit f422f39

Browse files
feat: default to api temperature
1 parent 342484a commit f422f39

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

balrog/client.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,19 @@ def generate(self, messages):
181181
converted_messages = self.convert_messages(messages)
182182

183183
def api_call():
184-
return self.client.chat.completions.create(
185-
messages=converted_messages,
186-
model=self.model_id,
187-
temperature=self.client_kwargs.get("temperature", 0.5),
188-
max_tokens=self.client_kwargs.get("max_tokens", 1024),
189-
)
184+
# Create kwargs for the API call
185+
api_kwargs = {
186+
"messages": converted_messages,
187+
"model": self.model_id,
188+
"max_tokens": self.client_kwargs.get("max_tokens", 1024),
189+
}
190+
191+
# Only include temperature if it's not None
192+
temperature = self.client_kwargs.get("temperature")
193+
if temperature is not None:
194+
api_kwargs["temperature"] = temperature
195+
196+
return self.client.chat.completions.create(**api_kwargs)
190197

191198
response = self.execute_with_retries(api_call)
192199

@@ -217,11 +224,16 @@ def _initialize_client(self):
217224
if not self._initialized:
218225
self.model = genai.GenerativeModel(self.model_id)
219226

227+
# Create kwargs dictionary for GenerationConfig
220228
client_kwargs = {
221-
"temperature": self.client_kwargs.get("temperature", 0.5),
222229
"max_output_tokens": self.client_kwargs.get("max_tokens", 1024),
223230
}
224231

232+
# Only include temperature if it's not None
233+
temperature = self.client_kwargs.get("temperature")
234+
if temperature is not None:
235+
client_kwargs["temperature"] = temperature
236+
225237
self.generation_config = genai.types.GenerationConfig(**client_kwargs)
226238
self._initialized = True
227239

@@ -411,12 +423,19 @@ def generate(self, messages):
411423
converted_messages = self.convert_messages(messages)
412424

413425
def api_call():
414-
return self.client.messages.create(
415-
messages=converted_messages,
416-
model=self.model_id,
417-
temperature=self.client_kwargs.get("temperature", 0.5),
418-
max_tokens=self.client_kwargs.get("max_tokens", 1024),
419-
)
426+
# Create kwargs for the API call
427+
api_kwargs = {
428+
"messages": converted_messages,
429+
"model": self.model_id,
430+
"max_tokens": self.client_kwargs.get("max_tokens", 1024),
431+
}
432+
433+
# Only include temperature if it's not None
434+
temperature = self.client_kwargs.get("temperature")
435+
if temperature is not None:
436+
api_kwargs["temperature"] = temperature
437+
438+
return self.client.messages.create(**api_kwargs)
420439

421440
response = self.execute_with_retries(api_call)
422441

balrog/config/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ client:
3030
model_id: gpt-4o # Model identifier (e.g., 'gpt-4', 'gpt-3.5-turbo')
3131
base_url: http://localhost:8080/v1 # Base URL for the API (if using a local server)
3232
generate_kwargs:
33-
temperature: 0.0 # Sampling temperature; 0.0 makes the output deterministic
33+
temperature: null # Sampling temperature. If null the API default temperature is used instead
3434
max_tokens: 4096 # Max tokens to generate in the response
3535
timeout: 60 # Timeout for API requests in seconds
3636
max_retries: 5 # Max number of retries for failed API calls

0 commit comments

Comments
 (0)