Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion balrog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def _initialize_client(self):
if not self._initialized:
if self.client_name.lower() == "vllm":
self.client = OpenAI(api_key="EMPTY", base_url=self.base_url)
elif self.client_name.lower() == "nvidia":
if not self.base_url or not self.base_url.strip():
raise ValueError("base_url must be provided when using NVIDIA client")
self.client = OpenAI(base_url=self.base_url)
elif self.client_name.lower() == "openai":
# For OpenAI, always use the standard API regardless of base_url
self.client = OpenAI()
self._initialized = True

Expand Down Expand Up @@ -462,7 +467,8 @@ def create_llm_client(client_config):

def client_factory():
client_name_lower = client_config.client_name.lower()
if "openai" in client_name_lower or "vllm" in client_name_lower:
if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower:
# NVIDIA uses OpenAI-compatible API, so we use the OpenAI wrapper
return OpenAIWrapper(client_config)
elif "gemini" in client_name_lower:
return GoogleGenerativeAIWrapper(client_config)
Expand Down