Skip to content

Commit e064a46

Browse files
lukekimCopilot
andauthored
Improve NSQL UX and harden internal LLM tools (sampling SQL quoting, tool descriptions) (spiceai#10715)
* Improve NSQL defaults and add datetime tool * feat: update tool descriptions and add current datetime function to snapshots * refactor(tests): improve variable naming for clarity in tools_by_name test * fix: remove dereference operator for tool_name in all_available_tools function * feat: enhance tool descriptions for clarity and detail across multiple tools Co-authored-by: Copilot <copilot@github.com> * feat: update tool descriptions for improved clarity and detail in Spice.ai runtime * feat: enhance tool descriptions for clarity and detail in snapshots * feat: clarify SQL query description in SqlToolParams struct * feat: update SQL query description for consistency in Spice.ai SQL Dialect Co-authored-by: Copilot <copilot@github.com> --------- Co-authored-by: Copilot <copilot@github.com>
1 parent 10e80bc commit e064a46

24 files changed

Lines changed: 608 additions & 243 deletions

crates/runtime/src/http/v1/nsql.rs

Lines changed: 148 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use arrow::array::RecordBatch;
5353
use itertools::Itertools;
5454
use llms::chat::nsql::{FailedAttempt, QueryGenerationContext, default::DefaultSqlGeneration};
5555
use serde::{Deserialize, Serialize};
56+
use spicepod::component::model::ModelType;
5657
use std::{sync::Arc, time::Duration};
5758
use tokio::sync::RwLock;
5859
use tracing::Span;
@@ -138,15 +139,15 @@ pub struct Request {
138139
/// The natural language query to be converted into SQL
139140
pub query: String,
140141

141-
/// The name of the model to use for SQL generation. Default: "nql"
142-
#[serde(default = "default_model")]
143-
pub model: String,
142+
/// The name of the model to use for SQL generation. If omitted, Spice defaults to the only compatible LLM model configured in the Spicepod.
143+
#[serde(skip_serializing_if = "Option::is_none")]
144+
pub model: Option<String>,
144145

145146
/// If true, streams the response instead of waiting for completion
146147
#[serde(default)]
147148
pub stream: bool,
148149

149-
/// Whether sample data is included in the context for SQL generation. Default: true
150+
/// Whether sample data is included in the context for SQL generation. Default: false
150151
#[serde(default = "default_sample_data_enabled")]
151152
pub sample_data_enabled: bool,
152153

@@ -160,11 +161,7 @@ pub struct Request {
160161
}
161162

162163
fn default_sample_data_enabled() -> bool {
163-
true
164-
}
165-
166-
fn default_model() -> String {
167-
"nql".to_string()
164+
false
168165
}
169166

170167
/// Checks if the request is asking to only generate SQL.
@@ -193,9 +190,8 @@ fn return_sql_only(accept: Option<&TypedHeader<Accept>>) -> bool {
193190
Request = "application/json",
194191
example = json!({
195192
"query": "Get the top 5 customers by total sales",
196-
"model": "nql",
197193
"stream": false,
198-
"sample_data_enabled": true,
194+
"sample_data_enabled": false,
199195
"datasets": ["sales_data"],
200196
"prompt_cache_key": "sales-dashboard"
201197
})
@@ -323,12 +319,18 @@ pub(crate) async fn handle_nsql_query(
323319

324320
let Request {
325321
query,
326-
model,
322+
model: requested_model,
327323
sample_data_enabled,
328324
datasets,
329325
prompt_cache_key,
330326
..
331327
} = payload;
328+
329+
let model = match resolve_nsql_model_name(requested_model, &rt).await {
330+
Ok(model) => model,
331+
Err((status, message)) => return (status, headers, message),
332+
};
333+
332334
let table_allowlist_opt = match table_allowlist(&model, &rt).await {
333335
Ok(ta) => ta,
334336
Err(e) => {
@@ -405,16 +407,20 @@ pub(crate) async fn handle_nsql_query(
405407
vec![]
406408
};
407409

408-
let models = llms.read().await;
409-
let Some(nql_model) = models.get(&model) else {
410-
return (
411-
StatusCode::BAD_REQUEST,
412-
headers,
413-
format!("Model {model} not found"),
414-
);
410+
let nql_model = {
411+
let models = llms.read().await;
412+
let Some(nql_model) = models.get(&model) else {
413+
return (
414+
StatusCode::BAD_REQUEST,
415+
headers,
416+
format!("Model {model} not found"),
417+
);
418+
};
419+
Arc::clone(nql_model)
415420
};
416421

417-
let sql_gen = nql_model.as_sql().unwrap_or(&DefaultSqlGeneration {});
422+
let default_sql_generation = DefaultSqlGeneration {};
423+
let sql_gen = nql_model.as_sql().unwrap_or(&default_sql_generation);
418424
// Tracks previously generated queries and associated errors to enable an efficient retry mechanism
419425
let mut sql_gen_ctx = QueryGenerationContext::default();
420426
let mut num_retries = 0;
@@ -555,6 +561,49 @@ pub(crate) async fn handle_nsql_query(
555561
}
556562
}
557563

564+
async fn resolve_nsql_model_name(
565+
requested_model: Option<String>,
566+
rt: &Arc<Runtime>,
567+
) -> Result<String, (StatusCode, String)> {
568+
if let Some(model) = requested_model {
569+
return Ok(model);
570+
}
571+
572+
let Some(app) = rt.read_app().await else {
573+
return Err((
574+
StatusCode::INTERNAL_SERVER_ERROR,
575+
"Unexpected internal error. App not prepared in runtime.".to_string(),
576+
));
577+
};
578+
579+
resolve_nsql_model_name_from_app(app.as_ref())
580+
.map_err(|message| (StatusCode::BAD_REQUEST, message))
581+
}
582+
583+
fn resolve_nsql_model_name_from_app(app: &app::App) -> Result<String, String> {
584+
let compatible_models = compatible_nsql_model_names(app);
585+
586+
match compatible_models.as_slice() {
587+
[] => Err(
588+
"No model specified and no compatible LLM model is configured. Add exactly one LLM model to the Spicepod or include the 'model' field in the request."
589+
.to_string(),
590+
),
591+
[model] => Ok(model.clone()),
592+
models => Err(format!(
593+
"No model specified and multiple compatible LLM models are configured ({}). Include the 'model' field in the request.",
594+
models.join(", ")
595+
)),
596+
}
597+
}
598+
599+
fn compatible_nsql_model_names(app: &app::App) -> Vec<String> {
600+
app.models
601+
.iter()
602+
.filter(|model| model.model_type() == Some(ModelType::Llm))
603+
.map(|model| model.name.clone())
604+
.collect()
605+
}
606+
558607
/// Construct a [`ResolvedTableAwareAllowlist`] based on the `App`'s `model.datasets`.
559608
async fn table_allowlist(
560609
model_name: &str,
@@ -591,3 +640,82 @@ async fn table_allowlist(
591640
};
592641
Ok(table_allowlist)
593642
}
643+
644+
#[cfg(test)]
645+
mod tests {
646+
use super::*;
647+
use app::AppBuilder;
648+
use serde_json::json;
649+
use spicepod::component::model::Model;
650+
651+
fn app_with_models(models: Vec<Model>) -> app::App {
652+
let mut builder = AppBuilder::new("test");
653+
for model in models {
654+
builder = builder.with_model(model);
655+
}
656+
builder.build()
657+
}
658+
659+
#[test]
660+
fn request_defaults_to_no_model_and_no_sample_data() {
661+
let request: Request = serde_json::from_value(json!({
662+
"query": "show total sales"
663+
}))
664+
.expect("request should deserialize with omitted optional fields");
665+
666+
assert_eq!(request.model, None);
667+
assert!(!request.sample_data_enabled);
668+
}
669+
670+
#[test]
671+
fn omitted_model_uses_single_compatible_model() {
672+
let app = app_with_models(vec![Model::new("openai:gpt-4o-mini", "llm_model")]);
673+
674+
let model_name = resolve_nsql_model_name_from_app(&app)
675+
.expect("single compatible model should be selected");
676+
677+
assert_eq!(model_name, "llm_model");
678+
}
679+
680+
#[test]
681+
fn omitted_model_ignores_non_llm_models() {
682+
let app = app_with_models(vec![
683+
Model::new("spiceai:my-org/my-app/models/runnable", "ml_model"),
684+
Model::new("openai:gpt-4o-mini", "llm_model"),
685+
]);
686+
687+
let model_name = resolve_nsql_model_name_from_app(&app)
688+
.expect("single compatible model should be selected");
689+
690+
assert_eq!(model_name, "llm_model");
691+
}
692+
693+
#[test]
694+
fn omitted_model_errors_when_no_compatible_model_exists() {
695+
let app = app_with_models(vec![]);
696+
697+
let error = resolve_nsql_model_name_from_app(&app)
698+
.expect_err("omitted model should fail without compatible models");
699+
700+
assert_eq!(
701+
error,
702+
"No model specified and no compatible LLM model is configured. Add exactly one LLM model to the Spicepod or include the 'model' field in the request."
703+
);
704+
}
705+
706+
#[test]
707+
fn omitted_model_errors_when_multiple_compatible_models_exist() {
708+
let app = app_with_models(vec![
709+
Model::new("openai:gpt-4o-mini", "first_model"),
710+
Model::new("openai:gpt-4o", "second_model"),
711+
]);
712+
713+
let error = resolve_nsql_model_name_from_app(&app)
714+
.expect_err("omitted model should fail with multiple compatible models");
715+
716+
assert_eq!(
717+
error,
718+
"No model specified and multiple compatible LLM models are configured (first_model, second_model). Include the 'model' field in the request."
719+
);
720+
}
721+
}

crates/runtime/src/http/v1/tools.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ pub(crate) struct SearchToolsQuery {
7272
status = 200, body = [ListToolElement],
7373
description = "All tools available in the Spice runtime",
7474
example = json!([
75-
{"name": "get_readiness", "description": "Retrieves the readiness status of all runtime components including registered datasets, models, and embeddings.", "parameters": null},
76-
{"name": "list_datasets", "description": "List all SQL tables available.", "parameters": null}
75+
{"name": "get_readiness", "description": "Report the readiness state of every Spice runtime component (datasets, accelerators, models, embeddings, catalogs).", "parameters": null},
76+
{"name": "list_datasets", "description": "List every dataset, view, and catalog visible to this runtime.", "parameters": null}
7777
])
7878
),
7979
(
@@ -110,7 +110,7 @@ pub(crate) async fn list(Extension(rt): Extension<Arc<Runtime>>) -> Response {
110110
example = json!([
111111
{"name": "tool_search", "description": "Search the Spice tool registry for tools relevant to the current task.", "parameters": {"type": "object"}},
112112
{"name": "tool_invoke", "description": "Invoke one Spice tool returned by tool_search.", "parameters": {"type": "object"}},
113-
{"name": "list_datasets", "description": "List all SQL tables available.", "parameters": null}
113+
{"name": "list_datasets", "description": "List every dataset, view, and catalog visible to this runtime.", "parameters": null}
114114
])
115115
),
116116
(status = 400, description = "Searchable tool registry is not configured", body = serde_json::Value),

crates/runtime/src/model/params/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ pub const PARAM_WITH_DEPRE_LEN: usize = 52;
5757
pub const COMMON_MODEL_PARAMETERS: [ParameterSpec; PARAM_LEN] = [
5858
// Common parameters for all models
5959
ParameterSpec::runtime("tools")
60-
.description("Which tools should be made available to the model. Set to 'auto' to automatically choose between direct tools and searchable discovery, 'all' to use built-in and Spicepod-configured tools directly, or 'search_registry' to require searchable tool discovery."),
60+
.description("Which tools should be made available to the model. Set to 'auto' to automatically choose between direct tools and searchable discovery without data sampling tools, 'all' to use built-in and Spicepod-configured tools directly, or 'search_registry' to require searchable tool discovery."),
6161
ParameterSpec::runtime("tool_embedding_model")
6262
.description("Embedding model name to use for searchable tool discovery. tools: search_registry requires a model configured in the embeddings section and uses it when only one embedding model is configured; tools: auto falls back to direct tools if embeddings are unavailable."),
6363
ParameterSpec::runtime("system_prompt")
@@ -120,7 +120,7 @@ pub const COMMON_MODEL_PARAMETERS: [ParameterSpec; PARAM_LEN] = [
120120
pub const COMMON_MODEL_PARAMETERS_WITH_DEPRECATED: [ParameterSpec; PARAM_WITH_DEPRE_LEN] = [
121121
// Common parameters for all models
122122
ParameterSpec::runtime("tools")
123-
.description("Which tools should be made available to the model. Set to 'auto' to automatically choose between direct tools and searchable discovery, 'all' to use built-in and Spicepod-configured tools directly, or 'search_registry' to require searchable tool discovery."),
123+
.description("Which tools should be made available to the model. Set to 'auto' to automatically choose between direct tools and searchable discovery without data sampling tools, 'all' to use built-in and Spicepod-configured tools directly, or 'search_registry' to require searchable tool discovery."),
124124
ParameterSpec::runtime("tool_embedding_model")
125125
.description("Embedding model name to use for searchable tool discovery. tools: search_registry requires a model configured in the embeddings section and uses it when only one embedding model is configured; tools: auto falls back to direct tools if embeddings are unavailable."),
126126
ParameterSpec::runtime("system_prompt")

crates/runtime/src/model/tool_use.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use tracing::{Instrument, Span};
4646

4747
use crate::Runtime;
4848
use crate::model::ModelContextExtension;
49+
use crate::tools::utils::tool_call_error_response;
4950
use llms::progress::Progress;
5051
use runtime_request_context::{AsyncMarker, RequestContext};
5152

@@ -107,10 +108,17 @@ impl ToolUsingChat {
107108
&self,
108109
list_datasets: &Arc<dyn SpiceModelTool>,
109110
) -> Result<Vec<ChatCompletionRequestMessage>, OpenAIError> {
110-
let t_resp = list_datasets
111-
.call("")
112-
.await
113-
.map_err(|e| OpenAIError::InvalidArgument(e.to_string()))?;
111+
let t_resp = match list_datasets.call("").await {
112+
Ok(resp) => resp,
113+
Err(e) => {
114+
let tool_name = list_datasets.name();
115+
let error = e.to_string();
116+
tracing::warn!(
117+
"Tool '{tool_name}' failed while creating initial tool-use messages: {error}"
118+
);
119+
tool_call_error_response(tool_name.as_ref(), error)
120+
}
121+
};
114122
Ok(vec![
115123
ChatCompletionRequestAssistantMessageArgs::default()
116124
.tool_calls(vec![ChatCompletionMessageToolCalls::Function(
@@ -166,10 +174,7 @@ impl ToolUsingChat {
166174
.content(e.to_string())
167175
.to_jsonl(),
168176
);
169-
Value::String(format!(
170-
"Failed to call the tool {}.\nAn error occurred: {e}",
171-
t.name()
172-
))
177+
tool_call_error_response(t.name().as_ref(), e)
173178
}
174179
},
175180
None => {

0 commit comments

Comments
 (0)