diff --git a/balrog/client.py b/balrog/client.py index ccabe97d..df387f48 100644 --- a/balrog/client.py +++ b/balrog/client.py @@ -147,9 +147,9 @@ def _initialize_client(self): if not self._initialized: if self.client_name.lower() == "vllm": self.client = OpenAI(api_key="EMPTY", base_url=self.base_url) - elif self.client_name.lower() == "nvidia": + elif self.client_name.lower() == "nvidia" or self.client_name.lower() == "xai": if not self.base_url or not self.base_url.strip(): - raise ValueError("base_url must be provided when using NVIDIA client") + raise ValueError("base_url must be provided when using NVIDIA or XAI client") self.client = OpenAI(base_url=self.base_url) elif self.client_name.lower() == "openai": # For OpenAI, always use the standard API regardless of base_url @@ -504,8 +504,8 @@ def create_llm_client(client_config): def client_factory(): client_name_lower = client_config.client_name.lower() - if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower: - # NVIDIA uses OpenAI-compatible API, so we use the OpenAI wrapper + if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower or "xai" in client_name_lower: + # NVIDIA and XAI use OpenAI-compatible API, so we use the OpenAI wrapper return OpenAIWrapper(client_config) elif "gemini" in client_name_lower: return GoogleGenerativeAIWrapper(client_config)