Skip to content

Commit 4065d44

Browse files
authored
Add NVIDIA provider, and improve declarative provider UX (#8798)
Signed-off-by: jh-block <jhugo@block.xyz>
1 parent 8f16ec6 commit 4065d44

9 files changed

Lines changed: 154 additions & 21 deletions

File tree

crates/goose/src/config/declarative_providers.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ pub struct DeclarativeProviderConfig {
7777
#[serde(default)]
7878
pub skip_canonical_filtering: bool,
7979
#[serde(default, deserialize_with = "deserialize_non_empty_string")]
80+
pub model_doc_link: Option<String>,
81+
#[serde(default)]
82+
pub setup_steps: Vec<String>,
83+
#[serde(default, deserialize_with = "deserialize_non_empty_string")]
8084
pub fast_model: Option<String>,
8185
}
8286

@@ -233,6 +237,8 @@ pub fn create_custom_provider(
233237
env_vars: None,
234238
dynamic_models: None,
235239
skip_canonical_filtering: false,
240+
model_doc_link: None,
241+
setup_steps: vec![],
236242
fast_model: None,
237243
};
238244

@@ -300,6 +306,8 @@ pub fn update_custom_provider(params: UpdateCustomProviderParams) -> Result<()>
300306
env_vars: existing_config.env_vars,
301307
dynamic_models: existing_config.dynamic_models,
302308
skip_canonical_filtering: existing_config.skip_canonical_filtering,
309+
model_doc_link: existing_config.model_doc_link,
310+
setup_steps: existing_config.setup_steps,
303311
fast_model: existing_config.fast_model.clone(),
304312
};
305313

@@ -587,6 +595,33 @@ mod tests {
587595
serde_json::from_str(json).expect("groq.json should parse without env_vars");
588596
assert!(config.env_vars.is_none());
589597
assert!(config.dynamic_models.is_none());
598+
assert!(config.model_doc_link.is_none());
599+
assert!(config.setup_steps.is_empty());
600+
}
601+
602+
#[test]
603+
fn test_nvidia_json_deserializes() {
604+
let json = include_str!("../providers/declarative/nvidia.json");
605+
let config: DeclarativeProviderConfig =
606+
serde_json::from_str(json).expect("nvidia.json should parse");
607+
assert_eq!(config.name, "nvidia");
608+
assert_eq!(config.display_name, "NVIDIA");
609+
assert!(matches!(config.engine, ProviderEngine::OpenAI));
610+
assert_eq!(config.api_key_env, "NVIDIA_API_KEY");
611+
assert_eq!(config.base_url, "https://integrate.api.nvidia.com/v1");
612+
assert_eq!(config.catalog_provider_id, Some("nvidia".to_string()));
613+
assert_eq!(config.dynamic_models, Some(true));
614+
assert_eq!(config.supports_streaming, Some(true));
615+
assert!(!config.skip_canonical_filtering);
616+
assert_eq!(
617+
config.model_doc_link,
618+
Some("https://build.nvidia.com/models".to_string())
619+
);
620+
assert_eq!(config.setup_steps.len(), 4);
621+
622+
assert_eq!(config.models.len(), 1);
623+
assert_eq!(config.models[0].name, "z-ai/glm-4.7");
624+
assert_eq!(config.models[0].context_limit, 131072);
590625
}
591626

592627
#[test]

