Skip to content

Commit e006561

Browse files
Feat/nvidia (#42)
* feat: default to api temperature * chore: docs * feat: nvidia API support
1 parent 4a3ca72 commit e006561

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

balrog/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ def _initialize_client(self):
144144
if not self._initialized:
145145
if self.client_name.lower() == "vllm":
146146
self.client = OpenAI(api_key="EMPTY", base_url=self.base_url)
147+
elif self.client_name.lower() == "nvidia":
148+
if not self.base_url or not self.base_url.strip():
149+
raise ValueError("base_url must be provided when using NVIDIA client")
150+
self.client = OpenAI(base_url=self.base_url)
147151
elif self.client_name.lower() == "openai":
152+
# For OpenAI, always use the standard API regardless of base_url
148153
self.client = OpenAI()
149154
self._initialized = True
150155

@@ -471,7 +476,8 @@ def create_llm_client(client_config):
471476

472477
def client_factory():
473478
client_name_lower = client_config.client_name.lower()
474-
if "openai" in client_name_lower or "vllm" in client_name_lower:
479+
if "openai" in client_name_lower or "vllm" in client_name_lower or "nvidia" in client_name_lower:
480+
# NVIDIA uses OpenAI-compatible API, so we use the OpenAI wrapper
475481
return OpenAIWrapper(client_config)
476482
elif "gemini" in client_name_lower:
477483
return GoogleGenerativeAIWrapper(client_config)

0 commit comments

Comments
 (0)