Skip to content

Commit f3a3f66

Browse files
feat: nvidia API support
1 parent 32cfcde commit f3a3f66

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

balrog/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ def _initialize_client(self):
144144
if not self._initialized:
145145
if self.client_name.lower() == "vllm":
146146
self.client = OpenAI(api_key="EMPTY", base_url=self.base_url)
147+
elif self.client_name.lower() == "nvidia":
148+
if not self.base_url or not self.base_url.strip():
149+
raise ValueError("base_url must be provided when using NVIDIA client")
150+
self.client = OpenAI(base_url=self.base_url)
147151
elif self.client_name.lower() == "openai":
152+
# For OpenAI, always use the standard API regardless of base_url
148153
self.client = OpenAI()
149154
self._initialized = True
150155

@@ -462,7 +467,8 @@ def create_llm_client(client_config):
462467

463468
def client_factory():
464469
client_name_lower = client_config.client_name.lower()
465-
if "openai" in client_name_lower or "vllm" in client_name_lower:
470+
if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower:
471+
# NVIDIA uses OpenAI-compatible API, so we use the OpenAI wrapper
466472
return OpenAIWrapper(client_config)
467473
elif "gemini" in client_name_lower:
468474
return GoogleGenerativeAIWrapper(client_config)

0 commit comments

Comments
 (0)