crates/goose/src/model.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ impl ModelConfig {
158158
self.context_limit = Some(canonical.limit.context);
159159
}
160160
if self.max_tokens.is_none() {
161-
self.max_tokens = canonical.limit.output.map(|o| o as i32);
161+
self.max_tokens = canonical
162+
.limit
163+
.output
164+
.filter(|&output| output < canonical.limit.context)
165+
.map(|output| output as i32);
162166
}
163167
if self.reasoning.is_none() {
164168
self.reasoning = canonical.reasoning;
@@ -491,6 +495,20 @@ mod tests {
491495
assert_eq!(config.max_tokens, Some(1_000));
492496
}
493497

498+
#[test]
499+
fn skips_canonical_output_limit_when_it_equals_context_limit() {
500+
let _guard = env_lock::lock_env([
501+
("GOOSE_MAX_TOKENS", None::<&str>),
502+
("GOOSE_CONTEXT_LIMIT", None::<&str>),
503+
]);
504+
let config =
505+
ModelConfig::new_or_fail("moonshotai/kimi-k2.5").with_canonical_limits("nvidia");
506+
507+
assert_eq!(config.context_limit, Some(262_144));
508+
assert_eq!(config.max_tokens, None);
509+
assert_eq!(config.max_output_tokens(), 4_096);
510+
}
511+
494512
#[test]
495513
fn unknown_model_leaves_fields_none() {
496514
let _guard = env_lock::lock_env([
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"name": "nvidia",
3+
"engine": "openai",
4+
"display_name": "NVIDIA",
5+
"description": "Hosted NVIDIA NIM models through the OpenAI-compatible API.",
6+
"api_key_env": "NVIDIA_API_KEY",
7+
"base_url": "https://integrate.api.nvidia.com/v1",
8+
"catalog_provider_id": "nvidia",
9+
"dynamic_models": true,
10+
"models": [
11+
{
12+
"name": "z-ai/glm-4.7",
13+
"context_limit": 131072
14+
}
15+
],
16+
"supports_streaming": true,
17+
"model_doc_link": "https://build.nvidia.com/models",
18+
"setup_steps": [
19+
"Sign in to https://build.nvidia.com",
20+
"Choose a Free Endpoint model from the model catalog",
21+
"Create an API key",
22+
"Copy the key and paste it above"
23+
]
24+
}

crates/goose/src/providers/init.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,41 @@ mod tests {
238238
assert!(!endpoint.secret, "Endpoint should not be secret");
239239
}
240240

241+
#[tokio::test]
242+
async fn test_nvidia_declarative_provider_registry_wiring() {
243+
let nvidia = get_from_registry("nvidia")
244+
.await
245+
.expect("nvidia provider should be registered");
246+
let meta = nvidia.metadata();
247+
248+
assert_eq!(nvidia.provider_type(), ProviderType::Declarative);
249+
assert!(nvidia.supports_inventory_refresh());
250+
assert_eq!(meta.display_name, "NVIDIA");
251+
assert_eq!(meta.default_model, "z-ai/glm-4.7");
252+
assert_eq!(meta.model_doc_link, "https://build.nvidia.com/models");
253+
assert!(!meta.setup_steps.is_empty());
254+
255+
let api_key = meta
256+
.config_keys
257+
.iter()
258+
.find(|k| k.name == "NVIDIA_API_KEY")
259+
.expect("NVIDIA_API_KEY config key should exist");
260+
assert!(api_key.required, "NVIDIA_API_KEY should be required");
261+
assert!(api_key.secret, "NVIDIA_API_KEY should be secret");
262+
assert!(api_key.primary, "NVIDIA_API_KEY should be primary");
263+
assert!(
264+
!meta.config_keys.iter().any(|k| k.name == "OPENAI_HOST"),
265+
"NVIDIA should not expose OpenAI host configuration"
266+
);
267+
assert!(
268+
!meta
269+
.config_keys
270+
.iter()
271+
.any(|k| k.name == "OPENAI_BASE_PATH"),
272+
"NVIDIA should not expose OpenAI base path configuration"
273+
);
274+
}
275+
241276
#[tokio::test]
242277
async fn test_openai_compatible_providers_config_keys() {
243278
let providers_list = providers().await;

crates/goose/src/providers/provider_registry.rs

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::base::{ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderType};
1+
use super::base::{ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderType};
22
use super::inventory::InventoryIdentityInput;
33
use crate::config::{DeclarativeProviderConfig, ExtensionConfig};
44
use crate::model::ModelConfig;
@@ -165,28 +165,32 @@ impl ProviderRegistry {
165165
})
166166
.collect();
167167

168-
let mut config_keys = base_metadata.config_keys.clone();
169-
170-
if let Some(api_key_index) = config_keys.iter().position(|key| key.secret) {
171-
if !config.requires_auth {
172-
config_keys.remove(api_key_index);
173-
} else if !config.api_key_env.is_empty() {
174-
let api_key_required = provider_type == ProviderType::Declarative;
175-
config_keys[api_key_index] = super::base::ConfigKey::new(
176-
&config.api_key_env,
177-
api_key_required,
178-
true,
179-
None,
180-
true,
181-
);
168+
let mut config_keys = if provider_type == ProviderType::Declarative {
169+
if config.requires_auth && !config.api_key_env.is_empty() {
170+
vec![ConfigKey::new(&config.api_key_env, true, true, None, true)]
171+
} else {
172+
Vec::new()
182173
}
183-
}
174+
} else {
175+
let mut config_keys = base_metadata.config_keys.clone();
176+
177+
if let Some(api_key_index) = config_keys.iter().position(|key| key.secret) {
178+
if !config.requires_auth {
179+
config_keys.remove(api_key_index);
180+
} else if !config.api_key_env.is_empty() {
181+
config_keys[api_key_index] =
182+
ConfigKey::new(&config.api_key_env, false, true, None, true);
183+
}
184+
}
185+
186+
config_keys
187+
};
184188

185189
if let Some(ref env_vars) = config.env_vars {
186190
for ev in env_vars {
187191
// Default primary to `required` so required fields show prominently in the UI
188192
let primary = ev.primary.unwrap_or(ev.required);
189-
config_keys.push(super::base::ConfigKey::new(
193+
config_keys.push(ConfigKey::new(
190194
&ev.name,
191195
ev.required,
192196
ev.secret,
@@ -202,9 +206,12 @@ impl ProviderRegistry {
202206
description,
203207
default_model,
204208
known_models,
205-
model_doc_link: base_metadata.model_doc_link,
209+
model_doc_link: config
210+
.model_doc_link
211+
.clone()
212+
.unwrap_or(base_metadata.model_doc_link),
206213
config_keys,
207-
setup_steps: vec![],
214+
setup_steps: config.setup_steps.clone(),
208215
model_selection_hint: None,
209216
};
210217
let inventory_config_keys = custom_metadata.config_keys.clone();

ui/desktop/openapi.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4643,6 +4643,10 @@
46434643
},
46444644
"nullable": true
46454645
},
4646+
"model_doc_link": {
4647+
"type": "string",
4648+
"nullable": true
4649+
},
46464650
"models": {
46474651
"type": "array",
46484652
"items": {
@@ -4655,6 +4659,12 @@
46554659
"requires_auth": {
46564660
"type": "boolean"
46574661
},
4662+
"setup_steps": {
4663+
"type": "array",
4664+
"items": {
4665+
"type": "string"
4666+
}
4667+
},
46584668
"skip_canonical_filtering": {
46594669
"type": "boolean"
46604670
},

ui/desktop/src/api/types.gen.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,11 @@ export type DeclarativeProviderConfig = {
220220
headers?: {
221221
[key: string]: string;
222222
} | null;
223+
model_doc_link?: string | null;
223224
models: Array<ModelInfo>;
224225
name: string;
225226
requires_auth?: boolean;
227+
setup_steps?: Array<string>;
226228
skip_canonical_filtering?: boolean;
227229
supports_streaming?: boolean | null;
228230
timeout_seconds?: number | null;

ui/desktop/src/components/ConfigContext.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
178178
try {
179179
const response = await providers();
180180
const providersData = response.data || [];
181+
providersListRef.current = providersData;
181182
setProvidersList(providersData);
182183
return providersData;
183184
} catch (error) {
@@ -199,6 +200,7 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
199200
try {
200201
const providersResponse = await providers();
201202
const providersData = providersResponse.data || [];
203+
providersListRef.current = providersData;
202204
setProvidersList(providersData);
203205
} catch (error) {
204206
console.error('Failed to load providers:', error);

ui/desktop/src/components/settings/providers/ProviderGrid.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function ProviderCards({
117117

118118
const configureProviderViaModal = useCallback(
119119
async (provider: ProviderDetails) => {
120-
if (provider.provider_type === 'Custom' || provider.provider_type === 'Declarative') {
120+
if (provider.provider_type === 'Custom') {
121121
const { getCustomProvider } = await import('../../../api');
122122
const result = await getCustomProvider({ path: { id: provider.name }, throwOnError: true });
123123

0 commit comments

Comments
 (0)