diff --git a/balrog/client.py b/balrog/client.py index be03166f..7588470f 100644 --- a/balrog/client.py +++ b/balrog/client.py @@ -181,12 +181,19 @@ def generate(self, messages): converted_messages = self.convert_messages(messages) def api_call(): - return self.client.chat.completions.create( - messages=converted_messages, - model=self.model_id, - temperature=self.client_kwargs.get("temperature", 0.5), - max_tokens=self.client_kwargs.get("max_tokens", 1024), - ) + # Create kwargs for the API call + api_kwargs = { + "messages": converted_messages, + "model": self.model_id, + "max_tokens": self.client_kwargs.get("max_tokens", 1024), + } + + # Only include temperature if it's not None + temperature = self.client_kwargs.get("temperature") + if temperature is not None: + api_kwargs["temperature"] = temperature + + return self.client.chat.completions.create(**api_kwargs) response = self.execute_with_retries(api_call) @@ -217,11 +224,16 @@ def _initialize_client(self): if not self._initialized: self.model = genai.GenerativeModel(self.model_id) + # Create kwargs dictionary for GenerationConfig client_kwargs = { - "temperature": self.client_kwargs.get("temperature", 0.5), "max_output_tokens": self.client_kwargs.get("max_tokens", 1024), } + # Only include temperature if it's not None + temperature = self.client_kwargs.get("temperature") + if temperature is not None: + client_kwargs["temperature"] = temperature + self.generation_config = genai.types.GenerationConfig(**client_kwargs) self._initialized = True @@ -411,12 +423,19 @@ def generate(self, messages): converted_messages = self.convert_messages(messages) def api_call(): - return self.client.messages.create( - messages=converted_messages, - model=self.model_id, - temperature=self.client_kwargs.get("temperature", 0.5), - max_tokens=self.client_kwargs.get("max_tokens", 1024), - ) + # Create kwargs for the API call + api_kwargs = { + "messages": converted_messages, + "model": self.model_id, + "max_tokens": self.client_kwargs.get("max_tokens", 1024), + } + + # Only include temperature if it's not None + temperature = self.client_kwargs.get("temperature") + if temperature is not None: + api_kwargs["temperature"] = temperature + + return self.client.messages.create(**api_kwargs) response = self.execute_with_retries(api_call) diff --git a/balrog/config/config.yaml b/balrog/config/config.yaml index 31a26ce3..5a48f34c 100644 --- a/balrog/config/config.yaml +++ b/balrog/config/config.yaml @@ -30,7 +30,7 @@ client: model_id: gpt-4o # Model identifier (e.g., 'gpt-4', 'gpt-3.5-turbo') base_url: http://localhost:8080/v1 # Base URL for the API (if using a local server) generate_kwargs: - temperature: 0.0 # Sampling temperature; 0.0 makes the output deterministic + temperature: null # Sampling temperature. If null the API default temperature is used instead max_tokens: 4096 # Max tokens to generate in the response timeout: 60 # Timeout for API requests in seconds max_retries: 5 # Max number of retries for failed API calls diff --git a/docs/evaluation.md b/docs/evaluation.md index 28bf06d3..870ebd03 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -93,6 +93,7 @@ python eval.py \ | **client.is_chat_model** | Indicates if the model follows a chat-based interface. | `True` | | **client.generate_kwargs.temperature** | Temperature for model response randomness. | `0.0` | | **client.alternate_roles** | If True the instruction prompt will be fused with first observation. Required by some LLMs. | `False` | +| **client.temperature** | If set to null will default to the API default temperature. Use a float from 0.0 to 1.0. otherwise. | `null` | | **envs.names** | Dash-separated list of environments to evaluate, e.g., `nle-minihack`. | `babyai-babaisai-textworld-crafter-nle-minihack`| @@ -103,3 +104,5 @@ python eval.py \ Mac systems might complain about fork when evaluating in multiprocessing mode (`eval.num_workers > 1`). To fix this export the following before running eval: `export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES` - Alternate roles: Some LLMs/VLMs require alternating roles. You can fuse the instruction prompt with the first observation to comply with this with the following: `client.alternate_roles=True` +- Temperature: + We recommend running models with temperature ranges around 0.5-0.7, or to use the default temperature of the model APIs. Too low temperatures can cause some of the more brittle models to endlessly repeat actions or create incoherent outputs.