@@ -147,9 +147,9 @@ def _initialize_client(self):
147147 if not self ._initialized :
148148 if self .client_name .lower () == "vllm" :
149149 self .client = OpenAI (api_key = "EMPTY" , base_url = self .base_url )
150- elif self .client_name .lower () == "nvidia" :
150+ elif self .client_name .lower () == "nvidia" or self . client_name . lower () == "xai" :
151151 if not self .base_url or not self .base_url .strip ():
152- raise ValueError ("base_url must be provided when using NVIDIA client" )
152+ raise ValueError ("base_url must be provided when using NVIDIA or XAI client" )
153153 self .client = OpenAI (base_url = self .base_url )
154154 elif self .client_name .lower () == "openai" :
155155 # For OpenAI, always use the standard API regardless of base_url
@@ -504,8 +504,8 @@ def create_llm_client(client_config):
504504
505505 def client_factory ():
506506 client_name_lower = client_config .client_name .lower ()
507- if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower :
508- # NVIDIA uses OpenAI-compatible API, so we use the OpenAI wrapper
507+ if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower or "xai" in client_name_lower :
508+ # NVIDIA and XAI use OpenAI-compatible API, so we use the OpenAI wrapper
509509 return OpenAIWrapper (client_config )
510510 elif "gemini" in client_name_lower :
511511 return GoogleGenerativeAIWrapper (client_config )
0 commit comments