Skip to content

Commit aa568d1

Browse files
feat: xai client
1 parent ec49c9b commit aa568d1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

balrog/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def _initialize_client(self):
147147
if not self._initialized:
148148
if self.client_name.lower() == "vllm":
149149
self.client = OpenAI(api_key="EMPTY", base_url=self.base_url)
150-
elif self.client_name.lower() == "nvidia":
150+
elif self.client_name.lower() == "nvidia" or self.client_name.lower() == "xai":
151151
if not self.base_url or not self.base_url.strip():
152-
raise ValueError("base_url must be provided when using NVIDIA client")
152+
raise ValueError("base_url must be provided when using NVIDIA or XAI client")
153153
self.client = OpenAI(base_url=self.base_url)
154154
elif self.client_name.lower() == "openai":
155155
# For OpenAI, always use the standard API regardless of base_url
@@ -504,8 +504,8 @@ def create_llm_client(client_config):
504504

505505
def client_factory():
506506
client_name_lower = client_config.client_name.lower()
507-
if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower:
508-
# NVIDIA uses OpenAI-compatible API, so we use the OpenAI wrapper
507+
if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower or "xai" in client_name_lower:
508+
# NVIDIA and XAI use OpenAI-compatible API, so we use the OpenAI wrapper
509509
return OpenAIWrapper(client_config)
510510
elif "gemini" in client_name_lower:
511511
return GoogleGenerativeAIWrapper(client_config)

0 commit comments

Comments
 (0)