From 9a0b4f8b2766bd23ec72af1630413f4cf93e30ef Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 4 May 2026 07:44:08 -0700 Subject: [PATCH 1/6] Add searchable registry mode for LLM tools (#10647) * feat(tools): add tool registry support and enhance tool usage options Co-authored-by: Copilot * style: format code for better readability in ToolDocument and tokenization functions * feat(tools): add support for tool embedding model in registry and responses Co-authored-by: Copilot * feat(tools): enhance tool options and registry embedding model handling Co-authored-by: Copilot * feat(rrf): integrate DEFAULT_RRF_K for configurable ranking and enhance fusion scoring Co-authored-by: Copilot * feat(tools): enhance tool handling with new options for searchable discovery and auto-selection Co-authored-by: Copilot * feat(tools): add 'All' option for direct access to all builtin tools Co-authored-by: Copilot * feat(tools): update descriptions for tool options and enhance tool retrieval logic * feat(tools): enhance error handling for table allowlist creation and update tool descriptions * feat(tools): improve embedding model selection logic for searchable tool discovery Co-authored-by: Copilot * feat(tools): add searchable tool registry endpoint and enhance tool retrieval logic Co-authored-by: Copilot * feat(tools): add optional embedding model parameter to searchable tool registry endpoint * feat(tools): rename tool_embedding_model to embedding_model in SearchToolsQuery * feat(tools): refactor table allowlist creation and enhance tool document handling Co-authored-by: Copilot * feat(tools): refactor table allowlist creation and enhance error handling in tool registry * feat(tests): simplify mock tool parameters in test cases Co-authored-by: Copilot * refactor(tools): improve tool registry and utility functions for better clarity and performance Co-authored-by: Copilot * feat(tools): enhance tool registry with caching and improve tool retrieval logic Co-authored-by: Copilot * ci: guard Install protoc step with relevant_changes check (#10646) The Install protoc step in the Build Test Binary job referenced the local ./.github/actions/install-protoc action but lacked the relevant_changes conditional guard that gates the surrounding steps. When a PR contains no relevant code changes (e.g. openapi.json or spicepod schema only), actions/checkout is skipped, leaving the workspace empty. The Install protoc step then fails immediately at job setup with "Can't find action.yml... install-protoc". This blocks all non-Rust PRs from completing the integration test workflow. Adding the same `if: needs.check_changes.outputs.relevant_changes == 'true'` guard restores the intended skip-on-no-relevant-changes behavior. * fix(benchmarks): redact non-deterministic partition_sizes in explain plan snapshots (#10641) * refactor(tools): update tool_registry_tools and caching logic for improved clarity and performance Co-authored-by: Copilot * fix(tools): address searchable registry review feedback * feat(tools): enhance tool registry preparation logic and add specific tool handling Co-authored-by: Copilot --------- Co-authored-by: Copilot Co-authored-by: claudespice Co-authored-by: Sergei Grebnov --- crates/runtime/src/http/routes.rs | 2 + crates/runtime/src/http/v1/tools.rs | 118 +- crates/runtime/src/model/chat.rs | 28 +- crates/runtime/src/model/params/mod.rs | 12 +- crates/runtime/src/model/responses.rs | 28 +- crates/runtime/src/search/rrf.rs | 3 +- crates/runtime/src/tools/builtin/catalog.rs | 2 +- crates/runtime/src/tools/mod.rs | 1 + crates/runtime/src/tools/options.rs | 90 +- crates/runtime/src/tools/registry.rs | 1718 +++++++++++++++++ crates/runtime/src/tools/utils.rs | 312 ++- .../search/src/aggregation/reciprocal_rank.rs | 98 +- 12 files changed, 2294 insertions(+), 118 deletions(-) create mode 100644 crates/runtime/src/tools/registry.rs diff --git a/crates/runtime/src/http/routes.rs b/crates/runtime/src/http/routes.rs index 78ed90b0be..458bd9f595 100644 --- a/crates/runtime/src/http/routes.rs +++ b/crates/runtime/src/http/routes.rs @@ -99,6 +99,7 @@ use tower_http::limit::RequestBodyLimitLayer; v1::inference::get, v1::inference::post, v1::tools::list, + v1::tools::search, v1::tools::post, v1::iceberg::get_config, v1::iceberg::get_namespaces, @@ -328,6 +329,7 @@ pub(crate) fn routes( let tools_auth_message = "Tool invocation (/v1/tools/*) requires `runtime.auth` to be configured. Configure an API key provider in your Spicepod (see https://spiceai.org/docs/reference/runtime#auth) and retry with credentials."; let tools_router = Router::new() .route("/v1/tools", get(v1::tools::list)) + .route("/v1/tools/search", get(v1::tools::search)) .route("/v1/tools/{*name}", post(v1::tools::post)) // Deprecated, use /v1/tools/:name instead .route("/v1/tool/{name}", post(v1::tools::post)) diff --git a/crates/runtime/src/http/v1/tools.rs b/crates/runtime/src/http/v1/tools.rs index b1138b9a91..143ac7c5e8 100644 --- a/crates/runtime/src/http/v1/tools.rs +++ b/crates/runtime/src/http/v1/tools.rs @@ -18,15 +18,19 @@ use std::sync::Arc; use axum::{ Extension, Json, - extract::Path, + extract::{Path, Query}, http::StatusCode, response::{IntoResponse, Response}, }; use futures::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::json; +use tools::SpiceModelTool; -use crate::Runtime; +use crate::{ + Runtime, + tools::registry::{get_tool_registry_tool, tool_registry_prompt_tools}, +}; /// Summary of a tool available to run, and the schema of its input parameters. #[derive(Serialize, Debug, Clone, PartialEq, Eq, Hash, Default, Deserialize)] @@ -37,6 +41,24 @@ struct ListToolElement { parameters: Option, } +impl ListToolElement { + fn from_tool(tool: &Arc) -> Self { + Self { + name: tool.name().to_string(), + description: tool.description().map(|d| d.to_string()), + parameters: tool.parameters(), + } + } +} + +#[derive(Debug, Default, Deserialize)] +#[cfg_attr(feature = "openapi", derive(utoipa::IntoParams))] +#[cfg_attr(feature = "openapi", into_params(parameter_in = Query))] +pub(crate) struct SearchToolsQuery { + /// Embedding model name to use for searchable tool discovery. Required only when multiple embedding models are configured. + embedding_model: Option, +} + /// List Tools /// /// Returns a list of all available tools in the Spice runtime. Tools provide reusable functionality that can be invoked programmatically or by AI agents. @@ -53,23 +75,69 @@ struct ListToolElement { {"name": "get_readiness", "description": "Retrieves the readiness status of all runtime components including registered datasets, models, and embeddings.", "parameters": null}, {"name": "list_datasets", "description": "List all SQL tables available.", "parameters": null} ]) + ), + ( + status = 401, + description = "Tool routes require runtime auth to be configured", + body = serde_json::Value, + example = json!({"message": "Tool invocation (/v1/tools/*) requires `runtime.auth` to be configured."}) ) ) ))] pub(crate) async fn list(Extension(rt): Extension>) -> Response { let tools = rt .list_all_tools() - .map(|tool| ListToolElement { - name: tool.name().to_string(), - description: tool.description().map(|d| d.to_string()), - parameters: tool.parameters(), - }) + .map(|tool| ListToolElement::from_tool(&tool)) .collect::>() .await; (StatusCode::OK, Json(tools)).into_response() } +/// List Searchable Tool Registry Tools +/// +/// Returns the small set of tool definitions an external LLM client should inject to use Spice's searchable tool registry. Invoke returned tools with `POST /v1/tools/{name}`. +#[cfg_attr(feature = "openapi", utoipa::path( + get, + path = "/v1/tools/search", + operation_id = "list_searchable_tool_registry_tools", + tag = "Tools", + params(SearchToolsQuery), + responses( + ( + status = 200, body = [ListToolElement], + description = "Searchable tool registry tools to inject into an external LLM prompt", + example = json!([ + {"name": "tool_search", "description": "Search the Spice tool registry for tools relevant to the current task.", "parameters": {"type": "object"}}, + {"name": "tool_invoke", "description": "Invoke one Spice tool returned by tool_search.", "parameters": {"type": "object"}}, + {"name": "list_datasets", "description": "List all SQL tables available.", "parameters": null} + ]) + ), + (status = 400, description = "Searchable tool registry is not configured", body = serde_json::Value), + ( + status = 401, + description = "Tool routes require runtime auth to be configured", + body = serde_json::Value, + example = json!({"message": "Tool invocation (/v1/tools/*) requires `runtime.auth` to be configured."}) + ) + ) +))] +pub(crate) async fn search( + Extension(rt): Extension>, + Query(query): Query, +) -> Response { + match tool_registry_prompt_tools(Arc::clone(&rt), query.embedding_model.as_deref()).await { + Ok(tools) => { + let tools = tools + .iter() + .map(ListToolElement::from_tool) + .collect::>(); + (StatusCode::OK, Json(tools)).into_response() + } + Err(error) => bad_request(error.to_string().as_str()), + } +} + /// Run Tool /// /// Execute a specific tool by name. The request body schema and response format are defined by each individual tool's specification. Use `GET /v1/tools` to discover available tools and their parameter schemas. @@ -79,7 +147,8 @@ pub(crate) async fn list(Extension(rt): Extension>) -> Response { operation_id = "run_tool", tag = "Tools", params( - ("name" = String, Path, description = "Name of the tool") + ("name" = String, Path, description = "Name of the tool"), + ("embedding_model" = Option, Query, description = "Embedding model to use when invoking the searchable tool registry's tool_search meta-tool") ), request_body( description = "Tool specific input parameters. See /v1/tools for parameter schema.", @@ -111,7 +180,19 @@ pub(crate) async fn list(Extension(rt): Extension>) -> Response { "passenger_count": 2 }])) ))), - (status = 404, description = "Tool not found", body = String, example="Tool no_sql not found"), + (status = 400, description = "Invalid searchable tool registry configuration", body = serde_json::Value), + ( + status = 401, + description = "Tool routes require runtime auth to be configured", + body = serde_json::Value, + example = json!({"message": "Tool invocation (/v1/tools/*) requires `runtime.auth` to be configured."}) + ), + ( + status = 404, + description = "Tool not found", + body = serde_json::Value, + example = json!({"message": "Tool 'no_sql' not found"}) + ), (status = 500, description = "An error occurred while calling the tool", body = serde_json::Value, example=json!({"message": "Error calling tool no_sql: No such tool"})) ) @@ -119,9 +200,22 @@ pub(crate) async fn list(Extension(rt): Extension>) -> Response { pub(crate) async fn post( Extension(rt): Extension>, Path(tool_name): Path, + Query(query): Query, body: String, ) -> Response { - let Some(tool) = rt.get_tool(tool_name.as_str()).await else { + let tool = match get_tool_registry_tool( + Arc::clone(&rt), + tool_name.as_str(), + query.embedding_model.as_deref(), + ) + .await + { + Ok(Some(tool)) => Some(tool), + Ok(None) => rt.get_tool(tool_name.as_str()).await, + Err(error) => return bad_request(error.to_string().as_str()), + }; + + let Some(tool) = tool else { return not_found(format!("Tool '{tool_name}' not found").as_str()); }; @@ -138,3 +232,7 @@ pub(crate) async fn post( fn not_found(message: &str) -> Response { (StatusCode::NOT_FOUND, Json(json!({"message": message}))).into_response() } + +fn bad_request(message: &str) -> Response { + (StatusCode::BAD_REQUEST, Json(json!({"message": message}))).into_response() +} diff --git a/crates/runtime/src/model/chat.rs b/crates/runtime/src/model/chat.rs index 54170a316f..83f1190901 100644 --- a/crates/runtime/src/model/chat.rs +++ b/crates/runtime/src/model/chat.rs @@ -43,6 +43,7 @@ use crate::{ parameters::Parameters, tools::{ options::SpiceToolsOptions, + registry::{TOOL_EMBEDDING_MODEL_PARAM, prepare_model_tools}, utils::{create_table_allowlist, get_tools_with_allowlist}, }, }; @@ -110,16 +111,27 @@ pub async fn try_to_chat_model( // Prevent infinite recursion in case of circular tool calls. .or(Some(DEFAULT_SPICE_TOOL_RECURSION_LIMIT)); - // Create table allowlist from model's datasets if specified - let table_allowlist = create_table_allowlist(&component.datasets); + let tool_embedding_model = extract_secret!(params, TOOL_EMBEDDING_MODEL_PARAM); let tool_model = match spice_tool_opt { - Some(opts) if opts.can_use_tools() => Arc::new(ToolUsingChat::new( - model, - Arc::clone(&rt), - get_tools_with_allowlist(Arc::clone(&rt), &opts, table_allowlist).await, - spice_recursion_limit, - )), + Some(opts) if opts.can_use_tools() => { + let table_allowlist = create_table_allowlist(&component.datasets).map_err(|e| { + LlmError::ModelParameterFailed { + model: component.name.clone(), + source: e, + } + })?; + let tools = get_tools_with_allowlist(Arc::clone(&rt), &opts, table_allowlist).await; + let tools = prepare_model_tools(Arc::clone(&rt), &opts, tools, tool_embedding_model) + .await + .map_err(|e| LlmError::FailedToLoadModel { source: e })?; + Arc::new(ToolUsingChat::new( + model, + Arc::clone(&rt), + tools, + spice_recursion_limit, + )) + } Some(_) | None => model, }; Ok(tool_model) diff --git a/crates/runtime/src/model/params/mod.rs b/crates/runtime/src/model/params/mod.rs index a0308b2265..ad6918054b 100644 --- a/crates/runtime/src/model/params/mod.rs +++ b/crates/runtime/src/model/params/mod.rs @@ -49,15 +49,17 @@ pub fn get_params_spec(source: &ModelSource) -> Option<&'static [ParameterSpec]> } } -pub const PARAM_LEN: usize = 46; -pub const PARAM_WITH_DEPRE_LEN: usize = 47; +pub const PARAM_LEN: usize = 47; +pub const PARAM_WITH_DEPRE_LEN: usize = 48; // Model parameters that are used for openai model provider. Those parameters are supported by other (non-openai) models as well. // OpenAI model is prefixed with `openai_`, use separate PARAMETERS constant to avoid confusion with other model providers. pub const COMMON_MODEL_PARAMETERS: [ParameterSpec; PARAM_LEN] = [ // Common parameters for all models ParameterSpec::runtime("tools") - .description("Which tools should be made available to the model. Set to 'auto' to use all available tools."), + .description("Which tools should be made available to the model. Set to 'auto' to automatically choose between direct tools and searchable discovery, 'all' to use built-in and Spicepod-configured tools directly, or 'search_registry' to require searchable tool discovery."), + ParameterSpec::runtime("tool_embedding_model") + .description("Embedding model name to use for searchable tool discovery. tools: search_registry requires a model configured in the embeddings section and uses it when only one embedding model is configured; tools: auto falls back to direct tools if embeddings are unavailable."), ParameterSpec::runtime("system_prompt") .description("An additional system prompt used for all chat completions to this model."), ParameterSpec::runtime("parameterized_prompt"), @@ -114,7 +116,9 @@ pub const COMMON_MODEL_PARAMETERS: [ParameterSpec; PARAM_LEN] = [ pub const COMMON_MODEL_PARAMETERS_WITH_DEPRECATED: [ParameterSpec; PARAM_WITH_DEPRE_LEN] = [ // Common parameters for all models ParameterSpec::runtime("tools") - .description("Which tools should be made available to the model. Set to 'auto' to use all available tools."), + .description("Which tools should be made available to the model. Set to 'auto' to automatically choose between direct tools and searchable discovery, 'all' to use built-in and Spicepod-configured tools directly, or 'search_registry' to require searchable tool discovery."), + ParameterSpec::runtime("tool_embedding_model") + .description("Embedding model name to use for searchable tool discovery. tools: search_registry requires a model configured in the embeddings section and uses it when only one embedding model is configured; tools: auto falls back to direct tools if embeddings are unavailable."), ParameterSpec::runtime("system_prompt") .description("An additional system prompt used for all chat completions to this model."), ParameterSpec::runtime("parameterized_prompt"), diff --git a/crates/runtime/src/model/responses.rs b/crates/runtime/src/model/responses.rs index a8cb540e8a..9997e8b51e 100644 --- a/crates/runtime/src/model/responses.rs +++ b/crates/runtime/src/model/responses.rs @@ -21,6 +21,7 @@ use crate::model::tool_use_responses::OpenAIResponsesTools; use crate::model::wrapper::responses::ResponsesWrapper; use crate::parameters::Parameters; use crate::tools::options::SpiceToolsOptions; +use crate::tools::registry::{TOOL_EMBEDDING_MODEL_PARAM, prepare_model_tools}; use crate::tools::utils::{create_table_allowlist, get_tools_with_allowlist}; use llms::chat::Error as LlmError; use llms::openai::{DEFAULT_LLM_MODEL, UsageTier}; @@ -103,16 +104,27 @@ pub async fn try_to_responses_model( .transpose() .map_err(|_| unreachable!("SpiceToolsOptions::from_str has no error condition"))?; - // Create table allowlist from model's datasets if specified - let table_allowlist = create_table_allowlist(&component.datasets); + let tool_embedding_model = extract_secret!(params, TOOL_EMBEDDING_MODEL_PARAM); let tool_model = match spice_tool_opt { - Some(opts) if opts.can_use_tools() => Arc::new(ToolUsingResponses::new( - model, - openai_responses_tools.unwrap_or_default(), - get_tools_with_allowlist(Arc::clone(&rt), &opts, table_allowlist).await, - spice_recursion_limit, - )), + Some(opts) if opts.can_use_tools() => { + let table_allowlist = create_table_allowlist(&component.datasets).map_err(|e| { + LlmError::ModelParameterFailed { + model: component.name.clone(), + source: e, + } + })?; + let tools = get_tools_with_allowlist(Arc::clone(&rt), &opts, table_allowlist).await; + let tools = prepare_model_tools(Arc::clone(&rt), &opts, tools, tool_embedding_model) + .await + .map_err(|e| LlmError::FailedToLoadModel { source: e })?; + Arc::new(ToolUsingResponses::new( + model, + openai_responses_tools.unwrap_or_default(), + tools, + spice_recursion_limit, + )) + } Some(_) | None => model, }; diff --git a/crates/runtime/src/search/rrf.rs b/crates/runtime/src/search/rrf.rs index 4a2e1aad2c..9398edf9cc 100644 --- a/crates/runtime/src/search/rrf.rs +++ b/crates/runtime/src/search/rrf.rs @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +use ::search::aggregation::reciprocal_rank::DEFAULT_RRF_K; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use datafusion::catalog::{Session, TableFunctionImpl, TableProvider}; @@ -312,7 +313,7 @@ impl ReciprocalRankFusionArgs { Ok(Self { search_udtf_exprs: search_udtfs, rrf_subquery_arguments: subquery_args, - k: extract_f64!(rrf_args, "k").unwrap_or(60.0), + k: extract_f64!(rrf_args, "k").unwrap_or(DEFAULT_RRF_K), join_key: extract_string!(rrf_args, "join_key").map(ident), time_column: extract_string!(rrf_args, "time_column").map(ident), recency_decay: extract_string!(rrf_args, "recency_decay") diff --git a/crates/runtime/src/tools/builtin/catalog.rs b/crates/runtime/src/tools/builtin/catalog.rs index 47a693eea2..d148835cd2 100644 --- a/crates/runtime/src/tools/builtin/catalog.rs +++ b/crates/runtime/src/tools/builtin/catalog.rs @@ -212,7 +212,7 @@ impl IndividualToolFactory for BuiltinToolCatalog { impl SpiceToolCatalog for BuiltinToolCatalog { async fn all(&self) -> Vec> { let mut tools = vec![]; - for t in SpiceToolsOptions::Auto.tools_by_name() { + for t in SpiceToolsOptions::All.tools_by_name() { match self.construct_builtin(t, None, None, &HashMap::new()) { Ok(tool) => tools.push(tool), Err(e) => tracing::warn!("Failed to construct builtin tool: '{}'. Error: {}", t, e), diff --git a/crates/runtime/src/tools/mod.rs b/crates/runtime/src/tools/mod.rs index 92f6da0eff..80e89915d3 100644 --- a/crates/runtime/src/tools/mod.rs +++ b/crates/runtime/src/tools/mod.rs @@ -27,6 +27,7 @@ pub mod factory; pub mod mcp; pub mod memory; pub mod options; +pub(crate) mod registry; pub mod utils; /// [`Tooling`] define several ways to access and load tools into the runtime. diff --git a/crates/runtime/src/tools/options.rs b/crates/runtime/src/tools/options.rs index d9054a56cf..dd2657dd5d 100644 --- a/crates/runtime/src/tools/options.rs +++ b/crates/runtime/src/tools/options.rs @@ -24,9 +24,16 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum SpiceToolsOptions { - /// Automatically use all available builtin tools. + /// Automatically choose between direct configured tools and searchable discovery. Auto, + /// Use built-in and Spicepod-configured tools directly. + All, + + /// Use a searchable registry over built-in and Spicepod-configured tools. + #[serde(rename = "search_registry")] + SearchRegistry, + /// Use builtin tools relevant for text-to-SQL. Nsql, @@ -42,15 +49,28 @@ impl SpiceToolsOptions { #[must_use] pub fn can_use_tools(&self) -> bool { match self { - SpiceToolsOptions::Auto | SpiceToolsOptions::Nsql => true, + SpiceToolsOptions::Auto + | SpiceToolsOptions::All + | SpiceToolsOptions::SearchRegistry + | SpiceToolsOptions::Nsql => true, SpiceToolsOptions::Disabled => false, SpiceToolsOptions::Specific(t) => !t.is_empty(), } } + #[must_use] + pub(crate) fn includes_all_available_tools(&self) -> bool { + matches!( + self, + SpiceToolsOptions::Auto | SpiceToolsOptions::All | SpiceToolsOptions::SearchRegistry + ) + } + pub(crate) fn tools_by_name(&self) -> Vec<&str> { match self { - SpiceToolsOptions::Auto => vec![ + SpiceToolsOptions::Auto + | SpiceToolsOptions::All + | SpiceToolsOptions::SearchRegistry => vec![ "search", "table_schema", "sql", @@ -73,6 +93,11 @@ impl SpiceToolsOptions { .iter() // Handle nested groupings. e.g: `spiced_tools: nsql, my_other_tool`. .flat_map(|s| match s.parse() { + Ok( + SpiceToolsOptions::Auto + | SpiceToolsOptions::All + | SpiceToolsOptions::SearchRegistry, + ) => SpiceToolsOptions::All.tools_by_name(), Ok(SpiceToolsOptions::Nsql) => SpiceToolsOptions::Nsql.tools_by_name(), _ => vec![s.as_str()], }) @@ -88,6 +113,8 @@ impl FromStr for SpiceToolsOptions { fn from_str(s: &str) -> Result { match s.trim().to_lowercase().as_str() { "auto" => Ok(SpiceToolsOptions::Auto), + "all" => Ok(SpiceToolsOptions::All), + "search_registry" => Ok(SpiceToolsOptions::SearchRegistry), "nsql" => Ok(SpiceToolsOptions::Nsql), "disabled" => Ok(SpiceToolsOptions::Disabled), _ => Ok(SpiceToolsOptions::Specific( @@ -134,5 +161,62 @@ mod tests { tools.iter().unique().count(), "'SpiceToolsOptions::tools_by_name' should not produce duplicates" ); + + assert_eq!( + SpiceToolsOptions::Specific(vec![ + "search_registry".to_string(), + "my_other_tool".to_string(), + ]) + .tools_by_name(), + SpiceToolsOptions::All + .tools_by_name() + .into_iter() + .chain(["my_other_tool"]) + .collect::>() + ); + } + + #[test] + fn test_all_tool_opts() { + assert!(SpiceToolsOptions::All.can_use_tools()); + assert!(SpiceToolsOptions::All.includes_all_available_tools()); + assert_eq!( + SpiceToolsOptions::All.tools_by_name(), + SpiceToolsOptions::Auto.tools_by_name() + ); + + assert!(matches!( + "all" + .parse::() + .expect("all should parse as a tool option"), + SpiceToolsOptions::All + )); + } + + #[test] + fn test_search_registry_tool_opts() { + assert!(SpiceToolsOptions::SearchRegistry.can_use_tools()); + assert!(SpiceToolsOptions::SearchRegistry.includes_all_available_tools()); + assert_eq!( + SpiceToolsOptions::SearchRegistry.tools_by_name(), + SpiceToolsOptions::Auto.tools_by_name() + ); + + assert!(matches!( + "search_registry" + .parse::() + .expect("search_registry should parse as a tool option"), + SpiceToolsOptions::SearchRegistry + )); + } + + #[test] + fn test_search_tool_name_is_specific_tool() { + let option = "search" + .parse::() + .expect("search should parse as a specific tool name"); + assert!( + matches!(option, SpiceToolsOptions::Specific(tools) if tools == vec!["search".to_string()]) + ); } } diff --git a/crates/runtime/src/tools/registry.rs b/crates/runtime/src/tools/registry.rs new file mode 100644 index 0000000000..f96d7d912c --- /dev/null +++ b/crates/runtime/src/tools/registry.rs @@ -0,0 +1,1718 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::{ + borrow::Cow, + collections::{HashMap, HashSet, hash_map::DefaultHasher}, + hash::{Hash, Hasher}, + sync::{Arc, LazyLock}, +}; + +use ::search::aggregation::reciprocal_rank::{ + DEFAULT_RRF_K, reciprocal_rank_fusion_scores, usize_to_f64, +}; +use async_trait::async_trait; +use llms::embeddings::{Embed, EmbeddingInput}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use snafu::Snafu; +use tokio::sync::OnceCell; +use tools::SpiceModelTool; + +use crate::{ + Runtime, + tools::{ + options::SpiceToolsOptions, + utils::{get_tools, parameters}, + }, +}; + +const TOOL_SEARCH_NAME: &str = "tool_search"; +const TOOL_INVOKE_NAME: &str = "tool_invoke"; +const LIST_DATASETS_TOOL_NAME: &str = "list_datasets"; +pub(crate) const TOOL_EMBEDDING_MODEL_PARAM: &str = "tool_embedding_model"; +const DEFAULT_SEARCH_LIMIT: usize = 5; +const MAX_SEARCH_LIMIT: usize = 20; +const AUTO_SEARCH_TOOL_THRESHOLD: usize = 20; +const TOOL_REGISTRY_SEARCH_TOOL_CACHE_MAX_ENTRIES: usize = 64; + +static TOOL_REGISTRY_SEARCH_TOOL_CACHE: LazyLock< + tokio::sync::RwLock>>, +> = LazyLock::new(|| tokio::sync::RwLock::new(HashMap::new())); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ToolRegistrySearchCacheKey { + runtime_id: usize, + embedding_model_name: String, + tools_hash: u64, +} + +#[derive(Debug, Snafu)] +enum ToolRegistryError { + #[snafu(display( + "Tool '{tool_id}' was not found in the searchable tool registry. Available tools: {available_tools}" + ))] + ToolNotFound { + tool_id: String, + available_tools: String, + }, + + #[snafu(display("Failed to invoke tool '{tool_id}' from searchable registry: {source}"))] + ToolInvokeFailed { + tool_id: String, + source: Box, + }, + + #[snafu(display( + "Tool name '{tool_name}' is reserved for the searchable tool registry. Rename the configured tool or disable searchable registry tools." + ))] + ReservedToolName { tool_name: String }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ToolRegistryPreparationMode { + Required, + Auto, +} + +pub(crate) async fn prepare_model_tools( + rt: Arc, + opts: &SpiceToolsOptions, + tools: Vec>, + embedding_model_name: Option<&str>, +) -> Result>, Box> { + match tool_registry_preparation_mode(opts, tools.len()) { + Some(ToolRegistryPreparationMode::Required) => { + ensure_no_reserved_tool_registry_name_conflicts(&tools)?; + let embedding_model = + resolve_tool_registry_embedding_model(Arc::clone(&rt), embedding_model_name) + .await?; + Ok(tool_registry_tools(tools, embedding_model)) + } + Some(ToolRegistryPreparationMode::Auto) => { + if let Some(tool_name) = reserved_tool_registry_name_conflict(&tools) { + tracing::warn!( + "Unable to use searchable tool registry for tools: auto: tool name '{}' is reserved for the searchable tool registry. Falling back to direct tool definitions.", + tool_name + ); + return Ok(tools); + } + + match resolve_tool_registry_embedding_model(Arc::clone(&rt), embedding_model_name).await + { + Ok(embedding_model) => Ok(tool_registry_tools(tools, embedding_model)), + Err(error) => { + tracing::warn!( + "Unable to use searchable tool registry for tools: auto: {}. Falling back to direct tool definitions.", + error + ); + Ok(tools) + } + } + } + None => Ok(tools), + } +} + +fn tool_registry_preparation_mode( + opts: &SpiceToolsOptions, + tool_count: usize, +) -> Option { + match opts { + SpiceToolsOptions::SearchRegistry => Some(ToolRegistryPreparationMode::Required), + SpiceToolsOptions::Auto if should_auto_search(tool_count) => { + Some(ToolRegistryPreparationMode::Auto) + } + SpiceToolsOptions::Specific(requested_tools) => { + specific_tool_registry_preparation_mode(requested_tools, tool_count) + } + SpiceToolsOptions::Auto + | SpiceToolsOptions::All + | SpiceToolsOptions::Nsql + | SpiceToolsOptions::Disabled => None, + } +} + +fn specific_tool_registry_preparation_mode( + requested_tools: &[String], + tool_count: usize, +) -> Option { + let mut auto_requested = false; + + for requested_tool in requested_tools { + match requested_tool.parse::() { + Ok(SpiceToolsOptions::SearchRegistry) => { + return Some(ToolRegistryPreparationMode::Required); + } + Ok(SpiceToolsOptions::Auto) => auto_requested = true, + Ok( + SpiceToolsOptions::All + | SpiceToolsOptions::Nsql + | SpiceToolsOptions::Disabled + | SpiceToolsOptions::Specific(_), + ) + | Err(_) => {} + } + } + + (auto_requested && should_auto_search(tool_count)).then_some(ToolRegistryPreparationMode::Auto) +} + +fn should_auto_search(tool_count: usize) -> bool { + tool_count > AUTO_SEARCH_TOOL_THRESHOLD +} + +fn is_reserved_tool_registry_name(tool_name: &str) -> bool { + matches!(tool_name, TOOL_SEARCH_NAME | TOOL_INVOKE_NAME) +} + +fn reserved_tool_registry_name_conflict(tools: &[Arc]) -> Option { + tools.iter().find_map(|tool| { + let tool_name = tool.name(); + is_reserved_tool_registry_name(tool_name.as_ref()).then(|| tool_name.into_owned()) + }) +} + +fn ensure_no_reserved_tool_registry_name_conflicts( + tools: &[Arc], +) -> Result<(), Box> { + if let Some(tool_name) = reserved_tool_registry_name_conflict(tools) { + return Err(Box::new(ToolRegistryError::ReservedToolName { tool_name })); + } + + Ok(()) +} + +#[must_use] +fn tool_registry_tools( + tools: Vec>, + embedding_model: Arc, +) -> Vec> { + if tools.is_empty() { + return Vec::new(); + } + + let registry = Arc::new(tools); + let search_tool = Arc::new(ToolRegistrySearchTool::new( + registry.as_slice(), + embedding_model, + )) as Arc; + + tool_registry_tools_with_search_tool(®istry, search_tool) +} + +fn tool_registry_tools_with_search_tool( + registry: &Arc>>, + search_tool: Arc, +) -> Vec> { + let direct_tools = registry + .iter() + .filter(|tool| tool.name() == LIST_DATASETS_TOOL_NAME) + .cloned() + .collect::>(); + let mut advertised_tools = vec![ + search_tool, + Arc::new(ToolRegistryInvokeTool::new(Arc::clone(registry))) as Arc, + ]; + advertised_tools.extend(direct_tools); + advertised_tools +} + +pub(crate) async fn tool_registry_prompt_tools( + rt: Arc, + embedding_model_name: Option<&str>, +) -> Result>, Box> { + let tools = Arc::new(get_tools(Arc::clone(&rt), &SpiceToolsOptions::SearchRegistry).await); + ensure_no_reserved_tool_registry_name_conflicts(tools.as_slice())?; + let (resolved_embedding_model_name, embedding_model) = + resolve_tool_registry_embedding_model_with_name(Arc::clone(&rt), embedding_model_name) + .await?; + let search_tool = cached_tool_registry_search_tool( + &rt, + Arc::clone(&tools), + &resolved_embedding_model_name, + embedding_model, + ) + .await as Arc; + + Ok(tool_registry_tools_with_search_tool(&tools, search_tool)) +} + +pub(crate) async fn get_tool_registry_tool( + rt: Arc, + tool_name: &str, + embedding_model_name: Option<&str>, +) -> Result>, Box> { + if is_reserved_tool_registry_name(tool_name) && rt.get_tool(tool_name).await.is_some() { + return Ok(None); + } + + match tool_name { + TOOL_SEARCH_NAME => { + let tools = + Arc::new(get_tools(Arc::clone(&rt), &SpiceToolsOptions::SearchRegistry).await); + let (resolved_embedding_model_name, embedding_model) = + resolve_tool_registry_embedding_model_with_name( + Arc::clone(&rt), + embedding_model_name, + ) + .await?; + let search_tool = cached_tool_registry_search_tool( + &rt, + tools, + &resolved_embedding_model_name, + embedding_model, + ) + .await as Arc; + Ok(Some(search_tool)) + } + TOOL_INVOKE_NAME => { + let tools = get_tools(Arc::clone(&rt), &SpiceToolsOptions::SearchRegistry).await; + let registry = Arc::new(tools); + Ok(Some( + Arc::new(ToolRegistryInvokeTool::new(registry)) as Arc + )) + } + _ => Ok(None), + } +} + +pub(crate) async fn resolve_tool_registry_embedding_model( + rt: Arc, + model_name: Option<&str>, +) -> Result, Box> { + let (_, embedding_model) = + resolve_tool_registry_embedding_model_with_name(rt, model_name).await?; + Ok(embedding_model) +} + +async fn resolve_tool_registry_embedding_model_with_name( + rt: Arc, + model_name: Option<&str>, +) -> Result<(String, Arc), Box> { + let configured_model_names = configured_embedding_model_names(&rt).await; + let model_name = + select_tool_registry_embedding_model_name(&configured_model_names, model_name)?; + + let Some(embedding_model) = rt.embeds().read().await.get(&model_name).cloned() else { + return Err(format!("Embedding model '{model_name}' configured for searchable tool discovery was not loaded. Check earlier embedding model errors and verify the `embeddings` configuration").into()); + }; + Ok((model_name, embedding_model)) +} + +async fn cached_tool_registry_search_tool( + rt: &Arc, + tools: Arc>>, + embedding_model_name: &str, + embedding_model: Arc, +) -> Arc { + let key = ToolRegistrySearchCacheKey { + runtime_id: Arc::as_ptr(rt).addr(), + embedding_model_name: embedding_model_name.to_string(), + tools_hash: tool_registry_tools_hash(&tools), + }; + + if let Some(tool) = TOOL_REGISTRY_SEARCH_TOOL_CACHE.read().await.get(&key) { + return Arc::clone(tool); + } + + let mut cache = TOOL_REGISTRY_SEARCH_TOOL_CACHE.write().await; + if cache.len() >= TOOL_REGISTRY_SEARCH_TOOL_CACHE_MAX_ENTRIES + && !cache.contains_key(&key) + && let Some(evicted_key) = cache.keys().next().cloned() + { + cache.remove(&evicted_key); + } + + let tool = cache.entry(key).or_insert_with(|| { + Arc::new(ToolRegistrySearchTool::new( + tools.as_slice(), + embedding_model, + )) + }); + Arc::clone(tool) +} + +fn tool_registry_tools_hash(tools: &[Arc]) -> u64 { + let mut hasher = DefaultHasher::new(); + tools.len().hash(&mut hasher); + for tool in tools { + tool.name().hash(&mut hasher); + tool.description().hash(&mut hasher); + tool.parameters() + .map(|parameters| parameters.to_string()) + .hash(&mut hasher); + } + hasher.finish() +} + +async fn configured_embedding_model_names(rt: &Arc) -> Vec { + let mut names = rt + .read_app() + .await + .map(|app| { + app.embeddings + .iter() + .map(|embedding| embedding.name.clone()) + .collect::>() + }) + .unwrap_or_default(); + names.sort(); + names +} + +fn select_tool_registry_embedding_model_name( + configured_model_names: &[String], + model_name: Option<&str>, +) -> Result> { + if let Some(model_name) = model_name { + return configured_model_names + .iter() + .find(|configured_model_name| configured_model_name.as_str() == model_name) + .cloned() + .ok_or_else(|| { + format!("Embedding model '{model_name}' specified by `{TOOL_EMBEDDING_MODEL_PARAM}` was not found in the `embeddings` section").into() + }); + } + + match configured_model_names { + [] => Err(format!("No embedding model configured for searchable tool discovery. Add one model to the `embeddings` section, or set `{TOOL_EMBEDDING_MODEL_PARAM}` to reference a configured embedding model").into()), + [model_name] => Ok(model_name.clone()), + model_names => Err(format!("Multiple embedding models are configured for searchable tool discovery: {}. Set `{TOOL_EMBEDDING_MODEL_PARAM}` to one of them", model_names.join(", ")).into()), + } +} + +struct ToolRegistrySearchTool { + documents: Arc>, + document_texts: Arc>, + embedding_model: Arc, + tool_embeddings: OnceCell>>, +} + +impl ToolRegistrySearchTool { + fn new(tools: &[Arc], embedding_model: Arc) -> Self { + let documents = tools.iter().map(ToolDocument::new).collect::>(); + let document_texts = documents + .iter() + .map(ToolDocument::vector_text) + .collect::>(); + + Self { + documents: Arc::new(documents), + document_texts: Arc::new(document_texts), + embedding_model, + tool_embeddings: OnceCell::new(), + } + } +} + +#[async_trait] +impl SpiceModelTool for ToolRegistrySearchTool { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(TOOL_SEARCH_NAME) + } + + fn description(&self) -> Option> { + Some(Cow::Borrowed( + "Search the Spice tool registry for tools relevant to the current task. Call this before tool_invoke; it returns tool_id, description, parameters, and score for the best matches.", + )) + } + + fn parameters(&self) -> Option { + parameters::() + } + + async fn call(&self, arg: &str) -> Result> { + let params: ToolSearchParams = serde_json::from_str(arg)?; + let limit = params + .limit + .unwrap_or(DEFAULT_SEARCH_LIMIT) + .clamp(1, MAX_SEARCH_LIMIT); + let min_score = params.min_score.unwrap_or(0.0).clamp(0.0, 1.0); + + let mut ranked_tools = hybrid_rank_tools( + self.documents.as_slice(), + self.document_texts.as_slice(), + ¶ms, + &self.embedding_model, + &self.tool_embeddings, + ) + .await?; + ranked_tools.sort_by(|left, right| { + right + .score + .total_cmp(&left.score) + .then_with(|| left.tool_id.cmp(&right.tool_id)) + }); + + let max_score = ranked_tools + .first() + .map_or(0.0, |ranked_tool| ranked_tool.score); + let tools = ranked_tools + .into_iter() + .filter(|ranked_tool| ranked_tool.score >= min_score || max_score == 0.0) + .take(limit) + .map(ToolSearchResult::from) + .collect::>(); + + Ok(json!({ + "query": params.query, + "keywords": params.keywords, + "search_mode": "hybrid_rrf", + "tools": tools, + })) + } +} + +struct ToolRegistryInvokeTool { + tools: Arc>>, +} + +impl ToolRegistryInvokeTool { + fn new(tools: Arc>>) -> Self { + Self { tools } + } + + fn find_tool(&self, tool_id: &str) -> Option> { + self.tools + .iter() + .find(|tool| tool.name() == tool_id) + .cloned() + } +} + +#[async_trait] +impl SpiceModelTool for ToolRegistryInvokeTool { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(TOOL_INVOKE_NAME) + } + + fn description(&self) -> Option> { + Some(Cow::Borrowed( + "Invoke one Spice tool returned by tool_search. Pass the selected tool_id and an arguments object matching that tool's parameters.", + )) + } + + fn parameters(&self) -> Option { + parameters::() + } + + async fn call(&self, arg: &str) -> Result> { + let params: ToolInvokeParams = serde_json::from_str(arg)?; + let Some(tool) = self.find_tool(¶ms.tool_id) else { + let available_tools = self + .tools + .iter() + .map(|tool| tool.name().to_string()) + .take(MAX_SEARCH_LIMIT) + .collect::>() + .join(", "); + return Err(Box::new(ToolRegistryError::ToolNotFound { + tool_id: params.tool_id, + available_tools, + })); + }; + + let tool_id = tool.name().to_string(); + let arguments = match params.arguments { + Some(Value::String(arguments)) => arguments, + Some(Value::Null) | None => "{}".to_string(), + Some(arguments) => serde_json::to_string(&arguments)?, + }; + + let result = + tool.call(&arguments) + .await + .map_err(|source| ToolRegistryError::ToolInvokeFailed { + tool_id: tool_id.clone(), + source, + })?; + Ok(json!({ + "tool_id": tool_id, + "result": result, + })) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct ToolSearchParams { + /// Natural-language description of the capability needed. + query: String, + + /// Optional keywords to boost exact lexical matches during hybrid lookup. + #[serde(default)] + keywords: Vec, + + /// Maximum number of matching tools to return. Defaults to 5 and is capped at 20. + #[serde(default)] + limit: Option, + + /// Optional minimum score from 0.0 to 1.0. Leave unset for fallback results. + #[serde(default)] + min_score: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct ToolInvokeParams { + /// Tool identifier returned by `tool_search`. + tool_id: String, + + /// JSON object matching the selected tool's parameter schema. + #[serde(default)] + arguments: Option, +} + +#[derive(Debug)] +struct RankedTool { + tool_id: String, + description: Option, + parameters: Option, + score: f64, + matched_terms: Vec, + match_sources: Vec, +} + +#[derive(Debug, Serialize)] +struct ToolSearchResult { + tool_id: String, + description: Option, + parameters: Option, + score: f64, + matched_terms: Vec, + match_sources: Vec, +} + +impl From for ToolSearchResult { + fn from(ranked_tool: RankedTool) -> Self { + Self { + tool_id: ranked_tool.tool_id, + description: ranked_tool.description, + parameters: ranked_tool.parameters, + score: (ranked_tool.score * 1000.0).round() / 1000.0, + matched_terms: ranked_tool.matched_terms, + match_sources: ranked_tool.match_sources, + } + } +} + +#[derive(Debug, Clone, Serialize)] +struct MatchSource { + source: &'static str, + rank: usize, + score: f64, +} + +#[derive(Debug)] +struct ToolDocument { + tool_id: String, + description: Option, + parameters: Option, + name_text: String, + description_text: String, + parameter_text: String, + name_tokens: Vec, + description_tokens: Vec, + parameter_tokens: Vec, + name_token_set: HashSet, + description_token_set: HashSet, + parameter_token_set: HashSet, + all_token_set: HashSet, + name_token_counts: HashMap, + description_token_counts: HashMap, + parameter_token_counts: HashMap, + parameter_key_tokens: HashSet, +} + +impl ToolDocument { + fn new(tool: &Arc) -> Self { + let tool_id = tool.name().to_string(); + let description = tool + .description() + .map(|description| description.to_string()); + let parameters = tool.parameters(); + let parameter_text = parameters + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(); + let mut parameter_key_tokens = HashSet::new(); + if let Some(parameters) = parameters.as_ref() { + collect_json_key_tokens(parameters, &mut parameter_key_tokens); + } + let name_tokens = tokenize_to_vec(&tool_id); + let description_tokens = tokenize_to_vec(description.as_deref().unwrap_or_default()); + let parameter_tokens = tokenize_to_vec(¶meter_text); + let name_token_set = token_set(&name_tokens); + let description_token_set = token_set(&description_tokens); + let parameter_token_set = token_set(¶meter_tokens); + let all_token_set = name_tokens + .iter() + .chain(&description_tokens) + .chain(¶meter_tokens) + .cloned() + .collect::>(); + let name_token_counts = token_counts(&name_tokens); + let description_token_counts = token_counts(&description_tokens); + let parameter_token_counts = token_counts(¶meter_tokens); + + Self { + name_text: normalize_text(&tool_id), + description_text: normalize_text(description.as_deref().unwrap_or_default()), + parameter_text: normalize_text(¶meter_text), + name_tokens, + description_tokens, + parameter_tokens, + name_token_set, + description_token_set, + parameter_token_set, + all_token_set, + name_token_counts, + description_token_counts, + parameter_token_counts, + parameter_key_tokens, + tool_id, + description, + parameters, + } + } + + fn total_tokens(&self) -> usize { + self.name_tokens.len() + self.description_tokens.len() + self.parameter_tokens.len() + } + + fn vector_text(&self) -> String { + format!( + "tool_id: {}\ndescription: {}\nparameters: {}", + self.tool_id, self.description_text, self.parameter_text + ) + } +} + +#[derive(Debug)] +struct ChannelMatch { + document_index: usize, + score: f64, + matched_terms: Vec, +} + +#[derive(Debug, Default, Clone)] +struct FusedMatch { + fused_score: f64, + matched_terms: Vec, + match_sources: Vec, +} + +async fn hybrid_rank_tools( + documents: &[ToolDocument], + document_texts: &[String], + params: &ToolSearchParams, + embedding_model: &Arc, + tool_embeddings: &OnceCell>>, +) -> Result, Box> { + let query_tokens = tokenize_to_vec(¶ms.query); + let keyword_tokens = params + .keywords + .iter() + .flat_map(|keyword| tokenize_to_vec(keyword)) + .collect::>(); + let search_tokens = unique_tokens(query_tokens.iter().chain(&keyword_tokens).cloned()); + + let mut channels = vec![ + ( + "full_text", + full_text_channel_matches(documents, &search_tokens), + ), + ( + "keyword", + keyword_channel_matches(documents, ¶ms.query, ¶ms.keywords, &search_tokens), + ), + ("schema", schema_channel_matches(documents, &search_tokens)), + ]; + channels.push(( + "vector", + vector_channel_matches( + document_texts, + ¶ms.query, + embedding_model, + tool_embeddings, + ) + .await?, + )); + let fused_matches = reciprocal_rank_fusion(channels); + let max_score = fused_matches + .values() + .map(|fused_match| fused_match.fused_score) + .fold(0.0, f64::max); + + Ok(documents + .iter() + .enumerate() + .map(|(document_index, document)| { + let mut fused_match = fused_matches + .get(&document_index) + .cloned() + .unwrap_or_default(); + fused_match.matched_terms.sort(); + fused_match.matched_terms.dedup(); + fused_match + .match_sources + .sort_by_key(|source| (source.rank, source.source)); + + RankedTool { + tool_id: document.tool_id.clone(), + description: document.description.clone(), + parameters: document.parameters.clone(), + score: if max_score > 0.0 { + fused_match.fused_score / max_score + } else { + 0.0 + }, + matched_terms: fused_match.matched_terms, + match_sources: fused_match.match_sources, + } + }) + .collect()) +} + +async fn vector_channel_matches( + document_texts: &[String], + query: &str, + embedding_model: &Arc, + tool_embeddings: &OnceCell>>, +) -> Result, llms::embeddings::Error> { + if query.trim().is_empty() || document_texts.is_empty() { + return Ok(Vec::new()); + } + + let query_embeddings = embedding_model + .embed(EmbeddingInput::String(query.to_string())) + .await?; + let Some(query_embedding) = query_embeddings.first() else { + return Ok(Vec::new()); + }; + + let document_embeddings = tool_embeddings + .get_or_try_init(|| { + let embedding_model = Arc::clone(embedding_model); + let document_texts = document_texts.to_vec(); + async move { + embedding_model + .embed(EmbeddingInput::StringArray(document_texts)) + .await + } + }) + .await?; + + Ok(document_embeddings + .iter() + .enumerate() + .filter_map(|(document_index, document_embedding)| { + non_zero_channel_match( + document_index, + cosine_similarity(query_embedding, document_embedding).max(0.0), + Vec::new(), + ) + }) + .collect()) +} + +fn full_text_channel_matches( + documents: &[ToolDocument], + query_tokens: &[String], +) -> Vec { + if query_tokens.is_empty() { + return Vec::new(); + } + + let document_count = usize_to_f64(documents.len()); + let document_frequency = document_frequency(documents, query_tokens); + documents + .iter() + .enumerate() + .filter_map(|(document_index, document)| { + let mut score = 0.0; + let mut matched_terms = Vec::new(); + + for query_token in query_tokens { + let field_score = + (usize_to_f64(token_count(&document.name_token_counts, query_token)) * 3.0) + + (usize_to_f64(token_count( + &document.description_token_counts, + query_token, + )) * 2.0) + + usize_to_f64(token_count(&document.parameter_token_counts, query_token)); + if field_score > 0.0 { + let frequency = usize_to_f64( + document_frequency + .get(query_token) + .copied() + .unwrap_or_default(), + ); + let inverse_document_frequency = + ((document_count + 1.0) / (frequency + 0.5)).ln().max(0.0) + 1.0; + score += inverse_document_frequency * field_score; + matched_terms.push(query_token.clone()); + } + } + + let length_normalizer = usize_to_f64(document.total_tokens().max(1)).sqrt(); + non_zero_channel_match(document_index, score / length_normalizer, matched_terms) + }) + .collect() +} + +fn keyword_channel_matches( + documents: &[ToolDocument], + query: &str, + keywords: &[String], + query_tokens: &[String], +) -> Vec { + let phrases = if keywords.is_empty() { + vec![query.to_string()] + } else { + keywords.to_vec() + }; + + documents + .iter() + .enumerate() + .filter_map(|(document_index, document)| { + let mut score = 0.0; + let mut matched_terms = Vec::new(); + + for phrase in &phrases { + let normalized_phrase = normalize_text(phrase); + if normalized_phrase.is_empty() { + continue; + } + + if document.name_text == normalized_phrase { + score += 10.0; + matched_terms.push(normalized_phrase.clone()); + } else if document.name_text.contains(&normalized_phrase) { + score += 6.0; + matched_terms.push(normalized_phrase.clone()); + } else if document.description_text.contains(&normalized_phrase) { + score += 4.0; + matched_terms.push(normalized_phrase.clone()); + } else if document.parameter_text.contains(&normalized_phrase) { + score += 2.0; + matched_terms.push(normalized_phrase.clone()); + } + } + + for query_token in query_tokens { + let token_score = if document.name_token_set.contains(query_token) { + 3.0 + } else if document.description_token_set.contains(query_token) { + 2.0 + } else if document.parameter_token_set.contains(query_token) { + 1.0 + } else { + 0.0 + }; + if token_score > 0.0 { + score += token_score; + matched_terms.push(query_token.clone()); + } + } + + non_zero_channel_match(document_index, score, matched_terms) + }) + .collect() +} + +fn schema_channel_matches( + documents: &[ToolDocument], + query_tokens: &[String], +) -> Vec { + if query_tokens.is_empty() { + return Vec::new(); + } + + documents + .iter() + .enumerate() + .filter_map(|(document_index, document)| { + let mut score = 0.0; + let mut matched_terms = Vec::new(); + for query_token in query_tokens { + if document.parameter_key_tokens.contains(query_token) { + score += 4.0; + matched_terms.push(query_token.clone()); + } else if document.parameter_token_set.contains(query_token) { + score += 1.0; + matched_terms.push(query_token.clone()); + } + } + non_zero_channel_match(document_index, score, matched_terms) + }) + .collect() +} + +fn reciprocal_rank_fusion( + channels: Vec<(&'static str, Vec)>, +) -> HashMap { + let mut fused_matches = HashMap::new(); + let mut ranked_channels = Vec::with_capacity(channels.len()); + + for (source, mut channel_matches) in channels { + channel_matches.sort_by(|left, right| { + right + .score + .total_cmp(&left.score) + .then_with(|| left.document_index.cmp(&right.document_index)) + }); + ranked_channels.push( + channel_matches + .iter() + .map(|channel_match| channel_match.document_index) + .collect::>(), + ); + + for (rank_index, channel_match) in channel_matches.into_iter().enumerate() { + let rank = rank_index + 1; + let fused_match = fused_matches + .entry(channel_match.document_index) + .or_insert_with(FusedMatch::default); + fused_match + .matched_terms + .extend(channel_match.matched_terms); + fused_match.match_sources.push(MatchSource { + source, + rank, + score: (channel_match.score * 1000.0).round() / 1000.0, + }); + } + } + + for (document_index, fused_score) in + reciprocal_rank_fusion_scores(ranked_channels, DEFAULT_RRF_K) + { + fused_matches + .entry(document_index) + .or_insert_with(FusedMatch::default) + .fused_score = fused_score; + } + + fused_matches +} + +fn non_zero_channel_match( + document_index: usize, + score: f64, + matched_terms: Vec, +) -> Option { + (score > 0.0).then_some(ChannelMatch { + document_index, + score, + matched_terms, + }) +} + +fn document_frequency( + documents: &[ToolDocument], + query_tokens: &[String], +) -> HashMap { + let mut frequency = HashMap::new(); + for document in documents { + for query_token in query_tokens { + if document.all_token_set.contains(query_token) { + frequency + .entry(query_token.clone()) + .and_modify(|count| *count += 1) + .or_insert(1); + } + } + } + frequency +} + +fn token_set(tokens: &[String]) -> HashSet { + tokens.iter().cloned().collect() +} + +fn token_counts(tokens: &[String]) -> HashMap { + let mut counts = HashMap::new(); + for token in tokens { + counts + .entry(token.clone()) + .and_modify(|count| *count += 1) + .or_insert(1); + } + counts +} + +fn token_count(counts: &HashMap, token: &str) -> usize { + counts.get(token).copied().unwrap_or_default() +} + +fn cosine_similarity(left: &[f32], right: &[f32]) -> f64 { + if left.len() != right.len() || left.is_empty() { + return 0.0; + } + + let (dot_product, left_norm, right_norm) = left.iter().zip(right).fold( + (0.0_f64, 0.0_f64, 0.0_f64), + |(dot_product, left_norm, right_norm), (&left_value, &right_value)| { + let left_value = f64::from(left_value); + let right_value = f64::from(right_value); + ( + dot_product + (left_value * right_value), + left_norm + (left_value * left_value), + right_norm + (right_value * right_value), + ) + }, + ); + + if left_norm == 0.0 || right_norm == 0.0 { + 0.0 + } else { + dot_product / (left_norm.sqrt() * right_norm.sqrt()) + } +} + +fn normalize_text(text: impl AsRef) -> String { + text.as_ref() + .chars() + .map(|character| { + if character.is_ascii_alphanumeric() { + character.to_ascii_lowercase() + } else { + ' ' + } + }) + .collect::() + .split_whitespace() + .collect::>() + .join(" ") +} + +fn tokenize_to_vec(text: &str) -> Vec { + unique_tokens( + normalize_text(text) + .split_whitespace() + .filter_map(normalize_search_token), + ) +} + +fn unique_tokens(tokens: impl IntoIterator) -> Vec { + let mut seen = HashSet::new(); + let mut unique = Vec::new(); + for token in tokens { + if seen.insert(token.clone()) { + unique.push(token); + } + } + unique +} + +fn normalize_search_token(token: &str) -> Option { + if token.len() <= 1 || is_stop_word(token) { + return None; + } + + let token = if token.len() > 4 && token.ends_with("ies") { + format!("{}y", token.trim_end_matches("ies")) + } else if token.len() > 3 + && token.ends_with('s') + && !token.ends_with("ss") + && !token.ends_with("us") + { + token.trim_end_matches('s').to_string() + } else { + token.to_string() + }; + + (!token.is_empty()).then_some(token) +} + +fn collect_json_key_tokens(value: &Value, tokens: &mut HashSet) { + match value { + Value::Object(object) => { + for (key, value) in object { + tokens.extend(tokenize_to_vec(key)); + collect_json_key_tokens(value, tokens); + } + } + Value::Array(values) => { + for value in values { + collect_json_key_tokens(value, tokens); + } + } + Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {} + } +} + +fn is_stop_word(token: &str) -> bool { + matches!( + token, + "a" | "an" + | "and" + | "are" + | "as" + | "at" + | "be" + | "by" + | "for" + | "from" + | "in" + | "into" + | "is" + | "of" + | "on" + | "or" + | "that" + | "the" + | "this" + | "to" + | "with" + ) +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use llms::embeddings::Error as EmbeddingError; + + use super::*; + + struct MockTool { + name: &'static str, + description: &'static str, + parameters: Value, + response: Value, + received_args: Arc>>, + } + + #[async_trait] + impl SpiceModelTool for MockTool { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(self.name) + } + + fn description(&self) -> Option> { + Some(Cow::Borrowed(self.description)) + } + + fn parameters(&self) -> Option { + Some(self.parameters.clone()) + } + + async fn call(&self, arg: &str) -> Result> { + if let Ok(mut received_args) = self.received_args.lock() { + received_args.push(arg.to_string()); + } + Ok(self.response.clone()) + } + } + + struct FailingTool { + name: &'static str, + } + + #[async_trait] + impl SpiceModelTool for FailingTool { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(self.name) + } + + fn description(&self) -> Option> { + Some(Cow::Borrowed("Always fails")) + } + + fn parameters(&self) -> Option { + Some(json!({"type": "object"})) + } + + async fn call( + &self, + _arg: &str, + ) -> Result> { + Err("tool failed".into()) + } + } + + fn mock_tool( + name: &'static str, + description: &'static str, + response: Value, + ) -> (Arc, Arc>>) { + mock_tool_with_parameters( + name, + description, + json!({"type": "object", "properties": {}}), + response, + ) + } + + fn mock_tool_with_parameters( + name: &'static str, + description: &'static str, + parameters: Value, + response: Value, + ) -> (Arc, Arc>>) { + let received_args = Arc::new(Mutex::new(Vec::new())); + ( + Arc::new(MockTool { + name, + description, + parameters, + response, + received_args: Arc::clone(&received_args), + }), + received_args, + ) + } + + #[derive(Debug)] + struct MockEmbed; + + #[async_trait] + impl Embed for MockEmbed { + async fn embed(&self, input: EmbeddingInput) -> Result>, EmbeddingError> { + let texts = match input { + EmbeddingInput::String(text) => vec![text], + EmbeddingInput::StringArray(texts) => texts, + EmbeddingInput::IntegerArray(_) | EmbeddingInput::ArrayOfIntegerArray(_) => { + Vec::new() + } + }; + + Ok(texts + .iter() + .map(|text| { + let normalized = normalize_text(text); + if normalized.contains("forecast") || normalized.contains("weather") { + vec![1.0, 0.0] + } else if normalized.contains("sql") || normalized.contains("query") { + vec![0.0, 1.0] + } else { + vec![0.1, 0.1] + } + }) + .collect()) + } + + fn size(&self) -> i32 { + 2 + } + } + + fn mock_embed() -> Arc { + Arc::new(MockEmbed) + } + + #[test] + fn embedding_selection_uses_explicit_configured_model() { + let configured_models = vec!["first".to_string(), "second".to_string()]; + let selected = + select_tool_registry_embedding_model_name(&configured_models, Some("second")) + .expect("explicit configured embedding model should be selected"); + + assert_eq!(selected, "second"); + } + + #[test] + fn embedding_selection_rejects_explicit_missing_model() { + let configured_models = vec!["configured".to_string()]; + let error = select_tool_registry_embedding_model_name(&configured_models, Some("missing")) + .expect_err("missing explicit embedding model should fail"); + + assert!( + error + .to_string() + .contains("was not found in the `embeddings` section") + ); + } + + #[test] + fn embedding_selection_uses_single_configured_model_when_unset() { + let configured_models = vec!["only_embedding".to_string()]; + let selected = select_tool_registry_embedding_model_name(&configured_models, None) + .expect("single configured embedding model should be inferred"); + + assert_eq!(selected, "only_embedding"); + } + + #[test] + fn embedding_selection_requires_configuration_when_unset() { + let configured_models = Vec::new(); + let error = select_tool_registry_embedding_model_name(&configured_models, None) + .expect_err("missing embedding model should fail"); + + assert!( + error + .to_string() + .contains("No embedding model configured for searchable tool discovery") + ); + } + + #[test] + fn embedding_selection_requires_explicit_model_when_multiple_configured() { + let configured_models = vec!["first".to_string(), "second".to_string()]; + let error = select_tool_registry_embedding_model_name(&configured_models, None) + .expect_err("ambiguous embedding model should fail"); + + assert!(error.to_string().contains(TOOL_EMBEDDING_MODEL_PARAM)); + } + + #[test] + fn auto_search_threshold_only_triggers_for_large_tool_sets() { + assert!(!should_auto_search(AUTO_SEARCH_TOOL_THRESHOLD)); + assert!(should_auto_search(AUTO_SEARCH_TOOL_THRESHOLD + 1)); + } + + #[test] + fn specific_search_registry_requests_registry_wrapping() { + let opts = + SpiceToolsOptions::Specific(vec!["search_registry".to_string(), "my_tool".to_string()]); + + assert_eq!( + tool_registry_preparation_mode(&opts, 1), + Some(ToolRegistryPreparationMode::Required) + ); + } + + #[test] + fn specific_auto_uses_registry_threshold() { + let opts = SpiceToolsOptions::Specific(vec!["auto".to_string(), "my_tool".to_string()]); + + assert_eq!( + tool_registry_preparation_mode(&opts, AUTO_SEARCH_TOOL_THRESHOLD), + None + ); + assert_eq!( + tool_registry_preparation_mode(&opts, AUTO_SEARCH_TOOL_THRESHOLD + 1), + Some(ToolRegistryPreparationMode::Auto) + ); + } + + #[test] + fn specific_all_keeps_direct_tool_mode() { + let opts = SpiceToolsOptions::Specific(vec!["all".to_string(), "my_tool".to_string()]); + + assert_eq!(tool_registry_preparation_mode(&opts, usize::MAX), None); + } + + #[test] + fn reserved_registry_tool_names_are_rejected() { + let (reserved_tool, _) = mock_tool( + TOOL_SEARCH_NAME, + "User configured tool with a reserved registry name", + json!(null), + ); + let tools = vec![reserved_tool]; + + let error = ensure_no_reserved_tool_registry_name_conflicts(&tools) + .expect_err("reserved registry tool names should be rejected"); + + assert!(error.to_string().contains("reserved")); + } + + #[tokio::test] + async fn search_ranks_relevant_tools_first() { + let (sql_tool, _) = mock_tool( + "sql", + "Run SQL queries against datasets and return query results", + json!(null), + ); + let (readiness_tool, _) = mock_tool_with_parameters( + "get_readiness", + "Retrieve component readiness status", + json!(null), + json!(null), + ); + let advertised_tools = tool_registry_tools(vec![readiness_tool, sql_tool], mock_embed()); + let search_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_SEARCH_NAME) + .expect("tool_search should be advertised"); + + let result = search_tool + .call(r#"{"query":"run a SQL query","limit":2}"#) + .await + .expect("tool search should succeed"); + let tools = result + .get("tools") + .and_then(Value::as_array) + .expect("tool search response should contain tools array"); + let first_tool_id = tools + .first() + .and_then(|tool| tool.get("tool_id")) + .and_then(Value::as_str) + .expect("first search result should have a tool_id"); + + assert_eq!(first_tool_id, "sql"); + assert_eq!(result.get("search_mode"), Some(&json!("hybrid_rrf"))); + let match_sources = tools + .first() + .and_then(|tool| tool.get("match_sources")) + .and_then(Value::as_array) + .expect("first search result should include match sources"); + assert!( + match_sources.iter().any(|source| { + source.get("source").and_then(Value::as_str) == Some("full_text") + }), + "hybrid search should include full-text matches" + ); + } + + #[tokio::test] + async fn search_uses_keyword_channel() { + let (sql_tool, _) = mock_tool( + "sql", + "Run SQL queries against datasets and return query results", + json!(null), + ); + let (readiness_tool, _) = mock_tool( + "get_readiness", + "Retrieve component readiness status", + json!(null), + ); + let advertised_tools = tool_registry_tools(vec![sql_tool, readiness_tool], mock_embed()); + let search_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_SEARCH_NAME) + .expect("tool_search should be advertised"); + + let result = search_tool + .call(r#"{"query":"component state","keywords":["get readiness"],"limit":2}"#) + .await + .expect("tool search should succeed"); + let first_tool = result + .get("tools") + .and_then(Value::as_array) + .and_then(|tools| tools.first()) + .expect("tool search should return at least one result"); + + assert_eq!(first_tool.get("tool_id"), Some(&json!("get_readiness"))); + assert!( + first_tool + .get("match_sources") + .and_then(Value::as_array) + .is_some_and(|sources| sources.iter().any(|source| { + source.get("source").and_then(Value::as_str) == Some("keyword") + })), + "hybrid search should include keyword matches" + ); + } + + #[tokio::test] + async fn search_uses_parameter_schema_channel() { + let (weather_tool, _) = mock_tool_with_parameters( + "weather", + "Fetch conditions for a location", + json!({ + "type": "object", + "properties": { + "city": {"type": "string"} + } + }), + json!(null), + ); + let (calculator_tool, _) = mock_tool_with_parameters( + "calculator", + "Evaluate arithmetic expressions", + json!({ + "type": "object", + "properties": { + "expression": {"type": "string"} + } + }), + json!(null), + ); + let advertised_tools = + tool_registry_tools(vec![calculator_tool, weather_tool], mock_embed()); + let search_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_SEARCH_NAME) + .expect("tool_search should be advertised"); + + let result = search_tool + .call(r#"{"query":"tool with city argument","limit":2}"#) + .await + .expect("tool search should succeed"); + let first_tool = result + .get("tools") + .and_then(Value::as_array) + .and_then(|tools| tools.first()) + .expect("tool search should return at least one result"); + + assert_eq!(first_tool.get("tool_id"), Some(&json!("weather"))); + assert!( + first_tool + .get("match_sources") + .and_then(Value::as_array) + .is_some_and(|sources| sources.iter().any(|source| { + source.get("source").and_then(Value::as_str) == Some("schema") + })), + "hybrid search should include parameter-schema matches" + ); + } + + #[tokio::test] + async fn search_uses_vector_channel_when_embedding_model_is_available() { + let (weather_tool, _) = mock_tool( + "weather", + "Get weather forecasts and current conditions", + json!(null), + ); + let (sql_tool, _) = mock_tool("sql", "Run SQL queries", json!(null)); + let advertised_tools = tool_registry_tools(vec![sql_tool, weather_tool], mock_embed()); + let search_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_SEARCH_NAME) + .expect("tool_search should be advertised"); + + let result = search_tool + .call(r#"{"query":"weather outlook","limit":2}"#) + .await + .expect("tool search should succeed"); + let first_tool = result + .get("tools") + .and_then(Value::as_array) + .and_then(|tools| tools.first()) + .expect("tool search should return at least one result"); + + assert_eq!(first_tool.get("tool_id"), Some(&json!("weather"))); + assert!( + first_tool + .get("match_sources") + .and_then(Value::as_array) + .is_some_and(|sources| sources.iter().any(|source| { + source.get("source").and_then(Value::as_str) == Some("vector") + })), + "hybrid search should include vector matches" + ); + } + + #[tokio::test] + async fn invoke_calls_selected_tool_with_arguments() { + let (sql_tool, received_args) = mock_tool("sql", "Run SQL queries", json!({"rows": 1})); + let advertised_tools = tool_registry_tools(vec![sql_tool], mock_embed()); + let invoke_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_INVOKE_NAME) + .expect("tool_invoke should be advertised"); + + let result = invoke_tool + .call(r#"{"tool_id":"sql","arguments":{"query":"select 1"}}"#) + .await + .expect("tool invoke should succeed"); + + assert_eq!( + result.get("tool_id"), + Some(&Value::String("sql".to_string())) + ); + assert_eq!(result.get("result"), Some(&json!({"rows": 1}))); + + let received_args = received_args + .lock() + .expect("received args lock should not be poisoned"); + assert_eq!(received_args.as_slice(), [r#"{"query":"select 1"}"#]); + } + + #[tokio::test] + async fn invoke_returns_error_when_tool_id_is_missing() { + let (sql_tool, _) = mock_tool("sql", "Run SQL queries", json!({"rows": 1})); + let advertised_tools = tool_registry_tools(vec![sql_tool], mock_embed()); + let invoke_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_INVOKE_NAME) + .expect("tool_invoke should be advertised"); + + let error = invoke_tool + .call(r#"{"tool_id":"missing","arguments":{}}"#) + .await + .expect_err("missing tool id should return an error"); + + assert!( + error + .to_string() + .contains("was not found in the searchable tool registry") + ); + } + + #[tokio::test] + async fn invoke_returns_error_when_selected_tool_fails() { + let failing_tool = Arc::new(FailingTool { name: "failing" }) as Arc; + let advertised_tools = tool_registry_tools(vec![failing_tool], mock_embed()); + let invoke_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_INVOKE_NAME) + .expect("tool_invoke should be advertised"); + + let error = invoke_tool + .call(r#"{"tool_id":"failing","arguments":{}}"#) + .await + .expect_err("failing tool should return an error"); + + assert!( + error + .to_string() + .contains("Failed to invoke tool 'failing' from searchable registry") + ); + } + + #[tokio::test] + async fn invoke_requires_exact_tool_id_match() { + let (namespaced_tool, _) = mock_tool( + "catalog/sql", + "Run SQL queries through a catalog tool", + json!({"namespaced": true}), + ); + let (encoded_name_tool, _) = mock_tool( + "catalog_sql", + "A different tool whose raw name matches the encoded namespaced tool", + json!({"encoded": true}), + ); + let advertised_tools = + tool_registry_tools(vec![namespaced_tool, encoded_name_tool], mock_embed()); + let invoke_tool = advertised_tools + .iter() + .find(|tool| tool.name() == TOOL_INVOKE_NAME) + .expect("tool_invoke should be advertised"); + + let result = invoke_tool + .call(r#"{"tool_id":"catalog_sql","arguments":{}}"#) + .await + .expect("exact tool id should invoke matching raw tool name"); + + assert_eq!(result.get("result"), Some(&json!({"encoded": true}))); + } + + #[test] + fn registry_keeps_list_datasets_directly_advertised() { + let (list_datasets_tool, _) = mock_tool( + LIST_DATASETS_TOOL_NAME, + "List all SQL tables available", + json!([]), + ); + let (sql_tool, _) = mock_tool("sql", "Run SQL queries", json!(null)); + + let mut names = tool_registry_tools(vec![list_datasets_tool, sql_tool], mock_embed()) + .iter() + .map(|tool| tool.name().to_string()) + .collect::>(); + names.sort(); + + assert_eq!(names, vec!["list_datasets", "tool_invoke", "tool_search"]); + } +} diff --git a/crates/runtime/src/tools/utils.rs b/crates/runtime/src/tools/utils.rs index 03bb0e3a18..278a3ca9cf 100644 --- a/crates/runtime/src/tools/utils.rs +++ b/crates/runtime/src/tools/utils.rs @@ -25,16 +25,24 @@ use runtime_datafusion::allowlist::ResolvedTableAwareAllowlist; use schemars::{JsonSchema, schema_for}; use serde::Serialize; use serde_json::Value; -use std::collections::HashMap; +use snafu::Snafu; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::datafusion::{SPICE_DEFAULT_CATALOG, SPICE_DEFAULT_SCHEMA}; use crate::{Runtime, tools::catalog::SpiceToolCatalog}; use super::builtin::catalog::BuiltinToolCatalog; +use super::factory::default_catalog_names; use super::{Tooling, options::SpiceToolsOptions}; use tools::{SpiceModelTool, rename::with_name}; +#[derive(Debug, Snafu)] +enum ToolUtilsError { + #[snafu(display("Failed to create table allowlist from model datasets: {source}"))] + CreateTableAllowlist { source: globset::Error }, +} + /// Creates the messages that would be sent and received if a language model were to request the `tool` /// to be called (via an assistant message), with defined `arg`, and the response from running the /// tool (via a tool message) also as a message. @@ -88,21 +96,18 @@ pub fn parameters() -> Option { /// Create a [`ResolvedTableAwareAllowlist`] from a list of dataset patterns. /// -/// Returns `None` if the list is empty. -pub fn create_table_allowlist(datasets: &[String]) -> Option { +/// Returns `Ok(None)` if the list is empty. +pub fn create_table_allowlist( + datasets: &[String], +) -> Result, Box> { if datasets.is_empty() { - return None; + return Ok(None); } - match ResolvedTableAwareAllowlist::with_defaults(SPICE_DEFAULT_CATALOG, SPICE_DEFAULT_SCHEMA) + ResolvedTableAwareAllowlist::with_defaults(SPICE_DEFAULT_CATALOG, SPICE_DEFAULT_SCHEMA) .with_table_patterns(datasets.to_vec()) - { - Ok(allowlist) => Some(allowlist), - Err(e) => { - tracing::warn!("Failed to create table allowlist from model datasets: {e}"); - None - } - } + .map(Some) + .map_err(|source| Box::new(ToolUtilsError::CreateTableAllowlist { source }).into()) } #[must_use] @@ -116,83 +121,240 @@ pub async fn get_tools_with_allowlist( opts: &SpiceToolsOptions, table_allowlist: Option, ) -> Vec> { + let configured_tool_names = configured_tool_names(&rt).await; let all_tools = rt.tools.read().await; let mut tools = vec![]; let mut missing_tools = vec![]; + let mut seen_tool_names = HashSet::new(); - for tt in opts.tools_by_name() { - if let Some((catalog_name, catalog_tool)) = tt.split_once(':') { - if let Some(Tooling::Catalog(catalog)) = all_tools.get(catalog_name) { - let catalog = match ( - catalog.as_any().downcast_ref::(), - table_allowlist.clone(), - ) { - (None, Some(_)) => { - tracing::info!( - "Table allowlist is only applicable to builtin catalog/tools. Allowlist will not be applied to '{catalog_name}'" - ); - Arc::clone(catalog) - } - (Some(builtin_catalog), Some(allowlist)) => Arc::new( - builtin_catalog - .clone() - .with_table_allowlist(allowlist.clone()), + if opts.includes_all_available_tools() { + extend_unique_tools( + &mut tools, + &mut seen_tool_names, + all_available_tools( + Arc::clone(&rt), + &all_tools, + &configured_tool_names, + table_allowlist, + ) + .await, + ); + return tools; + } + + if let SpiceToolsOptions::Specific(requested_tools) = opts { + for tt in requested_tools { + match tt.parse::() { + Ok( + SpiceToolsOptions::Auto + | SpiceToolsOptions::All + | SpiceToolsOptions::SearchRegistry, + ) => extend_unique_tools( + &mut tools, + &mut seen_tool_names, + all_available_tools( + Arc::clone(&rt), + &all_tools, + &configured_tool_names, + table_allowlist.clone(), ) - as Arc, - _ => Arc::clone(catalog), - }; - - if let Some(t) = catalog.get(catalog_tool).await { - tools.push(with_name( - &t, - format!("{catalog_name}/{}", t.name()).as_str(), - )); - } else { - tracing::warn!("Tool '{catalog_tool}' is not found in '{catalog_name}'."); - missing_tools.push(tt); - } - } else { - missing_tools.push(tt); - } - } else if let Some(tool) = all_tools.get(tt) { - if let Some(ref allowlist) = table_allowlist - && BuiltinToolCatalog::is_builtin_tool(tt) - { - if let Ok(t) = BuiltinToolCatalog::new(Arc::clone(&rt)) - .with_table_allowlist(allowlist.clone()) - .construct_builtin(tt, None, None, &HashMap::new()) - { - tools.push(t); - } else { - tracing::warn!("Failed to construct tool '{tt}' with table allowlist."); - missing_tools.push(tt); + .await, + ), + Ok(SpiceToolsOptions::Nsql) => { + for tool_name in SpiceToolsOptions::Nsql.tools_by_name() { + match get_tool_by_name( + Arc::clone(&rt), + &all_tools, + tool_name, + table_allowlist.clone(), + ) + .await + { + Some(resolved_tools) => extend_unique_tools( + &mut tools, + &mut seen_tool_names, + resolved_tools, + ), + None => missing_tools.push(tool_name.to_string()), + } + } } - } else { - if table_allowlist.is_some() { - tracing::info!( - "Table allowlist is only applicable to builtin catalog/tools. Allowlist will not be applied to '{tt}'" - ); + Ok(SpiceToolsOptions::Disabled) => {} + Ok(SpiceToolsOptions::Specific(_)) | Err(_) => { + match get_tool_by_name(Arc::clone(&rt), &all_tools, tt, table_allowlist.clone()) + .await + { + Some(resolved_tools) => { + extend_unique_tools(&mut tools, &mut seen_tool_names, resolved_tools); + } + None => missing_tools.push(tt.clone()), + } } - tools.extend(tool.tools().await); } - } else { - missing_tools.push(tt); } + + warn_missing_tools(&all_tools, &missing_tools); + return tools; } - if !missing_tools.is_empty() { - let available_tools = all_tools - .keys() - .map(String::as_str) - .collect::>() - .join(", "); + for tt in opts.tools_by_name() { + match get_tool_by_name(Arc::clone(&rt), &all_tools, tt, table_allowlist.clone()).await { + Some(resolved_tools) => { + extend_unique_tools(&mut tools, &mut seen_tool_names, resolved_tools); + } + None => missing_tools.push(tt.to_string()), + } + } - tracing::warn!( - "The following tools were not found in the registry: {}.\nAvailable tools are: {available_tools}.\nFor details, visit https://spiceai.org/docs/features/large-language-models/tools", - missing_tools.join(", ") - ); + warn_missing_tools(&all_tools, &missing_tools); + + tools +} + +async fn all_available_tools( + rt: Arc, + all_tools: &HashMap, + configured_tool_names: &HashSet, + table_allowlist: Option, +) -> Vec> { + let mut tools = vec![]; + let mut seen_tool_names = HashSet::new(); + + for tool_name in SpiceToolsOptions::All.tools_by_name() { + if let Some(resolved_tools) = get_tool_by_name( + Arc::clone(&rt), + all_tools, + tool_name, + table_allowlist.clone(), + ) + .await + { + extend_unique_tools(&mut tools, &mut seen_tool_names, resolved_tools); + } + } + + let default_catalog_names = default_catalog_names(); + let mut tool_entries = all_tools.iter().collect::>(); + tool_entries.sort_by(|(left_name, _), (right_name, _)| left_name.cmp(right_name)); + + for (tool_name, tooling) in tool_entries { + if BuiltinToolCatalog::is_builtin_tool(tool_name) { + continue; + } + + if !configured_tool_names.contains(tool_name) { + continue; + } + + if let Tooling::Catalog(catalog) = tooling + && default_catalog_names.contains(&catalog.name()) + { + continue; + } + + extend_unique_tools(&mut tools, &mut seen_tool_names, tooling.tools().await); } tools } + +async fn configured_tool_names(rt: &Arc) -> HashSet { + rt.read_app() + .await + .map(|app| app.tools.iter().map(|tool| tool.name.clone()).collect()) + .unwrap_or_default() +} + +async fn get_tool_by_name( + rt: Arc, + all_tools: &HashMap, + tool_name: &str, + table_allowlist: Option, +) -> Option>> { + if let Some((catalog_name, catalog_tool)) = tool_name.split_once(':') { + let Some(Tooling::Catalog(catalog)) = all_tools.get(catalog_name) else { + return None; + }; + + let catalog = match ( + catalog.as_any().downcast_ref::(), + table_allowlist, + ) { + (None, Some(_)) => { + tracing::info!( + "Table allowlist is only applicable to builtin catalog/tools. Allowlist will not be applied to '{catalog_name}'" + ); + Arc::clone(catalog) + } + (Some(builtin_catalog), Some(allowlist)) => { + Arc::new(builtin_catalog.clone().with_table_allowlist(allowlist)) + as Arc + } + _ => Arc::clone(catalog), + }; + + if let Some(t) = catalog.get(catalog_tool).await { + return Some(vec![with_name( + &t, + format!("{catalog_name}/{}", t.name()).as_str(), + )]); + } + + tracing::warn!("Tool '{catalog_tool}' is not found in '{catalog_name}'."); + return None; + } + + let tool = all_tools.get(tool_name)?; + + if let Some(ref allowlist) = table_allowlist + && BuiltinToolCatalog::is_builtin_tool(tool_name) + { + if let Ok(t) = BuiltinToolCatalog::new(Arc::clone(&rt)) + .with_table_allowlist(allowlist.clone()) + .construct_builtin(tool_name, None, None, &HashMap::new()) + { + return Some(vec![t]); + } + + tracing::warn!("Failed to construct tool '{tool_name}' with table allowlist."); + return None; + } + + if table_allowlist.is_some() { + tracing::info!( + "Table allowlist is only applicable to builtin catalog/tools. Allowlist will not be applied to '{tool_name}'" + ); + } + + Some(tool.tools().await) +} + +fn extend_unique_tools( + tools: &mut Vec>, + seen_tool_names: &mut HashSet, + new_tools: Vec>, +) { + for tool in new_tools { + if seen_tool_names.insert(tool.name().to_string()) { + tools.push(tool); + } + } +} + +fn warn_missing_tools(all_tools: &HashMap, missing_tools: &[String]) { + if missing_tools.is_empty() { + return; + } + + let available_tools = all_tools + .keys() + .map(String::as_str) + .collect::>() + .join(", "); + + tracing::warn!( + "The following tools were not found in the registry: {}. Available tools are: {available_tools}. For details, visit https://spiceai.org/docs/features/large-language-models/tools", + missing_tools.join(", ") + ); +} diff --git a/crates/search/src/aggregation/reciprocal_rank.rs b/crates/search/src/aggregation/reciprocal_rank.rs index c01d4f50da..1f4779d714 100644 --- a/crates/search/src/aggregation/reciprocal_rank.rs +++ b/crates/search/src/aggregation/reciprocal_rank.rs @@ -12,6 +12,7 @@ limitations under the License. */ use std::collections::{HashMap, HashSet}; +use std::hash::Hash; use std::sync::Arc; use crate::aggregation::from_single_input; @@ -40,11 +41,59 @@ use snafu::ResultExt; /// The underlying score of the search results is not important, only the rank (per stream order). /// The rank, for a given entry (for some primary key `a`) is converted to a score using the formula: /// ```text -/// score_a = 1 / (rank_i + offset) + 1 / (rank_j + offset) + ... +/// score_a = 1 / (rank_i + k) + 1 / (rank_j + k) + ... /// ``` -/// Where `rank_i` is the rank of the i-th stream, and `offset` is a constant (e.g. 60). +/// Where `rank_i` is the rank of the i-th stream, and `k` is a smoothing constant (e.g. 60). pub struct ReciprocalRankFusion; +/// Default RRF smoothing parameter used across Spice hybrid search. +pub const DEFAULT_RRF_K: f64 = 60.0; + +const USIZE_TO_F64_CHUNK_BITS: usize = 16; +const USIZE_TO_F64_CHUNK_BASE: f64 = 65_536.0; +const USIZE_TO_F64_CHUNK_MASK: usize = (1usize << USIZE_TO_F64_CHUNK_BITS) - 1; + +#[must_use] +pub fn reciprocal_rank_score(rank: usize, k: f64) -> f64 { + 1.0 / (usize_to_f64(rank) + k) +} + +#[must_use] +pub fn usize_to_f64(value: usize) -> f64 { + let mut remaining = value; + let mut multiplier = 1.0; + let mut converted = 0.0; + + while remaining > 0 { + let chunk_bytes = (remaining & USIZE_TO_F64_CHUNK_MASK).to_le_bytes(); + let chunk = u16::from_le_bytes([chunk_bytes[0], chunk_bytes[1]]); + converted += f64::from(chunk) * multiplier; + remaining >>= USIZE_TO_F64_CHUNK_BITS; + multiplier *= USIZE_TO_F64_CHUNK_BASE; + } + + converted +} + +#[must_use] +pub fn reciprocal_rank_fusion_scores(ranked_lists: I, k: f64) -> HashMap +where + K: Eq + Hash, + L: IntoIterator, + I: IntoIterator, +{ + let mut scores = HashMap::new(); + for ranked_list in ranked_lists { + for (rank_index, key) in ranked_list.into_iter().enumerate() { + scores + .entry(key) + .and_modify(|score| *score += reciprocal_rank_score(rank_index + 1, k)) + .or_insert_with(|| reciprocal_rank_score(rank_index + 1, k)); + } + } + scores +} + #[async_trait] impl CandidateAggregation for ReciprocalRankFusion { async fn aggregate( @@ -142,7 +191,7 @@ impl CandidateAggregation for ReciprocalRankFusion { table_names.as_slice(), primary_key.as_slice(), additional_columns.as_slice(), - 60, + DEFAULT_RRF_K, limit, ) .await @@ -306,13 +355,12 @@ fn are_types_compatible(t1: &DataType, t2: &DataType) -> bool { /// /// This function takes already-registered table names from a SessionContext and builds /// a logical plan that performs reciprocal rank fusion across them. -#[expect(clippy::cast_precision_loss)] async fn reciprocal_rank_fusion_plan( ctx: &SessionContext, tables: &[TableReference], primary_key: &[Column], additional_columns: &[Column], - offset: usize, + k: f64, limit: usize, ) -> datafusion::error::Result { // 1) Build CTEs that add explicit rank per table, ranking by SEARCH_SCORE_COLUMN_NAME @@ -364,16 +412,16 @@ async fn reciprocal_rank_fusion_plan( )?; } - // 4) Build the RRF score: SUM(COALESCE(1.0/(rank + offset), 0)) across all tables + // 4) Build the RRF score: SUM(COALESCE(1.0/(rank + k), 0)) across all tables let rrf_score = ranked_plans .iter() .map(|(table_name, _)| { let rank_col = col(Column::new(Some(table_name.clone()), "rank")); - let offset_lit = lit(offset as f64); + let k_lit = lit(k); let score = binary_expr( lit(1.0), Operator::Divide, - binary_expr(rank_col, Operator::Plus, offset_lit), + binary_expr(rank_col, Operator::Plus, k_lit), ); coalesce(vec![score, lit(0.0)]) }) @@ -437,6 +485,40 @@ mod tests { // The logical plan is tested through integration tests and runtime behavior verification. // If snapshot testing is needed, consider using LogicalPlan's display_indent() or explain methods. + #[test] + fn reciprocal_rank_fusion_scores_combines_ranked_lists() { + let scores = reciprocal_rank_fusion_scores( + vec![vec!["sql", "search"], vec!["search", "table_schema"]], + DEFAULT_RRF_K, + ); + + let search_score = scores + .get("search") + .expect("search should be present in fused scores"); + let sql_score = scores + .get("sql") + .expect("sql should be present in fused scores"); + let table_schema_score = scores + .get("table_schema") + .expect("table_schema should be present in fused scores"); + + assert!(search_score > sql_score); + assert!(sql_score > table_schema_score); + } + + #[test] + fn reciprocal_rank_score_decreases_past_u32_max_rank() { + let u32_max_rank = usize::try_from(u32::MAX).expect("u32::MAX should fit in usize"); + let larger_rank = u32_max_rank + .checked_add(1) + .expect("test requires usize wider than u32"); + + assert!( + reciprocal_rank_score(larger_rank, DEFAULT_RRF_K) + < reciprocal_rank_score(u32_max_rank, DEFAULT_RRF_K) + ); + } + #[test] fn test_additional_columns_of_schema() { let schema = Arc::new(Schema::new(vec![ From 903f307920b11971b0baf8bfee2b2e4171616303 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc <879445+phillipleblanc@users.noreply.github.com> Date: Mon, 4 May 2026 23:59:50 +0900 Subject: [PATCH 2/6] feat: refresh_mode: snapshot + SQLite/Turso WAL flush + Cayenne metastore slice (stacked) (#10651) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(snapshot): flush SQLite/Turso WAL before snapshotting SQLite and Turso accelerators run with WAL journal mode. The pre-existing snapshot path performed `fs::copy(live_db, temp_copy)` directly, which captured only the durable pages — every uncheckpointed write (typically all writes since the last automatic checkpoint) lived only in the `-wal` sidecar and was silently lost in the upload. This bug was invisible under `refresh_mode: full|append|changes` because the federated source repopulates on bootstrap. It surfaces as data loss with `refresh_mode: snapshot` (separate change) where the snapshot is the only data source. Fixes by: 1. Adding `SnapshotEngine::checkpoint_live` (default no-op) — invoked by `SnapshotManager::create_file_snapshot` *before* `fs::copy` while the accelerator write lock is held. 2. New `SqliteSnapshotEngine` overrides `checkpoint_live` to run `PRAGMA wal_checkpoint(TRUNCATE)` against the live file via a short-lived rusqlite connection. WAL mode allows multiple connections; the existing accelerator write lock guarantees no concurrent writers race the checkpoint. 3. `SqliteSnapshotEngine::prepare_for_upload` switches the *copied* file to `journal_mode=DELETE` as defense-in-depth: even if the copy ever ended up with stale `-wal`/`-shm` sidecars next to it, the uploaded file is guaranteed to be self-contained. 4. New `TursoSnapshotEngine` reuses the SqliteSnapshotEngine logic (libsql is on-disk-compatible with SQLite). The `turso` feature now pulls in rusqlite for this purpose. 5. `create_snapshot_engine` now returns the engine-specific impls for `AccelerationEngine::Sqlite` and `::Turso` instead of `DefaultSnapshotEngine`. Closes #10643 * feat(acceleration): add refresh_mode: snapshot for snapshot-only refreshes Adds a new refresh mode that loads accelerated tables only by polling the snapshot store for newer snapshots, never querying the federated source. Use cases: read replicas, expensive/slow sources, air-gapped propagation. Behavior: * `RefreshMode::Snapshot` is a new variant accepted in spicepod `acceleration.refresh_mode` (and is rejected by the validator unless `acceleration.snapshots` is enabled). * On startup, the dataset bootstraps from the snapshot store as today. * The refresh task polls the snapshot store on a configurable interval (default 60 s) via the new `SnapshotManager::download_if_newer(current_local_id, validate_schema)` API. The schema validator runs against the snapshot's `metadata.json` *before* download so a mismatched snapshot can never overwrite the primary file. * Per-engine reload is performed via a new `Accelerator::reload_from_snapshot` trait method. DuckDB/SQLite/ Turso impls evict pool connections and reopen at the same primary path; the live federated query path serves the prior inode until the swap completes via `SwappableTableProvider` (RwLock-backed `Arc` that caches schema/constraints/table_type). * `INSERT INTO ` is rejected with a clear error when the accelerator's refresh_mode is Snapshot. * Reload + swap is serialized against concurrent accelerator writes via the existing accelerator write mutex. * Atomic file write on download (temp + fsync + rename + parent fsync). Integration tests: * `snapshot_refresh::duckdb` — bootstrap → writer change → reader picks up new snapshot → swap → query reflects change. * `snapshot_refresh::sqlite` — same shape; passes thanks to the SQLite WAL flush in the parent commit. * `snapshot_refresh::turso` — same shape; same WAL flush dependency. CI: new `Snapshot Refresh Mode Integration Tests` job in integration.yml runs DuckDB/SQLite/Turso variants against MinIO. Cayenne is intentionally not covered in this commit: portable Cayenne snapshots require a per-dataset metastore export/import refactor (addressed in the next commit). The snapshot_refresh test directory omits Cayenne until then. Depends on the SQLite/Turso WAL flush from the previous commit; without it the SQLite/Turso integration tests cannot pass. Pulls in the upstream pool-eviction APIs from datafusion-contrib/datafusion-table-providers#635 (`invalidate_instance`, `invalidate_file_instance`). * feat(cayenne): add per-dataset metastore export/import (snapshot foundation) Adds the foundation for portable, per-dataset Cayenne metastore snapshots: a versioned-JSON 'slice' format that captures only one dataset's rows from cayenne_table and its dependent tables, with path columns rewritten to be relative to a configurable anchor. Why: The legacy Cayenne snapshot format archived the entire cayenne.db SQLite file. That approach forced a one-dataset-per-metadata-directory limit (see validate_cayenne_snapshot_consistency) because two snapshots would clobber each other's metastore rows on extract, and produced snapshots that were not portable across nodes whose data directories did not match the writer's absolute paths (#10642). It also raced with the reader process's eager metastore initialization (#10649): SQLite journal sidecars created during init triggered checksum mismatches against the archive on extract. This commit lays the groundwork to fix all three issues by introducing: * `crates/cayenne/src/metastore/snapshot.rs` — the slice format and export/import logic: - `DatasetMetastoreSlice` (versioned JSON, format_version: 1) with one entry per metastore table from EXPECTED_TABLES. - `SliceValue` mirrors MetastoreValue but JSON-friendly (blobs base64-encoded under the 'x' tag). - `export_dataset(metastore, dataset_name, anchor)` selects every row that belongs to the dataset (the cayenne_table row plus all rows in dependent tables matching the same table_id) and emits a slice. Path columns (cayenne_table.path, cayenne_delete_file.path, cayenne_partition.path) are rewritten relative to `anchor` so the slice contains no absolute filesystem paths. - `import_dataset(metastore, slice, anchor)` runs inside a single transaction: deletes any existing rows for the same dataset_name (FK ON DELETE CASCADE removes all dependent rows), then inserts the slice's rows with paths rewritten back to absolute form anchored at `anchor`. Unsupported format_version or engine mismatch are rejected up front. * `MetastoreRow::get_value(index) -> MetastoreValue` — a new trait method that returns the raw column value without type coercion, so generic export logic can serialize columns without knowing each column's expected type at compile time. SQLite and Turso impls add a one-line implementation that clones from their existing values Vec. Tests in `crates/cayenne/src/metastore/snapshot.rs::tests`: * `round_trip_preserves_rows_and_relocates_paths` — exports a dataset from one metastore, imports into a fresh metastore at a different anchor, verifies all partitions resolve to the new anchor. * `import_replaces_prior_dataset_rows_wholesale` — confirms the DuckDB-style snapshot-refresh semantic that import wipes any prior rows for the same dataset_name before inserting (no leftover partitions). * `import_leaves_other_datasets_untouched` — confirms slicing is correctly scoped: importing dataset A does not affect dataset B's rows in the same shared metastore. * `rejects_unsupported_format_version` — both wrong format_version and wrong engine identifier are rejected with clear error messages *before* any DML runs. * `json_round_trip` — slice serializes and parses back to an equivalent slice. Follow-up (separate commit / PR): wire `CayenneSnapshotEngine` into `SnapshotManager`'s directory create/extract paths so the new format is actually used by snapshot upload/download. With that follow-up: * the cayenne.db file leaves the archive (replaces #10649), * paths become portable (closes #10642), * the single-dataset-per-metadata-dir validation can be lifted, * Cayenne refresh_mode: snapshot becomes a passing integration test. * feat(cayenne): wire CayenneSnapshotEngine into snapshot pipeline This is Commit 4 of the refresh_mode: snapshot stack. Commit 3 added the metastore export/import API; this commit wires it into the SnapshotManager upload/download pipeline so Cayenne snapshots actually use the per-dataset slice format instead of shipping raw cayenne.db. Closes: - spiceai/spiceai#10642 (Cayenne snapshots not portable across data dirs) - spiceai/spiceai#10649 (Cayenne metastore-init races with checksum) Architecture: * Trait extension in runtime-acceleration::snapshot::engine: Adds two new SnapshotEngine methods, each with a no-op default so DuckDB/SQLite/Turso continue to behave exactly as before: - prepare_directory_snapshot(dirs, dataset_name) -> DirectorySnapshotPlan Returns (a) a list of file paths *relative to each source dir* that should be excluded from the tar, and (b) extra in-memory entries to append at the end of the tar. - finalize_directory_snapshot(dirs, dataset_name, extras) Engine-specific post-extract hook. Plus a new SnapshotEngineError::Custom { message } variant and a SnapshotEngineError::from_display() constructor so engines defined outside runtime-acceleration (CayenneSnapshotEngine in the runtime crate) can surface their own rich errors without needing a feature-gated variant in the trait crate. * directory_archive: new archive_directories_with_plan() that takes skip_relative_paths + extras. The original archive_directories() is preserved as a thin wrapper. The new add_directory_to_archive_filtered helper handles the recursive walk while honoring the skip set. The module remains engine-agnostic — the skip predicate is just a HashSet passed in. * SnapshotManager: - create_directory_snapshot now calls prepare_directory_snapshot first, then archive_directories_with_plan with the engine's skip list and extras. - download_to_directories now calls finalize_directory_snapshot after extract. - New with_snapshot_engine(self, Arc) builder so accelerators can inject a custom engine. * download_snapshot_if_needed and snapshot_before_recreate gain an Option> parameter. DuckDB/SQLite/Turso pass None (default behavior). Cayenne builds and passes a CayenneSnapshotEngine. * Cayenne accelerator (crates/runtime/src/dataaccelerator/cayenne): - New module snapshot_engine.rs implements CayenneSnapshotEngine. Holds an Arc, dataset_name, and data_dir_anchor. prepare_directory_snapshot exports the per-dataset slice via MetadataCatalog::export_dataset_slice and instructs the archiver to skip cayenne.db / -wal / -shm. finalize_directory_snapshot reads the slice JSON back from the extracted directory and calls MetadataCatalog::import_dataset_slice. - The cayenne accelerator's bootstrap path constructs a CayenneSnapshotEngine using its existing get_or_create_catalog machinery and threads it through download_snapshot_if_needed. * cayenne::MetadataCatalog: new export_dataset_slice and import_dataset_slice trait methods (default impls return InvalidOperationNoSource); CayenneCatalog overrides both with dispatch to the underlying SQLite or libsql metastore. * Lifted validate_cayenne_snapshot_consistency's SharedAcceleration restriction. With per-dataset slices, multiple Cayenne datasets sharing a metastore directory can each ship their own snapshot without clobbering each other's metastore rows on extract. The InconsistentSnapshotSettings (mixed enabled/disabled) check stays. * Re-enabled the Cayenne snapshot_refresh integration test (snapshot_refresh::cayenne::snapshot_refresh_cayenne_bootstrap_then_refresh). Tests: - 3 new unit tests in cayenne::snapshot_engine::tests: * create_directory_snapshot_skips_cayenne_db_and_emits_slice * refuses_mismatched_dataset * finalize_missing_slice_returns_clear_error - validate_snapshots tests updated; the test_cayenne_shared_acceleration_with_snapshots_errors test is renamed to ..._now_supported and asserts Ok. - Existing 91 runtime-acceleration::snapshot lib tests still pass. - Cayenne integration test runs against real S3 in CI (same harness as the DuckDB/SQLite/Turso refresh-mode tests). * fix(snapshot): bootstrap dataset checkpoint from snapshot metadata When a downloaded snapshot doesn't carry a populated _dataset_checkpoint row (the steady state for Cayenne, where the on-disk archive ships the per-dataset metastore slice JSON instead of the raw cayenne.db), the post-extract Checkpointer::get_schema() returns None and bootstrap was failing with MissingSchema. The metadata.json carries the dataset's verified schema; after import we now materialize that schema into the local checkpoint via Checkpointer::checkpoint(metadata_schema, None) so the spice_sys side matches the snapshot. For DuckDB / SQLite / Turso this branch is unreachable in steady state (their archives ship _dataset_checkpoint already); on a corrupted snapshot the self-heal is harmless because the metadata schema is the same one we'd otherwise have validated against. This unblocks refresh_mode: snapshot for Cayenne and re-enables the snapshot_refresh::cayenne integration test by default. Closes #10658. --- .github/workflows/integration.yml | 50 ++ Cargo.lock | 9 +- Cargo.toml | 2 +- crates/cayenne/src/catalog.rs | 43 ++ crates/cayenne/src/cayenne_catalog.rs | 32 + crates/cayenne/src/lib.rs | 1 + crates/cayenne/src/metastore.rs | 9 + crates/cayenne/src/metastore/snapshot.rs | 708 ++++++++++++++++++ crates/cayenne/src/metastore/sqlite.rs | 9 + crates/cayenne/src/metastore/turso.rs | 9 + crates/runtime-acceleration/Cargo.toml | 5 +- .../src/snapshot/directory_archive.rs | 79 +- .../src/snapshot/engine.rs | 126 +++- .../src/snapshot/engine/duckdb.rs | 1 + .../src/snapshot/engine/sqlite.rs | 253 +++++++ .../src/snapshot/engine/turso.rs | 81 ++ .../runtime-acceleration/src/snapshot/mod.rs | 526 +++++++++++-- crates/runtime/src/accelerated_table/mod.rs | 40 +- .../runtime/src/accelerated_table/refresh.rs | 27 +- .../src/accelerated_table/refresh_task.rs | 267 +++++++ .../accelerated_table/refresh_task_runner.rs | 15 + .../src/accelerated_table/snapshots.rs | 63 ++ .../src/component/dataset/acceleration.rs | 12 + .../src/dataaccelerator/cayenne/mod.rs | 114 ++- .../cayenne/snapshot_engine.rs | 460 ++++++++++++ crates/runtime/src/dataaccelerator/duckdb.rs | 59 +- crates/runtime/src/dataaccelerator/mod.rs | 96 +++ .../runtime/src/dataaccelerator/snapshots.rs | 54 +- crates/runtime/src/dataaccelerator/sqlite.rs | 46 +- .../runtime/src/dataaccelerator/swappable.rs | 310 ++++++++ crates/runtime/src/dataaccelerator/turso.rs | 38 + .../datafusion/iceberg_ddl/physical_plans.rs | 1 + crates/runtime/src/datafusion/mod.rs | 248 ++++++ crates/runtime/src/tracing_util.rs | 3 + .../tests/integration_snapshot_refresh.rs | 81 ++ .../runtime/tests/snapshot_refresh/cayenne.rs | 32 + .../runtime/tests/snapshot_refresh/duckdb.rs | 22 + crates/runtime/tests/snapshot_refresh/mod.rs | 599 +++++++++++++++ .../runtime/tests/snapshot_refresh/sqlite.rs | 26 + .../runtime/tests/snapshot_refresh/turso.rs | 31 + crates/spicepod/src/acceleration/mod.rs | 31 + 41 files changed, 4527 insertions(+), 91 deletions(-) create mode 100644 crates/cayenne/src/metastore/snapshot.rs create mode 100644 crates/runtime-acceleration/src/snapshot/engine/sqlite.rs create mode 100644 crates/runtime-acceleration/src/snapshot/engine/turso.rs create mode 100644 crates/runtime/src/dataaccelerator/cayenne/snapshot_engine.rs create mode 100644 crates/runtime/src/dataaccelerator/swappable.rs create mode 100644 crates/runtime/tests/integration_snapshot_refresh.rs create mode 100644 crates/runtime/tests/snapshot_refresh/cayenne.rs create mode 100644 crates/runtime/tests/snapshot_refresh/duckdb.rs create mode 100644 crates/runtime/tests/snapshot_refresh/mod.rs create mode 100644 crates/runtime/tests/snapshot_refresh/sqlite.rs create mode 100644 crates/runtime/tests/snapshot_refresh/turso.rs diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 7d53ede02c..37d0f9d76f 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -140,6 +140,7 @@ jobs: cargo nextest archive -p runtime --test integration --features ${FEATURES} --archive-file integration.tar.zst cargo nextest archive -p runtime --test retention_oom --features ${FEATURES} --archive-file retention_oom.tar.zst cargo nextest archive -p runtime --test integration_aws_sdk --features databricks,delta_lake --archive-file integration_aws_sdk.tar.zst + cargo nextest archive -p runtime --test integration_snapshot_refresh --features duckdb,sqlite,turso,snapshots --archive-file integration_snapshot_refresh.tar.zst cargo nextest archive -p aws-sdk-credential-bridge --test credential_provider --archive-file integration_aws_sdk_credential_bridge.tar.zst cargo nextest archive -p runtime-table-partition --test partition_table_provider --archive-file partition_table_test.tar.zst cargo nextest archive -p data_components --test hadoop_catalog_test --archive-file data_components_hadoop.tar.zst @@ -175,6 +176,16 @@ jobs: # Archive is already zstd-compressed; use minimal artifact zip compression. compression-level: 1 + - name: Upload snapshot refresh test archive + if: needs.check_changes.outputs.relevant_changes == 'true' + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 + with: + name: integration-snapshot-refresh-test-archive + path: ./integration_snapshot_refresh.tar.zst + retention-days: 3 + # Archive is already zstd-compressed; use minimal artifact zip compression. + compression-level: 1 + - name: Upload AWS SDK credential bridge test archive if: needs.check_changes.outputs.relevant_changes == 'true' uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7 @@ -499,6 +510,45 @@ jobs: run: | INSTA_WORKSPACE_ROOT="${PWD}" CARGO_MANIFEST_DIR="${PWD}" cargo nextest run --workspace-remap "${PWD}" --archive-file ./integration_aws_sdk_credential_bridge_test/integration_aws_sdk_credential_bridge.tar.zst + test-snapshot-refresh: + name: Snapshot Refresh Mode Integration Tests + needs: [build, check_changes] + if: needs.check_changes.outputs.relevant_changes == 'true' + permissions: read-all + runs-on: spiceai-dev-runners + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + with: + fetch-depth: 1 + + - name: Set up Rust + uses: ./.github/actions/setup-rust + + - name: Download snapshot refresh test archive + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 + with: + name: integration-snapshot-refresh-test-archive + path: ./integration_snapshot_refresh_test + + - name: Set up Nextest + uses: ./.github/actions/setup-nextest + + - name: Login to ACR + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0 + if: github.repository == 'spiceai/spiceai' + with: + registry: spiceaitestimages.azurecr.io + username: spiceai-repo-pull + password: ${{ secrets.AZCR_PASSWORD }} + + - name: Run snapshot refresh integration tests + env: + AWS_EC2_METADATA_DISABLED: true + AWS_SNAPSHOT_KEY: ${{ secrets.AWS_ICEBERG_ACCESS_KEY_ID }} + AWS_SNAPSHOT_SECRET: ${{ secrets.AWS_ICEBERG_SECRET_ACCESS_KEY }} + run: | + INSTA_WORKSPACE_ROOT="${PWD}" CARGO_MANIFEST_DIR="${PWD}" cargo nextest run --workspace-remap "${PWD}" --archive-file ./integration_snapshot_refresh_test/integration_snapshot_refresh.tar.zst + test-data-components: name: Data Components Integration Tests needs: [build, check_changes] diff --git a/Cargo.lock b/Cargo.lock index 8e9a9c1702..11611b771c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6035,7 +6035,7 @@ dependencies = [ [[package]] name = "datafusion-table-providers" version = "0.1.0" -source = "git+https://github.com/datafusion-contrib/datafusion-table-providers.git?rev=915f03870eff972dab671aa3481a3b55a289d2b9#915f03870eff972dab671aa3481a3b55a289d2b9" +source = "git+https://github.com/datafusion-contrib/datafusion-table-providers.git?rev=97ecd0059bd49297b956b6dd51c7047547cc97e0#97ecd0059bd49297b956b6dd51c7047547cc97e0" dependencies = [ "adbc_core", "adbc_driver_manager", @@ -7995,8 +7995,8 @@ dependencies = [ "libc", "log", "rustversion", - "windows-link 0.2.1", - "windows-result 0.4.1", + "windows-link 0.1.3", + "windows-result 0.3.4", ] [[package]] @@ -9081,7 +9081,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.62.2", + "windows-core 0.61.2", ] [[package]] @@ -15463,6 +15463,7 @@ dependencies = [ "runtime-object-store", "runtime-parameters", "runtime-secrets", + "rusqlite", "serde", "serde_json", "sha2", diff --git a/Cargo.toml b/Cargo.toml index fd03872e26..f966b26e7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -417,7 +417,7 @@ datafusion-physical-optimizer = { git = "https://github.com/spiceai/datafusion.g datafusion-spark = { git = "https://github.com/spiceai/datafusion.git", rev = "06e4b624c6073c40c7b2127ce620e281ec1979ae" } # spiceai-52.5 datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", rev = "06e4b624c6073c40c7b2127ce620e281ec1979ae" } # spiceai-52.5 -datafusion-table-providers = { git = "https://github.com/datafusion-contrib/datafusion-table-providers.git", rev = "915f03870eff972dab671aa3481a3b55a289d2b9" } # spiceai-52 +datafusion-table-providers = { git = "https://github.com/datafusion-contrib/datafusion-table-providers.git", rev = "97ecd0059bd49297b956b6dd51c7047547cc97e0" } # spiceai-52 with invalidate_instance / invalidate_file_instance (datafusion-contrib/datafusion-table-providers#635) ballista-core = { git = "https://github.com/spiceai/datafusion-ballista.git", rev = "383e165a080d648c313a2530a3a53eae6077fdba" } # spiceai-52.5 ballista-executor = { git = "https://github.com/spiceai/datafusion-ballista.git", rev = "383e165a080d648c313a2530a3a53eae6077fdba" } # spiceai-52.5 diff --git a/crates/cayenne/src/catalog.rs b/crates/cayenne/src/catalog.rs index 02df16e9fd..b22efcc0ec 100644 --- a/crates/cayenne/src/catalog.rs +++ b/crates/cayenne/src/catalog.rs @@ -412,6 +412,49 @@ pub trait MetadataCatalog: Send + Sync { /// /// Returns `Ok(true)` if the table was dropped, `Ok(false)` if the table didn't exist. async fn drop_table(&self, table_name: &str) -> CatalogResult; + + /// Export the metastore rows for `dataset_name` as a portable, versioned + /// slice with path columns rewritten relative to `data_dir_anchor`. + /// + /// Default implementation returns [`CatalogError::InvalidOperation`]. + /// `CayenneCatalog` overrides this with a real implementation. + /// + /// # Errors + /// + /// Returns an error if the dataset is not present or the underlying metastore + /// query fails. + async fn export_dataset_slice( + &self, + dataset_name: &str, + data_dir_anchor: &std::path::Path, + ) -> CatalogResult { + let _ = (dataset_name, data_dir_anchor); + Err(CatalogError::InvalidOperationNoSource { + message: "export_dataset_slice is not supported by this catalog".to_string(), + }) + } + + /// Atomically import a dataset slice into the metastore, replacing any + /// prior rows for the same `dataset_name`. Path columns are re-anchored + /// at `data_dir_anchor`. + /// + /// Default implementation returns [`CatalogError::InvalidOperation`]. + /// `CayenneCatalog` overrides this with a real implementation. + /// + /// # Errors + /// + /// Returns an error if the slice format is unsupported, the engine identifier + /// mismatches, or any DML in the underlying transaction fails. + async fn import_dataset_slice( + &self, + slice: &crate::metastore::snapshot::DatasetMetastoreSlice, + data_dir_anchor: &std::path::Path, + ) -> CatalogResult<()> { + let _ = (slice, data_dir_anchor); + Err(CatalogError::InvalidOperationNoSource { + message: "import_dataset_slice is not supported by this catalog".to_string(), + }) + } } /// Factory trait for creating catalog instances. diff --git a/crates/cayenne/src/cayenne_catalog.rs b/crates/cayenne/src/cayenne_catalog.rs index 4eaa3b084c..d19176a24b 100644 --- a/crates/cayenne/src/cayenne_catalog.rs +++ b/crates/cayenne/src/cayenne_catalog.rs @@ -1597,6 +1597,38 @@ impl MetadataCatalog for CayenneCatalog { Ok(true) } + + async fn export_dataset_slice( + &self, + dataset_name: &str, + data_dir_anchor: &std::path::Path, + ) -> CatalogResult { + match &self.metastore { + MetastoreImpl::Sqlite(m) => { + crate::metastore::snapshot::export_dataset(m, dataset_name, data_dir_anchor).await + } + #[cfg(feature = "turso")] + MetastoreImpl::Turso(m) => { + crate::metastore::snapshot::export_dataset(m, dataset_name, data_dir_anchor).await + } + } + } + + async fn import_dataset_slice( + &self, + slice: &crate::metastore::snapshot::DatasetMetastoreSlice, + data_dir_anchor: &std::path::Path, + ) -> CatalogResult<()> { + match &self.metastore { + MetastoreImpl::Sqlite(m) => { + crate::metastore::snapshot::import_dataset(m, slice, data_dir_anchor).await + } + #[cfg(feature = "turso")] + MetastoreImpl::Turso(m) => { + crate::metastore::snapshot::import_dataset(m, slice, data_dir_anchor).await + } + } + } } fn is_retryable_write_conflict(error: &CatalogError) -> bool { diff --git a/crates/cayenne/src/lib.rs b/crates/cayenne/src/lib.rs index f302b1c326..119811c073 100644 --- a/crates/cayenne/src/lib.rs +++ b/crates/cayenne/src/lib.rs @@ -70,6 +70,7 @@ pub(crate) mod schema; pub mod stats; pub use catalog::MetadataCatalog; +pub use catalog::{CatalogError, CatalogResult}; pub use catalog_provider::{ CayenneCatalogProvider, CayenneCatalogProviderConfig, CayenneSchemaProvider, }; diff --git a/crates/cayenne/src/metastore.rs b/crates/cayenne/src/metastore.rs index 8461cc0e54..f0ea0f93fa 100644 --- a/crates/cayenne/src/metastore.rs +++ b/crates/cayenne/src/metastore.rs @@ -20,6 +20,7 @@ limitations under the License. //! that can be used to store Cayenne metadata. This allows swapping between `SQLite`, //! Turso, or other storage implementations. +pub mod snapshot; pub mod sqlite; #[cfg(feature = "turso")] @@ -281,6 +282,14 @@ impl> From> for MetastoreValue { /// A row returned from a query. pub trait MetastoreRow: Send { + /// Get the raw `MetastoreValue` for a column by index. Used by + /// generic export/import logic that does not know the column types. + /// + /// # Errors + /// + /// Returns an error if the column index is out of bounds. + fn get_value(&self, index: usize) -> CatalogResult; + /// Get an i64 value from the row by column index. /// /// # Errors diff --git a/crates/cayenne/src/metastore/snapshot.rs b/crates/cayenne/src/metastore/snapshot.rs new file mode 100644 index 0000000000..7083e74763 --- /dev/null +++ b/crates/cayenne/src/metastore/snapshot.rs @@ -0,0 +1,708 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Per-dataset metastore snapshot serialization. +//! +//! The legacy Cayenne snapshot format archived the entire `cayenne.db` `SQLite` +//! file. That approach forced a one-dataset-per-metadata-directory limitation +//! (multiple datasets sharing a metastore would clobber each other on extract) +//! and made snapshots non-portable across nodes whose data directories did not +//! match the writer's absolute paths. +//! +//! This module replaces that with a portable, per-dataset metastore "slice": +//! +//! * **Export**: `export_dataset(metastore, dataset, anchor)` collects every +//! metastore row that belongs to `dataset` (the `cayenne_table` row keyed +//! by `table_name`, plus all rows in dependent tables that reference that +//! `table_id`) and emits a versioned JSON document. Path columns are +//! rewritten to be relative to `anchor` so the slice does not embed +//! filesystem-specific paths. +//! +//! * **Import**: `import_dataset(metastore, slice, anchor)` atomically +//! replaces any local rows for the same `table_name` with the slice's +//! contents inside a single `BEGIN IMMEDIATE` transaction. Path columns +//! are rewritten back to absolute form anchored at the local `anchor`. +//! FK `ON DELETE CASCADE` removes the dataset's prior dependent rows when +//! the existing `cayenne_table` row is deleted. +//! +//! The slice format is **versioned** (`format_version: 1`) so future +//! changes can be detected and rejected with a clear error. + +use std::collections::BTreeMap; +use std::path::Path; + +use base64::Engine as _; +use base64::engine::general_purpose::STANDARD as BASE64; +use serde::{Deserialize, Serialize}; + +use super::{EXPECTED_TABLES, ExecuteParams, MetastoreBackend, MetastoreValue, QueryParams}; +use crate::catalog::{CatalogError, CatalogResult}; + +/// Current slice format version. Incremented on incompatible format changes. +pub const SLICE_FORMAT_VERSION: u32 = 1; + +/// Engine identifier embedded in slices to detect cross-engine misuse. +pub const SLICE_ENGINE: &str = "cayenne"; + +/// JSON-friendly mirror of [`MetastoreValue`]. Blobs are base64-encoded so the +/// document remains valid UTF-8 JSON. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "t", content = "v")] +pub enum SliceValue { + /// 64-bit signed integer. + #[serde(rename = "i")] + Integer(i64), + /// UTF-8 text. + #[serde(rename = "s")] + Text(String), + /// Boolean. + #[serde(rename = "b")] + Bool(bool), + /// Binary blob, base64-encoded for JSON-friendliness. + #[serde(rename = "x")] + Blob(String), + /// SQL NULL. + #[serde(rename = "n")] + Null, +} + +impl From<&MetastoreValue> for SliceValue { + fn from(v: &MetastoreValue) -> Self { + match v { + MetastoreValue::Integer(i) => SliceValue::Integer(*i), + MetastoreValue::Text(s) => SliceValue::Text(s.clone()), + MetastoreValue::Bool(b) => SliceValue::Bool(*b), + MetastoreValue::Blob(b) => SliceValue::Blob(BASE64.encode(b)), + MetastoreValue::Null => SliceValue::Null, + } + } +} + +impl SliceValue { + /// Convert back to a `MetastoreValue`. + /// + /// # Errors + /// + /// Returns an error if a `Blob` variant contains invalid base64. + pub fn into_metastore_value(self) -> CatalogResult { + Ok(match self { + SliceValue::Integer(i) => MetastoreValue::Integer(i), + SliceValue::Text(s) => MetastoreValue::Text(s), + SliceValue::Bool(b) => MetastoreValue::Bool(b), + SliceValue::Blob(b64) => { + let bytes = BASE64 + .decode(b64.as_bytes()) + .map_err(|e| CatalogError::Database { + message: format!("invalid base64 blob in metastore slice: {e}"), + })?; + MetastoreValue::Blob(bytes) + } + SliceValue::Null => MetastoreValue::Null, + }) + } +} + +/// One row of a slice's per-table contents. Ordered to match the column order +/// in [`EXPECTED_TABLES`]. +pub type SliceRow = Vec; + +/// Versioned, dataset-scoped slice of the Cayenne metastore. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatasetMetastoreSlice { + /// Slice format version. Must equal [`SLICE_FORMAT_VERSION`] for this build. + pub format_version: u32, + /// Engine identifier; must equal [`SLICE_ENGINE`] (`"cayenne"`). + pub engine: String, + /// Logical dataset name (matches `cayenne_table.table_name`). + pub dataset_name: String, + /// Wall-clock timestamp (milliseconds since epoch) when the slice was exported. + pub exported_at_ms: i64, + /// Map of metastore table name -> rows. Each row is positional; + /// column order must match the corresponding [`EXPECTED_TABLES`] entry. + pub tables: BTreeMap>, +} + +impl DatasetMetastoreSlice { + /// Marshal to a JSON byte vector suitable for embedding in a snapshot + /// archive. + /// + /// # Errors + /// + /// Propagates JSON serialization errors. + pub fn to_json_bytes(&self) -> Result, serde_json::Error> { + serde_json::to_vec(self) + } + + /// Parse from a JSON byte slice. Validates `format_version` and `engine`. + /// + /// # Errors + /// + /// Returns an error if the JSON is malformed, the format version is + /// unsupported, or the engine identifier mismatches. + pub fn from_json_bytes(bytes: &[u8]) -> CatalogResult { + let slice: Self = serde_json::from_slice(bytes).map_err(|e| CatalogError::Database { + message: format!("failed to parse metastore slice JSON: {e}"), + })?; + if slice.format_version != SLICE_FORMAT_VERSION { + return Err(CatalogError::Database { + message: format!( + "unsupported metastore slice format_version {} (this build understands only {SLICE_FORMAT_VERSION})", + slice.format_version + ), + }); + } + if slice.engine != SLICE_ENGINE { + return Err(CatalogError::Database { + message: format!( + "metastore slice engine mismatch: expected '{SLICE_ENGINE}', got '{}'", + slice.engine + ), + }); + } + Ok(slice) + } +} + +/// Returns the (`path_column_index`, `path_is_relative_column_index`) for tables +/// that store filesystem paths. Returns `None` for tables without path columns. +fn path_columns_for_table(table_name: &str) -> Option<(usize, usize)> { + match table_name { + "cayenne_table" | "cayenne_delete_file" => Some((2, 3)), + "cayenne_partition" => Some((5, 6)), + _ => None, + } +} + +/// Returns the column index that holds `table_id` for each metastore table. +/// `cayenne_table` itself stores it at index 0; child tables store it at index 1. +/// +/// Currently informational — wholesale-replace import preserves the slice's +/// own `table_id` values verbatim, so no remap is needed. Kept for the future +/// case where we might want to re-key on import. +#[expect( + dead_code, + reason = "retained for future re-keying on import; see doc above" +)] +fn table_id_column_index(table_name: &str) -> usize { + match table_name { + "cayenne_table" => 0, + _ => 1, + } +} + +/// Rewrite a path to be relative to `anchor`, if it is currently absolute and +/// lives under `anchor`. Returns the original path unchanged otherwise. +/// +/// This is intentionally lenient: paths outside the anchor (which would +/// indicate misconfiguration on the writer) are left untouched and surface +/// later as "file not found" on the reader if the absolute path does not +/// resolve there. We log a warning so the operator notices. +fn make_relative(abs: &str, anchor: &Path) -> String { + let p = Path::new(abs); + if let Ok(rel) = p.strip_prefix(anchor) { + rel.to_string_lossy().into_owned() + } else { + tracing::warn!( + "cayenne metastore export: path {abs:?} is not under anchor {anchor:?}; \ + leaving as-is — slice will not be portable to readers with a different data directory" + ); + abs.to_string() + } +} + +/// Rewrite a (possibly relative) path back to absolute, anchored at `anchor`. +/// Paths that already are absolute are returned unchanged (defensive: handles +/// the lenient case in [`make_relative`]). +fn make_absolute(rel_or_abs: &str, anchor: &Path) -> String { + let p = Path::new(rel_or_abs); + if p.is_absolute() { + rel_or_abs.to_string() + } else { + anchor.join(p).to_string_lossy().into_owned() + } +} + +/// Lookup `table_id` for the given dataset, returning `None` if not found. +async fn lookup_table_id( + metastore: &impl MetastoreBackend, + dataset_name: &str, +) -> CatalogResult> { + let rows = metastore + .query( + QueryParams { + sql: "SELECT table_id FROM cayenne_table WHERE table_name = ?", + params: vec![MetastoreValue::Text(dataset_name.to_string())], + }, + |row| row.get_string(0), + ) + .await?; + Ok(rows.into_iter().next()) +} + +/// Export this dataset's rows from the metastore as a versioned slice. +/// +/// Path columns are rewritten relative to `data_dir_anchor` so the resulting +/// slice is portable to readers with a different data directory, provided +/// they re-anchor at their own data directory on import. +/// +/// # Errors +/// +/// Returns an error if the dataset does not exist, or if any underlying +/// metastore query fails. +pub async fn export_dataset( + metastore: &impl MetastoreBackend, + dataset_name: &str, + data_dir_anchor: &Path, +) -> CatalogResult { + let table_id = lookup_table_id(metastore, dataset_name) + .await? + .ok_or_else(|| CatalogError::Database { + message: format!( + "cannot export metastore slice: dataset '{dataset_name}' not found in cayenne_table" + ), + })?; + + let mut tables: BTreeMap> = BTreeMap::new(); + + for expected in EXPECTED_TABLES { + let n_columns = expected.columns.len(); + let (sql, params) = if expected.name == "cayenne_table" { + ( + format!( + "SELECT {} FROM {} WHERE table_name = ?", + expected.columns.join(", "), + expected.name + ), + vec![MetastoreValue::Text(dataset_name.to_string())], + ) + } else { + ( + format!( + "SELECT {} FROM {} WHERE table_id = ?", + expected.columns.join(", "), + expected.name + ), + vec![MetastoreValue::Text(table_id.clone())], + ) + }; + + let path_cols = path_columns_for_table(expected.name); + + let rows: Vec = metastore + .query(QueryParams { sql: &sql, params }, move |row| { + let mut out = Vec::with_capacity(n_columns); + for i in 0..n_columns { + out.push(SliceValue::from(&row.get_value(i)?)); + } + Ok(out) + }) + .await?; + + // Rewrite path columns to be relative to anchor. If `make_relative` + // could not strip the anchor (path is outside `data_dir_anchor`), it + // returns the original absolute path — in that case we leave + // `path_is_relative=false` so the slice stays internally consistent. + let rows: Vec = if let Some((path_idx, rel_idx)) = path_cols { + rows.into_iter() + .map(|mut r| { + if let Some(SliceValue::Text(abs)) = r.get(path_idx).cloned() { + let rel = make_relative(&abs, data_dir_anchor); + let is_relative = rel != abs; + r[path_idx] = SliceValue::Text(rel); + r[rel_idx] = SliceValue::Bool(is_relative); + } + r + }) + .collect() + } else { + rows + }; + + tables.insert(expected.name.to_string(), rows); + } + + Ok(DatasetMetastoreSlice { + format_version: SLICE_FORMAT_VERSION, + engine: SLICE_ENGINE.to_string(), + dataset_name: dataset_name.to_string(), + exported_at_ms: chrono::Utc::now().timestamp_millis(), + tables, + }) +} + +/// Atomically import a dataset slice into the metastore. +/// +/// If `slice.dataset_name` already exists in the local `cayenne_table`, that +/// row is deleted (cascading to all dependent rows) before the slice's rows +/// are inserted. Path columns are re-anchored at `data_dir_anchor`. +/// +/// The entire import runs inside a single `BEGIN IMMEDIATE` transaction; on +/// any error the local metastore is left unchanged. +/// +/// # Errors +/// +/// Returns an error if any DML fails or the slice is internally inconsistent. +pub async fn import_dataset( + metastore: &impl MetastoreBackend, + slice: &DatasetMetastoreSlice, + data_dir_anchor: &Path, +) -> CatalogResult<()> { + if slice.format_version != SLICE_FORMAT_VERSION { + return Err(CatalogError::Database { + message: format!( + "refusing to import metastore slice: unsupported format_version {}", + slice.format_version + ), + }); + } + if slice.engine != SLICE_ENGINE { + return Err(CatalogError::Database { + message: format!( + "refusing to import metastore slice: engine '{}' != '{SLICE_ENGINE}'", + slice.engine + ), + }); + } + + let txn = metastore.begin_transaction().await?; + + // Wholesale-replace any existing rows for this dataset. + txn.execute(ExecuteParams { + sql: "DELETE FROM cayenne_table WHERE table_name = ?", + params: vec![MetastoreValue::Text(slice.dataset_name.clone())], + }) + .await?; + + for expected in EXPECTED_TABLES { + let Some(rows) = slice.tables.get(expected.name) else { + continue; + }; + if rows.is_empty() { + continue; + } + let path_cols = path_columns_for_table(expected.name); + + // Build INSERT statement. We use positional ? placeholders matching + // the EXPECTED_TABLES column order. + let placeholders = vec!["?"; expected.columns.len()].join(", "); + let sql = format!( + "INSERT INTO {} ({}) VALUES ({})", + expected.name, + expected.columns.join(", "), + placeholders + ); + + for row in rows { + if row.len() != expected.columns.len() { + return Err(CatalogError::Database { + message: format!( + "metastore slice row for table {} has {} columns, expected {}", + expected.name, + row.len(), + expected.columns.len() + ), + }); + } + + // Convert SliceValue -> MetastoreValue, applying path rewriting. + let mut params: Vec = Vec::with_capacity(row.len()); + for (i, v) in row.iter().cloned().enumerate() { + let mut mv = v.into_metastore_value()?; + if let Some((path_idx, rel_idx)) = path_cols { + if i == path_idx { + if let MetastoreValue::Text(p) = &mv { + mv = MetastoreValue::Text(make_absolute(p, data_dir_anchor)); + } + } else if i == rel_idx { + // We always re-store as absolute on import; flip the flag + // back to false so the catalog code paths see the same + // shape they always have. + mv = MetastoreValue::Bool(false); + } + } + params.push(mv); + } + + txn.execute(ExecuteParams { sql: &sql, params }).await?; + } + } + + txn.commit().await?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metastore::sqlite::SqliteMetastore; + + async fn fresh_metastore() -> (Arc, tempfile::TempDir) { + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("cayenne.db"); + let metastore = Arc::new(SqliteMetastore::new(format!( + "sqlite://{}", + db_path.display() + ))); + metastore.init_schema().await.expect("init_schema"); + (metastore, tmp) + } + + use std::sync::Arc; + + fn sample_table_row(table_id: &str, table_name: &str, abs_path: &str) -> Vec { + vec![ + MetastoreValue::Text(table_id.to_string()), + MetastoreValue::Text(table_name.to_string()), + MetastoreValue::Text(abs_path.to_string()), + MetastoreValue::Bool(false), + MetastoreValue::Text("{\"fields\":[]}".to_string()), + MetastoreValue::Null, + MetastoreValue::Null, + MetastoreValue::Text(String::new()), + MetastoreValue::Null, + MetastoreValue::Null, + MetastoreValue::Integer(0), + ] + } + + fn sample_partition_row( + partition_id: &str, + table_id: &str, + abs_path: &str, + partition_key: &str, + ) -> Vec { + vec![ + MetastoreValue::Text(partition_id.to_string()), + MetastoreValue::Text(table_id.to_string()), + MetastoreValue::Text("[]".to_string()), + MetastoreValue::Text("[]".to_string()), + MetastoreValue::Text(partition_key.to_string()), + MetastoreValue::Text(abs_path.to_string()), + MetastoreValue::Bool(false), + MetastoreValue::Integer(100), + MetastoreValue::Integer(1024), + ] + } + + async fn insert_dataset( + ms: &SqliteMetastore, + dataset: &str, + anchor: &Path, + partitions: &[(&str, &str, &str)], // (partition_id, partition_key, file) + ) { + let table_id = format!("tid-{dataset}"); + let table_path = anchor + .join(format!("{dataset}.dir")) + .to_string_lossy() + .into_owned(); + ms.execute(ExecuteParams { + sql: "INSERT INTO cayenne_table (table_id, table_name, path, path_is_relative, schema_json, primary_key_json, on_conflict_json, current_snapshot_id, partition_column, vortex_config_json, current_sequence_number) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + params: sample_table_row(&table_id, dataset, &table_path), + }) + .await + .expect("insert table"); + for (pid, pk, file) in partitions { + let abs = anchor.join(file).to_string_lossy().into_owned(); + ms.execute(ExecuteParams { + sql: "INSERT INTO cayenne_partition (partition_id, table_id, partition_columns_json, partition_values_json, partition_key, path, path_is_relative, record_count, file_size_bytes) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + params: sample_partition_row(pid, &table_id, &abs, pk), + }) + .await + .expect("insert partition"); + } + } + + #[tokio::test] + async fn round_trip_preserves_rows_and_relocates_paths() { + let (ms_a, tmp_a) = fresh_metastore().await; + let anchor_a = tmp_a.path(); + insert_dataset( + &ms_a, + "trips", + anchor_a, + &[ + ("p1", "k1", "trips/part-001"), + ("p2", "k2", "trips/part-002"), + ], + ) + .await; + + let slice = export_dataset(ms_a.as_ref(), "trips", anchor_a) + .await + .expect("export"); + assert_eq!(slice.format_version, SLICE_FORMAT_VERSION); + assert_eq!(slice.engine, SLICE_ENGINE); + assert_eq!(slice.tables["cayenne_table"].len(), 1); + assert_eq!(slice.tables["cayenne_partition"].len(), 2); + + for row in &slice.tables["cayenne_partition"] { + if let SliceValue::Text(p) = &row[5] { + assert!( + !std::path::Path::new(p).is_absolute(), + "slice partition path should be relative: {p}" + ); + } + } + + let (ms_b, tmp_b) = fresh_metastore().await; + let anchor_b = tmp_b.path(); + import_dataset(ms_b.as_ref(), &slice, anchor_b) + .await + .expect("import"); + + let partitions: Vec<(String, String)> = ms_b + .query( + QueryParams { + sql: "SELECT partition_key, path FROM cayenne_partition WHERE table_id = 'tid-trips'", + params: vec![], + }, + |row| Ok((row.get_string(0)?, row.get_string(1)?)), + ) + .await + .expect("query partitions"); + assert_eq!(partitions.len(), 2); + for (_pk, path) in &partitions { + let anchor_str: String = anchor_b.to_string_lossy().to_string(); + assert!( + path.starts_with(&anchor_str), + "path {path} should be under {anchor_str}" + ); + } + } + + #[tokio::test] + async fn import_replaces_prior_dataset_rows_wholesale() { + let (ms, tmp) = fresh_metastore().await; + let anchor = tmp.path(); + insert_dataset( + &ms, + "trips", + anchor, + &[ + ("old1", "k1", "old1"), + ("old2", "k2", "old2"), + ("old3", "k3", "old3"), + ], + ) + .await; + + let mut tables: BTreeMap> = BTreeMap::new(); + tables.insert( + "cayenne_table".to_string(), + vec![ + sample_table_row("tid-trips", "trips", "trips.dir") + .iter() + .map(SliceValue::from) + .collect(), + ], + ); + tables.insert( + "cayenne_partition".to_string(), + vec![ + sample_partition_row("new1", "tid-trips", "new1", "newk") + .iter() + .map(SliceValue::from) + .collect(), + ], + ); + let slice = DatasetMetastoreSlice { + format_version: SLICE_FORMAT_VERSION, + engine: SLICE_ENGINE.to_string(), + dataset_name: "trips".to_string(), + exported_at_ms: 0, + tables, + }; + + import_dataset(ms.as_ref(), &slice, anchor) + .await + .expect("import"); + + let rows: Vec = ms + .query( + QueryParams { + sql: "SELECT partition_id FROM cayenne_partition WHERE table_id = 'tid-trips' ORDER BY partition_id", + params: vec![], + }, + |row| row.get_string(0), + ) + .await + .expect("q"); + assert_eq!(rows, vec!["new1".to_string()]); + } + + #[tokio::test] + async fn import_leaves_other_datasets_untouched() { + let (ms, tmp) = fresh_metastore().await; + let anchor = tmp.path(); + insert_dataset(&ms, "trips", anchor, &[("t1", "k1", "t1")]).await; + insert_dataset(&ms, "riders", anchor, &[("r1", "k1", "r1")]).await; + + let slice = export_dataset(ms.as_ref(), "trips", anchor) + .await + .expect("export"); + import_dataset(ms.as_ref(), &slice, anchor) + .await + .expect("import"); + + let riders: Vec = ms + .query( + QueryParams { + sql: "SELECT partition_id FROM cayenne_partition WHERE table_id = 'tid-riders'", + params: vec![], + }, + |row| row.get_string(0), + ) + .await + .expect("q riders"); + assert_eq!(riders, vec!["r1".to_string()]); + } + + #[tokio::test] + async fn rejects_unsupported_format_version() { + let (ms, tmp) = fresh_metastore().await; + let mut slice = DatasetMetastoreSlice { + format_version: 99, + engine: SLICE_ENGINE.to_string(), + dataset_name: "trips".to_string(), + exported_at_ms: 0, + tables: BTreeMap::new(), + }; + let err = import_dataset(ms.as_ref(), &slice, tmp.path()) + .await + .expect_err("should fail"); + assert!(err.to_string().contains("unsupported"), "err={err}"); + + slice.format_version = SLICE_FORMAT_VERSION; + slice.engine = "duckdb".to_string(); + let err = import_dataset(ms.as_ref(), &slice, tmp.path()) + .await + .expect_err("should fail"); + assert!(err.to_string().contains("engine"), "err={err}"); + } + + #[tokio::test] + async fn json_round_trip() { + let (ms, tmp) = fresh_metastore().await; + insert_dataset(&ms, "trips", tmp.path(), &[("p1", "k1", "f1")]).await; + let slice = export_dataset(ms.as_ref(), "trips", tmp.path()) + .await + .expect("export"); + let bytes = slice.to_json_bytes().expect("to_json"); + let parsed = DatasetMetastoreSlice::from_json_bytes(&bytes).expect("from_json"); + assert_eq!(parsed.dataset_name, "trips"); + assert_eq!(parsed.tables.len(), slice.tables.len()); + } +} diff --git a/crates/cayenne/src/metastore/sqlite.rs b/crates/cayenne/src/metastore/sqlite.rs index 4fb99579b7..42e9c2c8f3 100644 --- a/crates/cayenne/src/metastore/sqlite.rs +++ b/crates/cayenne/src/metastore/sqlite.rs @@ -318,6 +318,15 @@ struct SqliteRow { } impl MetastoreRow for SqliteRow { + fn get_value(&self, index: usize) -> CatalogResult { + self.values + .get(index) + .cloned() + .ok_or_else(|| CatalogError::Database { + message: format!("Column index {index} out of bounds"), + }) + } + fn get_i64(&self, index: usize) -> CatalogResult { let value = self .values diff --git a/crates/cayenne/src/metastore/turso.rs b/crates/cayenne/src/metastore/turso.rs index 6fc6ae7950..c3fc9ea771 100644 --- a/crates/cayenne/src/metastore/turso.rs +++ b/crates/cayenne/src/metastore/turso.rs @@ -271,6 +271,15 @@ struct TursoRow { } impl MetastoreRow for TursoRow { + fn get_value(&self, index: usize) -> CatalogResult { + self.values + .get(index) + .cloned() + .ok_or_else(|| CatalogError::Database { + message: format!("Column index {index} out of bounds"), + }) + } + fn get_i64(&self, index: usize) -> CatalogResult { let value = self .values diff --git a/crates/runtime-acceleration/Cargo.toml b/crates/runtime-acceleration/Cargo.toml index 8b011b478f..4eb4b7f5f9 100644 --- a/crates/runtime-acceleration/Cargo.toml +++ b/crates/runtime-acceleration/Cargo.toml @@ -30,6 +30,7 @@ telemetry = { path = "../telemetry" } runtime-parameters = { path = "../runtime-parameters" } runtime-object-store = { path = "../runtime-object-store" } runtime-secrets = { path = "../runtime-secrets" } +rusqlite = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true sha2.workspace = true @@ -47,8 +48,8 @@ uuid = { version = "1", features = ["v7"] } default = [] duckdb = ["dep:duckdb"] snapshots = [] -sqlite = [] -turso = [] +sqlite = ["dep:rusqlite"] +turso = ["dep:rusqlite"] [dev-dependencies] anyhow = { workspace = true } diff --git a/crates/runtime-acceleration/src/snapshot/directory_archive.rs b/crates/runtime-acceleration/src/snapshot/directory_archive.rs index 9811524a06..57aa34ed22 100644 --- a/crates/runtime-acceleration/src/snapshot/directory_archive.rs +++ b/crates/runtime-acceleration/src/snapshot/directory_archive.rs @@ -24,7 +24,7 @@ limitations under the License. use sha2::{Digest, Sha256}; use snafu::prelude::*; use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, path::{Component, Path, PathBuf}, sync::{Arc, LazyLock}, }; @@ -115,6 +115,30 @@ type Result = std::result::Result; /// - Files cannot be added to the archive /// - Writing to the output fails pub async fn archive_directories(dirs: &[(PathBuf, String)], writer: W) -> Result +where + W: AsyncWrite + Unpin + Send, +{ + archive_directories_with_plan(dirs, writer, &[], &[]).await +} + +/// Like [`archive_directories`], but allows excluding files (`skip_relative_paths`, +/// matched against each entry's path *relative to its source directory*) and +/// appending extra in-memory entries (`extras`) at the end of the tar. +/// +/// `extras[i]` is added to the archive with `archive_path = extras[i].0` and +/// `bytes = extras[i].1`. The bytes count toward the returned total. +/// +/// # Errors +/// +/// Returns an error if writing the tar archive fails (I/O error on `writer`, +/// missing source directory, or any of the recursive walks/append calls +/// against `dirs` or `extras` fail). +pub async fn archive_directories_with_plan( + dirs: &[(PathBuf, String)], + writer: W, + skip_relative_paths: &[PathBuf], + extras: &[(String, Vec)], +) -> Result where W: AsyncWrite + Unpin + Send, { @@ -123,6 +147,8 @@ where use tokio::task::spawn_blocking; let dirs = dirs.to_vec(); + let skip: HashSet = skip_relative_paths.iter().cloned().collect(); + let extras = extras.to_vec(); // Use spawn_blocking since tar operations are synchronous let (total_bytes, tar_data) = spawn_blocking(move || { @@ -159,12 +185,23 @@ where } // Add all files from this directory recursively - add_directory_to_archive(&mut archive, dir_path, archive_prefix).map_err(|e| { - ArchiveError::CreateArchive { + add_directory_to_archive_filtered(&mut archive, dir_path, archive_prefix, &skip) + .map_err(|e| ArchiveError::CreateArchive { path: dir_path.clone(), source: e, - } - })?; + })?; + } + + // Append in-memory extras after the on-disk content. + for (archive_path, bytes) in &extras { + let mut header = tar::Header::new_gnu(); + header.set_size(bytes.len() as u64); + header.set_mode(0o644); + header.set_entry_type(tar::EntryType::Regular); + header.set_mtime(0); + archive + .append_data(&mut header, archive_path.as_str(), bytes.as_slice()) + .map_err(|source| ArchiveError::WriteArchive { source })?; } // Finish the archive @@ -796,11 +833,14 @@ fn extract_with_skip_existing_and_verify( Ok(()) } -/// Recursively add a directory and its contents to a tar archive. -fn add_directory_to_archive( +/// Walks `dir_path` recursively and appends each file to `archive` under +/// `archive_prefix`, skipping any file whose path *relative to `dir_path`* +/// is contained in `skip_relative_paths`. +fn add_directory_to_archive_filtered( archive: &mut tar::Builder, dir_path: &Path, archive_prefix: &str, + skip_relative_paths: &HashSet, ) -> std::io::Result<()> { use std::fs; @@ -809,6 +849,7 @@ fn add_directory_to_archive( dir: &Path, base_path: &Path, archive_prefix: &str, + skip_relative_paths: &HashSet, ) -> std::io::Result<()> { if dir.is_dir() { for entry in fs::read_dir(dir)? { @@ -818,6 +859,14 @@ fn add_directory_to_archive( std::io::Error::other(format!("Failed to strip prefix from {}", path.display())) })?; + if skip_relative_paths.contains(relative_path) { + tracing::debug!( + "Skipping {} during archive creation (engine plan)", + path.display() + ); + continue; + } + let archive_path = if archive_prefix.is_empty() { relative_path.to_path_buf() } else { @@ -834,7 +883,13 @@ fn add_directory_to_archive( } if metadata.is_dir() { - visit_dirs(archive, &path, base_path, archive_prefix)?; + visit_dirs( + archive, + &path, + base_path, + archive_prefix, + skip_relative_paths, + )?; } else if metadata.is_file() { archive.append_path_with_name(&path, &archive_path)?; } @@ -843,7 +898,13 @@ fn add_directory_to_archive( Ok(()) } - visit_dirs(archive, dir_path, dir_path, archive_prefix) + visit_dirs( + archive, + dir_path, + dir_path, + archive_prefix, + skip_relative_paths, + ) } #[cfg(test)] diff --git a/crates/runtime-acceleration/src/snapshot/engine.rs b/crates/runtime-acceleration/src/snapshot/engine.rs index f6670376ba..fecb3e626a 100644 --- a/crates/runtime-acceleration/src/snapshot/engine.rs +++ b/crates/runtime-acceleration/src/snapshot/engine.rs @@ -15,6 +15,7 @@ limitations under the License. use async_trait::async_trait; use snafu::prelude::*; +use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -25,23 +26,71 @@ mod duckdb; #[cfg(feature = "duckdb")] pub use duckdb::DuckDBSnapshotEngine; +#[cfg(any(feature = "sqlite", feature = "turso"))] +mod sqlite; +#[cfg(feature = "sqlite")] +pub use sqlite::SqliteSnapshotEngine; + +#[cfg(feature = "turso")] +mod turso; +#[cfg(feature = "turso")] +pub use turso::TursoSnapshotEngine; + #[derive(Debug, Snafu)] pub enum SnapshotEngineError { #[snafu(display("DuckDB snapshot error: {source}"))] #[cfg(feature = "duckdb")] DuckDB { source: duckdb::DuckDBSnapshotError }, - /// Placeholder variant for when no features are enabled + #[snafu(display("SQLite snapshot error: {source}"))] + #[cfg(any(feature = "sqlite", feature = "turso"))] + Sqlite { source: sqlite::SqliteSnapshotError }, + + /// Placeholder variant for when no snapshot-capable feature is enabled. #[snafu(display( - "No snapshot engine is available. Enable a snapshot engine feature (e.g., 'duckdb')." + "No snapshot engine is available. Enable a snapshot engine feature \ + (e.g., 'duckdb', 'sqlite', or 'turso')." ))] - #[cfg(not(feature = "duckdb"))] + #[cfg(not(any(feature = "duckdb", feature = "sqlite", feature = "turso")))] Generic, + + /// Open-ended variant used by engines that live outside `runtime-acceleration` + /// (e.g. `CayenneSnapshotEngine` in the runtime crate). The owning crate + /// formats its rich error to a string and wraps it here. + #[snafu(display("{message}"))] + Custom { message: String }, +} + +impl SnapshotEngineError { + /// Construct a [`SnapshotEngineError::Custom`] from anything that renders + /// to a string. Convenience for engines defined in downstream crates. + pub fn from_display(message: D) -> Self { + SnapshotEngineError::Custom { + message: message.to_string(), + } + } } /// Trait defining engine-specific snapshot operations. #[async_trait] pub trait SnapshotEngine: Send + Sync { + /// Hook invoked on the **live** accelerator file *before* it is copied to a + /// temporary snapshot location. Engines that buffer writes outside the + /// primary file (e.g. SQLite/Turso WAL) should checkpoint here so that the + /// subsequent `fs::copy` produces a self-contained file. + /// + /// Default implementation is a no-op. + /// + /// The caller holds the accelerator's write lock for the duration of this + /// call, so no concurrent writes are in flight. + async fn checkpoint_live( + &self, + _live_path: &Path, + _dataset_name: &str, + ) -> Result<(), SnapshotEngineError> { + Ok(()) + } + /// Prepares a snapshot file for upload. /// For engines that support compaction (e.g., `DuckDB`), this may compact the file. /// For other engines, this returns the source path unchanged. @@ -60,6 +109,72 @@ pub trait SnapshotEngine: Send + Sync { /// Returns whether this engine supports compaction. fn supports_compaction(&self) -> bool; + + /// Hook invoked by `SnapshotManager` *before* archiving a directory-layout + /// snapshot. Returns a [`DirectorySnapshotPlan`] that controls which files + /// are skipped from the source directories and which extra in-memory + /// entries are added to the archive. + /// + /// Default implementation includes everything, adds nothing. + /// + /// `dirs` is `(local_directory, archive_prefix)` pairs as passed to the + /// archive layer. `dataset_name` is the name of the dataset whose snapshot + /// is being created. + async fn prepare_directory_snapshot( + &self, + dirs: &[(PathBuf, String)], + dataset_name: &str, + ) -> Result { + let _ = (dirs, dataset_name); + Ok(DirectorySnapshotPlan::default()) + } + + /// Hook invoked by `SnapshotManager` *after* extracting a directory-layout + /// snapshot. Allows engines to perform engine-specific post-processing + /// (e.g. import a metastore slice that was written into one of the + /// extracted directories at upload time). + /// + /// `dirs` is the same `(local_directory, archive_prefix)` pairs supplied + /// to the download path. The engine should locate any virtual entries it + /// emitted from `prepare_directory_snapshot` by their well-known archive + /// paths within `dirs` (the upload-time `extras` list cannot be passed + /// across the upload → download boundary). + /// + /// Default implementation is a no-op. + async fn finalize_directory_snapshot( + &self, + dirs: &[(PathBuf, String)], + dataset_name: &str, + ) -> Result<(), SnapshotEngineError> { + let _ = (dirs, dataset_name); + Ok(()) + } +} + +/// A virtual entry to be added to a directory-snapshot tar archive that does +/// not come from the on-disk source directories. +#[derive(Debug, Clone)] +pub struct DirectoryArchiveExtra { + /// Path within the tar archive (e.g. `"metastore/slice.json"`). Must not + /// collide with a file produced by walking the source directories. + pub archive_path: String, + /// Raw bytes of the entry. + pub bytes: Vec, +} + +/// Engine-supplied plan that controls how a directory-layout snapshot is +/// archived (creation side) and what extras the corresponding extract-side +/// hook should expect to find. +#[derive(Debug, Clone, Default)] +pub struct DirectorySnapshotPlan { + /// Filenames (relative to each `dirs[i].0`) that must be excluded from + /// the archive. Engines use this to drop files they intend to replace + /// (e.g. Cayenne drops `cayenne.db*` because the metastore is captured + /// as a JSON slice instead). + pub skip_relative_paths: HashSet, + /// Extra in-memory entries to add to the archive after the on-disk + /// directory contents are written. + pub extra_entries: Vec, } /// Default snapshot engine for engines that don't require special preparation. @@ -81,6 +196,7 @@ impl SnapshotEngine for DefaultSnapshotEngine { } /// Creates a snapshot engine for the given acceleration engine. +#[must_use] pub fn create_snapshot_engine( engine: &AccelerationEngine, #[cfg(feature = "duckdb")] compaction_enabled: bool, @@ -95,9 +211,9 @@ pub fn create_snapshot_engine( Arc::new(DuckDBSnapshotEngine::new(compaction_enabled)) } #[cfg(feature = "sqlite")] - AccelerationEngine::Sqlite => Arc::new(DefaultSnapshotEngine), + AccelerationEngine::Sqlite => Arc::new(SqliteSnapshotEngine::new()), #[cfg(feature = "turso")] - AccelerationEngine::Turso => Arc::new(DefaultSnapshotEngine), + AccelerationEngine::Turso => Arc::new(TursoSnapshotEngine::new()), AccelerationEngine::Cayenne => Arc::new(DefaultSnapshotEngine), } } diff --git a/crates/runtime-acceleration/src/snapshot/engine/duckdb.rs b/crates/runtime-acceleration/src/snapshot/engine/duckdb.rs index 8f04f667f4..daadf4c0f0 100644 --- a/crates/runtime-acceleration/src/snapshot/engine/duckdb.rs +++ b/crates/runtime-acceleration/src/snapshot/engine/duckdb.rs @@ -50,6 +50,7 @@ pub struct DuckDBSnapshotEngine { } impl DuckDBSnapshotEngine { + #[must_use] pub fn new(compaction_enabled: bool) -> Self { Self { compaction_enabled } } diff --git a/crates/runtime-acceleration/src/snapshot/engine/sqlite.rs b/crates/runtime-acceleration/src/snapshot/engine/sqlite.rs new file mode 100644 index 0000000000..46614d7084 --- /dev/null +++ b/crates/runtime-acceleration/src/snapshot/engine/sqlite.rs @@ -0,0 +1,253 @@ +/* +Copyright 2026 The Spice.ai OSS Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! `SQLite`-specific snapshot engine implementation. +//! +//! `SQLite` accelerator databases run in WAL (write-ahead log) journal mode. +//! In WAL mode, writes are buffered into a `-wal` sidecar file and only +//! periodically checkpointed back into the main `.sqlite` file. A naive +//! `fs::copy` of just the main file therefore captures only the durable +//! pages (often just the 4 KB header on a freshly-written DB) and loses +//! every uncheckpointed write. +//! +//! The `SqliteSnapshotEngine` addresses this by: +//! 1. **`checkpoint_live`** — under the accelerator write lock, opens a +//! short-lived rusqlite connection to the live DB and runs a full +//! `wal_checkpoint(TRUNCATE)`. This is safe under WAL mode (multiple +//! connections are supported) and the existing accelerator write lock +//! held by the caller guarantees no other writers race us. +//! 2. **`prepare_for_upload`** — defensively switches the *copied* file +//! to `journal_mode=DELETE` so the uploaded snapshot has no `-wal` +//! sidecar at all and is fully self-contained. + +use async_trait::async_trait; +use snafu::prelude::*; +use std::path::{Path, PathBuf}; + +use super::SnapshotEngine; + +#[derive(Debug, Snafu)] +pub enum SqliteSnapshotError { + #[snafu(display("Failed to open SQLite for snapshot preparation: {path:?}"))] + Connect { + path: PathBuf, + source: rusqlite::Error, + }, + #[snafu(display( + "Failed to checkpoint SQLite WAL for dataset '{dataset}' at {path:?}: {source}" + ))] + Checkpoint { + dataset: String, + path: PathBuf, + source: rusqlite::Error, + }, + #[snafu(display( + "Incomplete WAL checkpoint for dataset '{dataset}' at {path:?}: \ + busy={busy}, log_frames={log_frames}, checkpointed_frames={checkpointed_frames}. \ + Another connection is holding the WAL or not all frames were flushed; \ + snapshotting now would lose data." + ))] + CheckpointIncomplete { + dataset: String, + path: PathBuf, + busy: i64, + log_frames: i64, + checkpointed_frames: i64, + }, + #[snafu(display( + "Failed to switch SQLite copy to journal_mode=DELETE for dataset '{dataset}' at {path:?}: {source}" + ))] + JournalMode { + dataset: String, + path: PathBuf, + source: rusqlite::Error, + }, + #[snafu(display( + "SQLite snapshot preparation task failed unexpectedly for dataset '{dataset}'" + ))] + JoinError { + dataset: String, + source: tokio::task::JoinError, + }, +} + +pub struct SqliteSnapshotEngine; + +impl SqliteSnapshotEngine { + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl Default for SqliteSnapshotEngine { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl SnapshotEngine for SqliteSnapshotEngine { + async fn checkpoint_live( + &self, + live_path: &Path, + dataset_name: &str, + ) -> Result<(), super::SnapshotEngineError> { + let live_path = live_path.to_path_buf(); + let dataset = dataset_name.to_string(); + tokio::task::spawn_blocking(move || { + let conn = rusqlite::Connection::open(&live_path).context(ConnectSnafu { + path: live_path.clone(), + })?; + // wal_checkpoint(TRUNCATE) forces all WAL frames into the main + // database file and truncates the WAL to zero length. This is the + // strongest available checkpoint short of switching journal mode. + // + // The pragma returns one row `(busy, log, checkpointed)`: + // * `busy != 0` means another connection (e.g. a stuck + // read-transaction) is holding the WAL and the truncation + // could not complete. + // * `checkpointed < log` means not every frame was flushed. + // + // Either case would let post-checkpoint writes leak past the copy + // we're about to take, defeating the whole point of the hook. We + // surface them as `Checkpoint` errors so the caller can either + // retry or fall back rather than silently snapshot a corrupted + // database. + let (busy, log_frames, checkpointed_frames): (i64, i64, i64) = conn + .query_row("PRAGMA wal_checkpoint(TRUNCATE)", [], |row| { + Ok((row.get(0)?, row.get(1)?, row.get(2)?)) + }) + .context(CheckpointSnafu { + dataset: dataset.clone(), + path: live_path.clone(), + })?; + if busy != 0 || checkpointed_frames < log_frames { + return Err(SqliteSnapshotError::CheckpointIncomplete { + dataset: dataset.clone(), + path: live_path, + busy, + log_frames, + checkpointed_frames, + }); + } + Ok::<(), SqliteSnapshotError>(()) + }) + .await + .context(JoinSnafu { + dataset: dataset_name.to_string(), + }) + .map_err(|e| super::SnapshotEngineError::Sqlite { source: e })? + .map_err(|e| super::SnapshotEngineError::Sqlite { source: e }) + } + + async fn prepare_for_upload( + &self, + source_path: &Path, + dataset_name: &str, + ) -> Result { + // The caller has already done `fs::copy(live_db, source_path)` after + // `checkpoint_live` flushed the WAL. The copy should already be + // self-contained, but as defense-in-depth we switch the copy to + // `journal_mode=DELETE` to guarantee no `-wal`/`-shm` sidecars exist + // adjacent to the file we're about to upload. + let copy_path = source_path.to_path_buf(); + let dataset = dataset_name.to_string(); + let path = tokio::task::spawn_blocking(move || { + let conn = rusqlite::Connection::open(©_path).context(ConnectSnafu { + path: copy_path.clone(), + })?; + // Switching to DELETE mode forces a final checkpoint and removes + // any -wal/-shm files. If the copy never had a WAL (because the + // live checkpoint already truncated it), this is a no-op. + conn.query_row("PRAGMA journal_mode=DELETE", [], |_row| Ok(())) + .context(JournalModeSnafu { + dataset, + path: copy_path.clone(), + })?; + Ok::(copy_path) + }) + .await + .context(JoinSnafu { + dataset: dataset_name.to_string(), + }) + .map_err(|e| super::SnapshotEngineError::Sqlite { source: e })? + .map_err(|e| super::SnapshotEngineError::Sqlite { source: e })?; + Ok(path) + } + + fn supports_compaction(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rusqlite::Connection; + use tempfile::TempDir; + + fn create_wal_mode_db_with_rows(tmp: &TempDir, rows: &[(i64, &str)]) -> PathBuf { + let db_path = tmp.path().join("live.sqlite"); + let conn = Connection::open(&db_path).expect("open"); + // Force WAL mode and write rows; do NOT checkpoint, so the main + // file remains nearly-empty and the data lives only in the WAL. + conn.query_row("PRAGMA journal_mode=WAL", [], |_| Ok(())) + .expect("wal"); + conn.execute("CREATE TABLE t(id INTEGER PRIMARY KEY, name TEXT)", []) + .expect("create"); + for (id, name) in rows { + conn.execute("INSERT INTO t(id, name) VALUES (?1, ?2)", (id, name)) + .expect("insert"); + } + // Drop without explicit checkpoint -- keep WAL state intact. + drop(conn); + db_path + } + + fn count_rows(path: &Path) -> i64 { + let conn = Connection::open(path).expect("open verify"); + conn.query_row("SELECT COUNT(*) FROM t", [], |row| row.get(0)) + .expect("count") + } + + #[tokio::test] + async fn checkpoint_live_then_copy_captures_all_rows() { + let tmp = TempDir::new().expect("tmp"); + let rows = vec![(1, "a"), (2, "b"), (3, "c")]; + let live = create_wal_mode_db_with_rows(&tmp, &rows); + + let engine = SqliteSnapshotEngine::new(); + engine + .checkpoint_live(&live, "ds") + .await + .expect("checkpoint live"); + + let copy = tmp.path().join("copy.sqlite"); + std::fs::copy(&live, ©).expect("copy"); + + let final_path = engine + .prepare_for_upload(©, "ds") + .await + .expect("prepare"); + + // No -wal or -shm should exist next to the prepared file. + assert!(!final_path.with_extension("sqlite-wal").exists()); + assert!(!final_path.with_extension("sqlite-shm").exists()); + + assert_eq!( + count_rows(&final_path), + i64::try_from(rows.len()).expect("row count fits in i64") + ); + } +} diff --git a/crates/runtime-acceleration/src/snapshot/engine/turso.rs b/crates/runtime-acceleration/src/snapshot/engine/turso.rs new file mode 100644 index 0000000000..c921b4bc2c --- /dev/null +++ b/crates/runtime-acceleration/src/snapshot/engine/turso.rs @@ -0,0 +1,81 @@ +/* +Copyright 2026 The Spice.ai OSS Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Turso (libsql) snapshot engine implementation. +//! +//! **Status:** WAL flush is currently a no-op for Turso. The first attempt +//! reused [`super::sqlite::SqliteSnapshotEngine`] under the assumption that +//! libsql's on-disk format is byte-compatible with classic `SQLite` for +//! read-only opens via rusqlite. In practice the integration test surfaced +//! `"file is not a database"` from rusqlite when opening a libsql primary +//! file. Tracked in spiceai/spiceai#10657. +//! +//! Until that issue is fixed by routing the checkpoint pragma through a +//! turso/libsql-native connection, snapshot creation for Turso accelerators +//! falls back to the default behavior: `fs::copy` of the live file as-is, +//! with the same WAL-loss caveat that #10643 originally documented. +//! `refresh_mode: snapshot` against Turso is therefore disabled for now +//! (see `tests/snapshot_refresh/turso.rs`). + +use async_trait::async_trait; +use std::path::{Path, PathBuf}; + +use super::SnapshotEngine; + +/// Snapshot engine for Turso accelerators. +/// +/// Currently a no-op (defers WAL flushing to a future libsql-native +/// implementation; see spiceai/spiceai#10657). The struct exists so that +/// `create_snapshot_engine` can return a stable, Turso-specific type and +/// the call sites stay symmetric with the other engine implementations. +pub struct TursoSnapshotEngine; + +impl TursoSnapshotEngine { + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl Default for TursoSnapshotEngine { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl SnapshotEngine for TursoSnapshotEngine { + async fn checkpoint_live( + &self, + _live_path: &Path, + _dataset_name: &str, + ) -> Result<(), super::SnapshotEngineError> { + // No-op: see spiceai/spiceai#10657. Returning Ok here matches + // historical Turso behavior; once the issue is fixed this will route + // through a libsql-native checkpoint. + Ok(()) + } + + async fn prepare_for_upload( + &self, + source_path: &Path, + _dataset_name: &str, + ) -> Result { + // No-op: pass the copy through unchanged. See spiceai/spiceai#10657. + Ok(source_path.to_path_buf()) + } + + fn supports_compaction(&self) -> bool { + false + } +} diff --git a/crates/runtime-acceleration/src/snapshot/mod.rs b/crates/runtime-acceleration/src/snapshot/mod.rs index f4ad7c8688..9522aea8b5 100644 --- a/crates/runtime-acceleration/src/snapshot/mod.rs +++ b/crates/runtime-acceleration/src/snapshot/mod.rs @@ -57,7 +57,7 @@ use crate::dataset_checkpoint::DatasetCheckpointerFactory; mod behavior; pub mod directory_archive; -mod engine; +pub mod engine; pub mod metrics; pub use crate::layout::AccelerationLayout; pub use behavior::{SNAPSHOTS_ENTERPRISE_ONLY_MESSAGE, SnapshotBehavior}; @@ -221,6 +221,8 @@ impl DatasetMetadata { /// Details captured when downloading a snapshot for bootstrapping. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SnapshotDownloadInfo { + /// The snapshot id from the snapshot metadata that was downloaded. + pub snapshot_id: u64, pub schema: SchemaRef, pub bytes_downloaded: u64, pub checksum: String, @@ -333,6 +335,10 @@ pub enum SnapshotDownloadError { CheckpointerSchema { source: Box, }, + #[snafu(display("Failed to bootstrap dataset checkpoint from snapshot metadata: {source}"))] + CheckpointerBootstrap { + source: Box, + }, #[snafu(display("Snapshot {path} is missing a schema in its checkpoint"))] MissingSchema { path: String }, #[snafu(display("Snapshot schema for dataset {dataset} is missing in metadata"))] @@ -923,6 +929,15 @@ impl SnapshotManager { self } + /// Replaces the snapshot engine. Used by accelerators (notably Cayenne) + /// that need engine-specific archive create/extract behavior beyond what + /// `create_snapshot_engine` produces from `AccelerationEngine` alone. + #[must_use] + pub fn with_snapshot_engine(mut self, engine: Arc) -> Self { + self.snapshot_engine = engine; + self + } + /// Sets the policy for snapshot creation. #[must_use] pub fn with_snapshots_creation_policy( @@ -944,6 +959,136 @@ impl SnapshotManager { schema_meta.to_schema_ref().ok() } + /// Returns the `current_snapshot_id` from the remote snapshot metadata for this + /// dataset, if any. Returns `None` when there is no metadata, no entry for the + /// dataset, or the dataset has no current snapshot. + /// + /// # Errors + /// + /// Returns an error if reading or parsing the snapshot metadata from the + /// object store fails. Callers (e.g. snapshot-mode refresh) can use this + /// to react to transient object-store failures. + pub async fn remote_current_snapshot_id(&self) -> Result, SnapshotDownloadError> { + let handle = self.load_metadata().await.map_err(|e| match e { + MetadataLoadError::Read { path, source } => { + SnapshotDownloadError::ReadMetadata { path, source } + } + MetadataLoadError::Parse { path, source } => { + SnapshotDownloadError::ParseMetadata { path, source } + } + MetadataLoadError::UnsupportedVersion { path, version } => { + SnapshotDownloadError::UnsupportedMetadataVersion { path, version } + } + })?; + let Some(handle) = handle else { + return Ok(None); + }; + let Some(dataset_entry) = handle.metadata.datasets.get(&self.dataset_name) else { + return Ok(None); + }; + Ok(dataset_entry.current_snapshot_id) + } + + /// Downloads the latest snapshot only if its `snapshot_id` is strictly + /// greater than `current_local_id`. When the remote `current_snapshot_id` + /// is less than or equal to `current_local_id` (matching id, or remote + /// metadata rolled back), returns `Ok(None)` without touching local files. + /// + /// This is the primary entry point for `refresh_mode: snapshot`, which polls + /// the snapshot store on a fixed cadence and only reloads the accelerator + /// when a strictly newer snapshot is available. Snapshot mode never + /// regresses the accelerator to an older snapshot id. + /// + /// # Errors + /// + /// Same errors as [`SnapshotManager::download_latest_snapshot`]. + /// Download the latest snapshot if it is strictly newer than + /// `current_local_id`. The optional `validate_schema` callback is given + /// the snapshot metadata's recorded schema **before** any bytes are + /// downloaded or written to disk; if it returns false a schema + /// mismatch is reported and the accelerator's primary file is left + /// untouched. + /// + /// # Errors + /// + /// Same errors as [`SnapshotManager::download_latest_snapshot`], plus + /// [`SnapshotDownloadError::SchemaMismatch`] when `validate_schema` + /// rejects the metadata schema. + pub async fn download_if_newer( + &self, + current_local_id: Option, + validate_schema: Option<&(dyn Fn(&SchemaRef) -> bool + Send + Sync)>, + ) -> Result, SnapshotDownloadError> { + let Some(remote_id) = self.remote_current_snapshot_id().await? else { + return Ok(None); + }; + match current_local_id { + Some(local_id) if remote_id <= local_id => { + if remote_id < local_id { + tracing::warn!( + dataset = %self.dataset_name, + remote_snapshot_id = remote_id, + local_snapshot_id = local_id, + "snapshot metadata current id is older than the locally loaded snapshot; \ + skipping reload to avoid regression" + ); + } + Ok(None) + } + _ => { + // Inspect the metadata-recorded schema first, before + // touching the file: an incompatible snapshot must never + // overwrite the accelerator's current primary file. If a + // validator is provided we require the remote metadata to + // expose a parseable schema for this dataset — missing or + // malformed schema metadata is treated as a validation + // failure rather than silently skipped. + if let Some(validate) = validate_schema { + let handle = self.load_metadata().await.map_err(|e| match e { + MetadataLoadError::Read { path, source } => { + SnapshotDownloadError::ReadMetadata { path, source } + } + MetadataLoadError::Parse { path, source } => { + SnapshotDownloadError::ParseMetadata { path, source } + } + MetadataLoadError::UnsupportedVersion { path, version } => { + SnapshotDownloadError::UnsupportedMetadataVersion { path, version } + } + })?; + let handle = handle.ok_or_else(|| SnapshotDownloadError::SchemaMismatch { + dataset: self.dataset_name.clone(), + })?; + let dataset_metadata = handle + .metadata + .datasets + .get(&self.dataset_name) + .ok_or_else(|| SnapshotDownloadError::SchemaMismatch { + dataset: self.dataset_name.clone(), + })?; + let metadata_schema = dataset_metadata.current_schema().ok_or_else(|| { + SnapshotDownloadError::SchemaMismatch { + dataset: self.dataset_name.clone(), + } + })?; + let metadata_schema_ref = + metadata_schema.to_schema_ref().map_err(|source| { + SnapshotDownloadError::MetadataSchemaDeserialize { + dataset: self.dataset_name.clone(), + source, + } + })?; + if !validate(&metadata_schema_ref) { + return Err(SnapshotDownloadError::SchemaMismatch { + dataset: self.dataset_name.clone(), + }); + } + } + + self.download_latest_snapshot().await + } + } + } + /// Creates a new snapshot by streaming the local acceleration file to object storage. /// /// For file-based accelerators (`DuckDB`, `SQLite`), this copies and uploads the database file. @@ -1080,6 +1225,15 @@ impl SnapshotManager { destination_location: &ObjectPath, lock_guard: OwnedMutexGuard<()>, ) -> Result<(u64, String), SnapshotUploadError> { + // Step 0: Engine-specific live checkpoint while the lock is held. + // For SQLite/Turso this drains the WAL into the main file so that + // the subsequent `fs::copy` produces a self-contained snapshot. + // Default (no-op) for engines without WAL. + self.snapshot_engine + .checkpoint_live(source_local_path, &self.dataset_name) + .await + .context(PrepareUploadSnafu)?; + // Step 1: Copy the database file locally (lock is held) let temp_copy_path = source_local_path.with_extension("snapshot_tmp"); fs::copy(source_local_path, &temp_copy_path) @@ -1141,7 +1295,20 @@ impl SnapshotManager { destination_location: &ObjectPath, lock_guard: OwnedMutexGuard<()>, ) -> Result<(u64, String), SnapshotUploadError> { - use crate::snapshot::directory_archive::archive_directories; + use crate::snapshot::directory_archive::archive_directories_with_plan; + + // Step 0: Ask the engine for any per-directory skip list / extras. + let plan = self + .snapshot_engine + .prepare_directory_snapshot(dirs, &self.dataset_name) + .await + .map_err(|source| SnapshotUploadError::PrepareUpload { source })?; + let skip_paths: Vec = plan.skip_relative_paths.into_iter().collect(); + let extras: Vec<(String, Vec)> = plan + .extra_entries + .into_iter() + .map(|e| (e.archive_path, e.bytes)) + .collect(); // Step 1: Create a temporary tar archive of all directories let temp_archive_path = std::env::temp_dir().join(format!( @@ -1158,12 +1325,13 @@ impl SnapshotManager { source, })?; - let total_archived = archive_directories(dirs, archive_file) - .await - .map_err(|source| SnapshotUploadError::ArchiveCreate { - path: temp_archive_path.clone(), - source: std::io::Error::other(source.to_string()), - })?; + let total_archived = + archive_directories_with_plan(dirs, archive_file, &skip_paths, &extras) + .await + .map_err(|source| SnapshotUploadError::ArchiveCreate { + path: temp_archive_path.clone(), + source: std::io::Error::other(source.to_string()), + })?; tracing::debug!( "Created tar archive for snapshot. dataset={} archive_size={}", @@ -1673,43 +1841,67 @@ impl SnapshotManager { source, })?; - if let Some(schema) = checkpointer + let local_schema = checkpointer .get_schema() .await - .map_err(|source| SnapshotDownloadError::CheckpointerSchema { source })? - { + .map_err(|source| SnapshotDownloadError::CheckpointerSchema { source })?; + let final_schema = if let Some(schema) = local_schema { if schema.as_ref() != metadata_schema.as_ref() { return Err(SnapshotDownloadError::SchemaMismatch { dataset: self.dataset_name.clone(), }); } - - let local_path_display = self - .layout - .primary_path() - .map_or_else(|| "".to_string(), |p| p.display().to_string()); - tracing::info!( - dataset = %self.dataset_name, - snapshot = %entry.snapshot, - size_bytes = actual_size, - sha = %actual_checksum, - "Snapshot restored to {local_path_display}" - ); - Ok(SnapshotDownloadInfo { - schema, - bytes_downloaded: actual_size, - checksum: actual_checksum, - last_updated_at: entry.snapshot_last_updated_at_ms, - }) + schema } else { - tracing::warn!( + // The downloaded snapshot didn't carry a populated + // `_dataset_checkpoint` row. This is the steady-state for + // engines whose archive doesn't ship the spice_sys checkpoint + // alongside the live data — most notably Cayenne, where the + // archive ships the per-dataset metastore slice instead of + // the raw `cayenne.db`. The metadata-recorded schema has + // already been verified to round-trip via `to_schema_ref()`, + // so we can safely materialize it into the checkpoint table + // to bring the local DB in line with the snapshot. + // + // For engines whose snapshot *does* normally include the + // checkpoint row (DuckDB / SQLite / Turso), this branch is + // unreachable in practice; if a corrupted snapshot reaches + // it the self-heal is harmless because we trust the metadata + // schema (it's the same one we just validated above for + // any other branch). + // + // Closes spiceai/spiceai#10658. + tracing::debug!( dataset = %self.dataset_name, snapshot = %entry.snapshot, sha = %entry.snapshot_checksum, - "Snapshot schema not found" + "Bootstrapping dataset checkpoint from snapshot metadata" ); - Err(SnapshotDownloadError::MissingSchema { path: path_display }) - } + checkpointer + .checkpoint(&metadata_schema, None) + .await + .map_err(|source| SnapshotDownloadError::CheckpointerBootstrap { source })?; + metadata_schema + }; + + let local_path_display = self + .layout + .primary_path() + .map_or_else(|| "".to_string(), |p| p.display().to_string()); + tracing::info!( + dataset = %self.dataset_name, + snapshot = %entry.snapshot, + size_bytes = actual_size, + sha = %actual_checksum, + "Snapshot restored to {local_path_display}" + ); + Ok(SnapshotDownloadInfo { + snapshot_id: entry.snapshot_id, + schema: final_schema, + bytes_downloaded: actual_size, + checksum: actual_checksum, + last_updated_at: entry.snapshot_last_updated_at_ms, + }) } /// Downloads a snapshot directly to a single file (for file-based accelerators). @@ -1730,9 +1922,28 @@ impl SnapshotManager { } let mut stream = get_result.into_stream(); - let mut file = fs::File::create(local_path).await.map_err(|source| { + + // Write to a sibling temp file then atomically rename into the primary + // path. This guarantees concurrent readers (e.g. an active accelerator + // running in `refresh_mode: snapshot`) never observe a half-written or + // truncated file: they either see the previous complete snapshot or + // the new complete snapshot. + let temp_path = match local_path.file_name() { + Some(name) => { + let mut tmp_name = std::ffi::OsString::from("."); + tmp_name.push(name); + tmp_name.push(".download.tmp"); + local_path.with_file_name(tmp_name) + } + None => local_path.with_extension("download.tmp"), + }; + // Best-effort cleanup of any leftover temp file from a previous failed + // download; ignore errors (e.g. file does not exist). + let _ = fs::remove_file(&temp_path).await; + + let mut file = fs::File::create(&temp_path).await.map_err(|source| { SnapshotDownloadError::WriteLocal { - path: local_path.clone(), + path: temp_path.clone(), source, } })?; @@ -1744,7 +1955,7 @@ impl SnapshotManager { let chunk = match chunk_result { Ok(chunk) => chunk, Err(source) => { - let _ = fs::remove_file(local_path).await; + let _ = fs::remove_file(&temp_path).await; return Err(SnapshotDownloadError::DownloadBytes { path: path_display.to_string(), source, @@ -1756,26 +1967,110 @@ impl SnapshotManager { hasher.update(&chunk); if let Err(source) = file.write_all(&chunk).await { - let _ = fs::remove_file(local_path).await; + let _ = fs::remove_file(&temp_path).await; return Err(SnapshotDownloadError::WriteLocal { - path: local_path.clone(), + path: temp_path.clone(), source, }); } } if let Err(source) = file.flush().await { - let _ = fs::remove_file(local_path).await; + let _ = fs::remove_file(&temp_path).await; + return Err(SnapshotDownloadError::WriteLocal { + path: temp_path.clone(), + source, + }); + } + // fsync the temp file before rename so the downloaded bytes are + // durable in the temp file. Crash durability of the renamed primary + // path also requires syncing the parent directory after the rename; + // we do that below once the rename has succeeded. + if let Err(source) = file.sync_all().await { + let _ = fs::remove_file(&temp_path).await; return Err(SnapshotDownloadError::WriteLocal { - path: local_path.clone(), + path: temp_path.clone(), source, }); } drop(file); - let actual_checksum = self - .validate_snapshot(entry, actual_size, hasher, local_path, path_display) - .await?; + // Validate before renaming so a corrupt download never replaces the + // current good file at the primary path. + let actual_checksum = match self + .validate_snapshot(entry, actual_size, hasher, &temp_path, path_display) + .await + { + Ok(checksum) => checksum, + Err(e) => { + let _ = fs::remove_file(&temp_path).await; + return Err(e); + } + }; + + // `tokio::fs::rename` defers to `std::fs::rename`, which on every + // tier-1 platform we ship for performs a destination-replacing move: + // + // - Unix: `rename(2)` atomically replaces an existing file at the + // destination on the same filesystem. + // - Windows: `MoveFileExW` is invoked with `MOVEFILE_REPLACE_EXISTING` + // so the destination is replaced when present (subject to the + // usual Windows constraint that no other process/handle holds the + // destination open with sharing flags that disallow replacement). + // + // If the platform-level rename fails with `AlreadyExists` (older or + // unusual Windows configurations where the replace flag was not + // honored), fall back to a swap-via-sidecar dance: rename the + // existing target to a `.old` sidecar (atomic), rename `temp_path` + // into place (atomic), then best-effort delete the sidecar. At every + // instant `local_path` points to either the old file or the new one, + // never to a missing entry — which matters for `refresh_mode: snapshot` + // because the accelerator's pool may still be holding readers open + // against `local_path` in the gap before `reload_from_snapshot` + // evicts them. + if let Err(source) = fs::rename(&temp_path, local_path).await { + if source.kind() == std::io::ErrorKind::AlreadyExists { + let sidecar_path = local_path.with_extension(format!("old.{}", std::process::id())); + if let Err(swap_err) = fs::rename(local_path, &sidecar_path).await { + let _ = fs::remove_file(&temp_path).await; + return Err(SnapshotDownloadError::WriteLocal { + path: local_path.clone(), + source: swap_err, + }); + } + if let Err(retry_err) = fs::rename(&temp_path, local_path).await { + // Restore the original to avoid leaving the dataset + // pointing at a missing file. + let _ = fs::rename(&sidecar_path, local_path).await; + let _ = fs::remove_file(&temp_path).await; + return Err(SnapshotDownloadError::WriteLocal { + path: local_path.clone(), + source: retry_err, + }); + } + // Best-effort cleanup; the sidecar will be reaped on next + // restart if this fails (e.g. another process still has it + // open on Windows). + let _ = fs::remove_file(&sidecar_path).await; + } else { + let _ = fs::remove_file(&temp_path).await; + return Err(SnapshotDownloadError::WriteLocal { + path: local_path.clone(), + source, + }); + } + } + + // Best-effort fsync of the parent directory so the rename's directory + // entry update is durable across a crash. POSIX requires this in + // addition to fsync of the file itself; on platforms where directory + // fsync is not supported (or returns EINVAL/ENOTDIR) we silently fall + // back to relying on the rename's own metadata journaling. + if let Some(parent) = local_path.parent() + && let Ok(dir) = fs::File::open(parent).await + { + let _ = dir.sync_all().await; + } Ok((actual_size, actual_checksum)) } @@ -1906,6 +2201,18 @@ impl SnapshotManager { source: std::io::Error::other(source.to_string()), })?; + if let Err(source) = self + .snapshot_engine + .finalize_directory_snapshot(dirs, &self.dataset_name) + .await + { + let _ = fs::remove_file(&temp_archive_path).await; + return Err(SnapshotDownloadError::ArchiveExtract { + path: temp_archive_path.clone(), + source: std::io::Error::other(format!("engine finalize failed: {source}")), + }); + } + // Cleanup temp archive let _ = fs::remove_file(&temp_archive_path).await; @@ -2733,6 +3040,28 @@ mod tests { )])) } + /// Writes a sample local accelerator file appropriate for the engine. + /// For `SQLite`/`Turso`, creates a real (empty) `SQLite` WAL-mode database + /// so that the engine's `checkpoint_live` hook can open it. For other + /// engines, writes opaque test bytes since no engine-side validation + /// runs against the file pre-snapshot. + fn write_sample_local_db(path: &std::path::Path, engine: &AccelerationEngine) { + match engine { + #[cfg(any(feature = "sqlite", feature = "turso"))] + AccelerationEngine::Sqlite | AccelerationEngine::Turso => { + let conn = rusqlite::Connection::open(path).expect("open sample sqlite db"); + conn.query_row("PRAGMA journal_mode=WAL", [], |_| Ok(())) + .expect("set wal"); + conn.execute("CREATE TABLE sample(id INTEGER PRIMARY KEY)", []) + .expect("create sample table"); + drop(conn); + } + _ => { + std::fs::write(path, b"test snapshot content").expect("write test file"); + } + } + } + /// Builds a `SnapshotManager` for the specified engine type. /// This enables testing snapshot functionality across all accelerator backends. fn build_manager_for_engine( @@ -2915,12 +3244,121 @@ mod tests { assert_eq!(info.schema.as_ref(), schema.as_ref()); assert_eq!(info.bytes_downloaded, contents.len() as u64); assert_eq!(info.checksum, checksum); + assert_eq!(info.snapshot_id, 0); let downloaded = fs::read(&local_path) .await .expect("read downloaded snapshot"); assert_eq!(downloaded.as_slice(), contents.as_ref()); } + #[tokio::test] + #[cfg(feature = "duckdb")] + async fn download_if_newer_returns_none_when_local_id_matches() { + let store = Arc::new(InMemory::new()); + let base = Path::from(SNAPSHOT_BASE_PATH); + let layout = SnapshotPathLayout::new(DATASET_NAME, &AccelerationEngine::DuckDB); + let instant = Utc + .with_ymd_and_hms(2025, 1, 2, 3, 4, 5) + .single() + .expect("valid time"); + let location = layout.build_location(&base, instant); + let contents = Bytes::from_static(b"snapshot-bytes"); + store + .put(&location, contents.clone().into()) + .await + .expect("write snapshot"); + let entry = SnapshotEntry { + snapshot_id: 7, + timestamp_ms: instant.timestamp_millis(), + snapshot: snapshot_uri(&location), + snapshot_checksum: compute_sha256_hex(contents.as_ref()), + snapshot_checksum_algorithm: SNAPSHOT_CHECKSUM_ALGORITHM.to_string(), + snapshot_size: contents.len() as u64, + snapshot_engine: None, + snapshot_row_count: None, + snapshot_last_updated_at_ms: None, + }; + let schema = sample_schema(); + let metadata = SnapshotMetadata { + format_version: SNAPSHOT_METADATA_FORMAT_VERSION, + location: SNAPSHOT_URI_PREFIX.to_string(), + last_updated_ms: Utc::now().timestamp_millis(), + datasets: HashMap::from([( + DATASET_NAME.to_string(), + dataset_metadata(&schema, vec![entry], Some(7)), + )]), + }; + let metadata_path = base.child(METADATA_FILE_NAME); + write_metadata(&store, &metadata_path, &metadata).await; + + let temp_dir = TempDir::new().expect("create temp dir"); + let local_path = temp_dir.path().join("snapshot.db"); + let manager = build_manager( + Arc::clone(&store), + local_path.clone(), + BootstrapOnFailureBehavior::Warn, + &schema, + false, + ); + + let result = manager + .download_if_newer(Some(7), None) + .await + .expect("download_if_newer should succeed"); + assert!(result.is_none(), "matching ids must not download"); + assert!( + !local_path.exists(), + "local file must not be written when nothing is newer" + ); + + // A strictly older local id should still NOT trigger a download to + // an *older* remote snapshot (regression-safety): only a strictly + // newer remote snapshot causes a reload. Here the remote current id + // is 7 and the local id we claim is 8, so this must be a no-op. + let result = manager + .download_if_newer(Some(8), None) + .await + .expect("download_if_newer should succeed"); + assert!( + result.is_none(), + "local id ahead of remote must not regress" + ); + assert!( + !local_path.exists(), + "local file must not be written when remote is older" + ); + + // A strictly older local id (6) than the remote (7) should download. + let info = manager + .download_if_newer(Some(6), None) + .await + .expect("download_if_newer should succeed") + .expect("expected newer snapshot to be downloaded"); + assert_eq!(info.snapshot_id, 7); + } + + #[tokio::test] + #[cfg(feature = "duckdb")] + async fn download_if_newer_returns_none_when_no_metadata() { + let store = Arc::new(InMemory::new()); + let temp_dir = TempDir::new().expect("create temp dir"); + let local_path = temp_dir.path().join("snapshot.db"); + let schema = sample_schema(); + let manager = build_manager( + Arc::clone(&store), + local_path.clone(), + BootstrapOnFailureBehavior::Warn, + &schema, + false, + ); + let result = manager + .download_if_newer(None, None) + .await + .expect("download_if_newer should succeed"); + assert!(result.is_none()); + assert!(!local_path.exists()); + } + #[tokio::test] async fn download_with_fallback_uses_next_snapshot_on_integrity_failure() { let store = Arc::new(InMemory::new()); @@ -4024,7 +4462,7 @@ mod tests { let store = Arc::new(InMemory::new()); let temp_dir = TempDir::new().expect("create temp dir"); let local_path = temp_dir.path().join("snapshot.db"); - std::fs::write(&local_path, b"test snapshot content").expect("write test file"); + write_sample_local_db(&local_path, engine); let schema = sample_schema(); let manager = build_manager_for_engine( diff --git a/crates/runtime/src/accelerated_table/mod.rs b/crates/runtime/src/accelerated_table/mod.rs index c299a540c1..0adccfd080 100644 --- a/crates/runtime/src/accelerated_table/mod.rs +++ b/crates/runtime/src/accelerated_table/mod.rs @@ -66,7 +66,7 @@ pub mod refresh_task; mod refresh_task_runner; mod retention; pub(crate) mod sink; -mod snapshots; +pub(crate) mod snapshots; mod synchronized_table; mod timestamp_metrics_utils; pub mod write; @@ -338,6 +338,9 @@ pub struct Builder { synchronize_with: Option, initial_load_complete: bool, snapshot_creation_config: Option, + /// Per-dataset state for `RefreshMode::Snapshot`. Required when the + /// refresh mode is Snapshot; ignored otherwise. + snapshot_refresh_state: Option, metrics: Option, cpu_runtime: Option, io_runtime: Handle, @@ -386,6 +389,7 @@ impl Builder { initial_load_complete: false, refresh_semaphore: None, snapshot_creation_config: None, + snapshot_refresh_state: None, metrics: None, cpu_runtime: None, io_runtime, @@ -568,6 +572,16 @@ impl Builder { self } + /// Configure per-dataset state for `RefreshMode::Snapshot`. Required when + /// the refresh mode is Snapshot. + pub fn snapshot_refresh_state( + &mut self, + state: Option, + ) -> &mut Self { + self.snapshot_refresh_state = state; + self + } + /// Set the TTL for cache mode pub fn caching_ttl(&mut self, ttl: Option) -> &mut Self { self.caching_ttl = ttl; @@ -739,6 +753,16 @@ impl Builder { Some(start_refresh), ) } + RefreshMode::Snapshot => { + // Snapshot mode is interval-driven and supports manual refresh triggers + // to force a poll of the snapshot store outside the regular cadence. + let (start_refresh, on_start_refresh) = + mpsc::channel::>(1); + ( + refresh::AccelerationRefreshMode::Snapshot(on_start_refresh), + Some(start_refresh), + ) + } }; validate_refresh_data_window(&self.refresh, &self.dataset_name, &self.federated.schema()); @@ -781,6 +805,7 @@ impl Builder { } refresher.with_snapshot_creation_config(self.snapshot_creation_config); + refresher.with_snapshot_refresh_state(self.snapshot_refresh_state); refresher.set_bootstrap_status(self.bootstrap_status); if let Some(ref resource_monitor) = self.resource_monitor { @@ -1444,6 +1469,19 @@ impl TableProvider for AcceleratedTable { input: Arc, overwrite: InsertOp, ) -> datafusion::error::Result> { + // In `refresh_mode: snapshot`, the accelerator is a read-only mirror + // of the snapshot store. Accepting writes here would either be + // silently overwritten by the next snapshot reload (data loss) or + // race with the file replacement performed during refresh. Reject + // explicitly so callers fail loudly rather than observing surprising + // behavior. + if self.refresh_mode == RefreshMode::Snapshot { + return Err(datafusion::error::DataFusionError::Execution(format!( + "writes to accelerated table {} are not permitted when refresh_mode is 'snapshot'; the accelerator is driven exclusively from the snapshot store", + self.dataset_name + ))); + } + self.update_last_updated_at(); match &self.write_mode { diff --git a/crates/runtime/src/accelerated_table/refresh.rs b/crates/runtime/src/accelerated_table/refresh.rs index 44cbea9bb1..d164880f93 100644 --- a/crates/runtime/src/accelerated_table/refresh.rs +++ b/crates/runtime/src/accelerated_table/refresh.rs @@ -414,8 +414,12 @@ impl Refresh { }; last_checkpoint.last_checkpoint_time().await.ok().flatten() } + // Snapshot mode is interval-based and does not depend on the + // dataset checkpoint (snapshot poll cadence is governed solely by + // `check_interval`). The first poll happens immediately on startup + // so we pick up any snapshot newer than what is on local disk. // Append and Changes modes are always refreshed since they stream changes from the source table. - RefreshMode::Append | RefreshMode::Changes => { + RefreshMode::Snapshot | RefreshMode::Append | RefreshMode::Changes => { return NextRefresh::WaitFor(Duration::ZERO); } // Caching mode handles refreshes in two ways: @@ -611,6 +615,9 @@ pub enum AccelerationRefreshMode { Append(Receiver>), Changes(ChangesStream), Caching(Receiver>), + /// Snapshot mode: refreshes are driven by polling the snapshot store for + /// snapshots newer than the currently loaded one. + Snapshot(Receiver>), } pub struct Refresher { @@ -628,6 +635,7 @@ pub struct Refresher { refresh_on_startup: RefreshOnStartup, synchronize_with: Option, snapshot_config: Option, + snapshot_refresh_state: Option, snapshot_interval_task: Option>, initial_load_completed: Arc, @@ -692,6 +700,7 @@ impl Refresher { semaphore: None, on_complete_notification: None, snapshot_config: None, + snapshot_refresh_state: None, snapshot_interval_task: None, metrics: None, cpu_runtime, @@ -762,6 +771,16 @@ impl Refresher { self } + /// Configure per-dataset state for `RefreshMode::Snapshot`. Required when + /// the refresh mode is Snapshot. + pub fn with_snapshot_refresh_state( + &mut self, + state: Option, + ) -> &mut Self { + self.snapshot_refresh_state = state; + self + } + /// Set the bootstrap status from dataset initialization. pub fn set_bootstrap_status(&mut self, bootstrap_status: BootstrapStatus) -> &mut Self { self.bootstrap_status = bootstrap_status; @@ -862,7 +881,8 @@ impl Refresher { ( AccelerationRefreshMode::Append(receiver) | AccelerationRefreshMode::Full(receiver) - | AccelerationRefreshMode::Caching(receiver), + | AccelerationRefreshMode::Caching(receiver) + | AccelerationRefreshMode::Snapshot(receiver), _, ) => receiver, (AccelerationRefreshMode::Changes(stream), _) => { @@ -938,6 +958,9 @@ impl Refresher { refresh_task_runner = refresh_task_runner.with_s3_express_acceleration(self.is_s3_express_acceleration); + refresh_task_runner = + refresh_task_runner.with_snapshot_refresh_state(self.snapshot_refresh_state.clone()); + let mut refresh_task_runner = refresh_task_runner.build(); let (start_refresh, mut on_refresh_complete) = refresh_task_runner.start()?; diff --git a/crates/runtime/src/accelerated_table/refresh_task.rs b/crates/runtime/src/accelerated_table/refresh_task.rs index d2907d17e1..a7e3e22af1 100644 --- a/crates/runtime/src/accelerated_table/refresh_task.rs +++ b/crates/runtime/src/accelerated_table/refresh_task.rs @@ -129,6 +129,9 @@ pub struct RefreshTaskBuilder { last_updated_at: Arc, /// Whether the acceleration uses S3 Express One Zone storage. is_s3_express_acceleration: bool, + /// State for `refresh_mode: snapshot`. Required when the refresh mode is + /// [`RefreshMode::Snapshot`]; ignored otherwise. + snapshot_refresh_state: Option, } impl RefreshTaskBuilder { @@ -158,6 +161,7 @@ impl RefreshTaskBuilder { on_stream_batch_process_callback: None, last_updated_at: Arc::new(AtomicI64::new(0)), is_s3_express_acceleration: false, + snapshot_refresh_state: None, } } @@ -217,6 +221,16 @@ impl RefreshTaskBuilder { self } + /// Provide the snapshot-refresh state required for `RefreshMode::Snapshot`. + #[must_use] + pub fn with_snapshot_refresh_state( + mut self, + state: Option, + ) -> RefreshTaskBuilder { + self.snapshot_refresh_state = state; + self + } + #[must_use] pub fn build(self) -> RefreshTask { let semaphore = self @@ -267,6 +281,7 @@ impl RefreshTaskBuilder { on_stream_batch_process_callback: self.on_stream_batch_process_callback, last_updated_at: self.last_updated_at, is_s3_express_acceleration: self.is_s3_express_acceleration, + snapshot_refresh_state: self.snapshot_refresh_state, } } } @@ -291,6 +306,9 @@ pub struct RefreshTask { last_updated_at: Arc, /// Whether the acceleration uses S3 Express One Zone storage. is_s3_express_acceleration: bool, + /// Per-dataset state required for `RefreshMode::Snapshot`. `None` for all + /// other refresh modes. + snapshot_refresh_state: Option, } impl std::fmt::Debug for RefreshTask { @@ -479,6 +497,7 @@ impl RefreshTask { RefreshMode::Full | RefreshMode::Append => &metrics::REFRESH_DURATION_MS, RefreshMode::Changes => unreachable!("changes are handled upstream"), RefreshMode::Caching => &metrics::REFRESH_DURATION_MS, + RefreshMode::Snapshot => &metrics::REFRESH_DURATION_MS, }, &dataset_metrics_label_sets, ); @@ -499,6 +518,12 @@ impl RefreshTask { // For caching mode, identify and refresh stale rows based on fetched_at and TTL return self.refresh_stale_cached_rows(refresh).await; } + RefreshMode::Snapshot => { + // For snapshot mode, poll the snapshot store for a newer snapshot + // and reload the accelerator from it. The federated source is + // never queried for refreshes in this mode. + return self.refresh_from_snapshot(refresh).await; + } }; let streaming_data_update = match get_data_update_result { @@ -903,6 +928,236 @@ impl RefreshTask { Ok(()) } + /// Drives `RefreshMode::Snapshot`: poll the snapshot store for a snapshot + /// strictly newer than what is currently loaded; if found, download it + /// (which writes to the accelerator's primary path) and call into the + /// accelerator's `reload_from_snapshot` to swap in a fresh `TableProvider`. + /// + /// The federated source is never queried by this code path. When no newer + /// snapshot is available the call is a no-op (Ready, no swap). + async fn refresh_from_snapshot( + &self, + refresh: &Refresh, + ) -> Result<(), RetryError> { + let _ = refresh; // refresh sql / window are intentionally unused for snapshot mode + + let Some(state) = self.snapshot_refresh_state.clone() else { + // This is a configuration bug: the refresh mode is Snapshot but no + // SnapshotRefreshState was attached. Surface as a permanent error so + // the dataset is marked unhealthy rather than retried indefinitely. + tracing::error!( + dataset = %self.dataset_name, + "refresh_mode: snapshot is configured but no SnapshotRefreshState is available; \ + this indicates a runtime configuration bug." + ); + self.set_refresh_status( + None, + status::ComponentStatus::error_with_message("snapshot refresh failure".to_string()), + ) + .await; + return Err(RetryError::permanent( + super::Error::FailedToRefreshDataset { + source: datafusion::error::DataFusionError::Internal( + "snapshot refresh state missing".to_string(), + ), + }, + )); + }; + + self.set_refresh_status(None, status::ComponentStatus::Refreshing) + .await; + + let start_time = SystemTime::now(); + let current_local_id = state.current_loaded_id(); + + // Take the accelerator write mutex up front so the entire refresh + // (download + provider rebuild + swap) is serialized with other code + // paths that take this mutex. `AcceleratedTable::insert_into` rejects + // writes outright when `refresh_mode: snapshot` is enabled, so this + // mutex's only remaining job here is to serialize concurrent snapshot + // refreshes / cache writes against the swap. The atomic rename inside + // `download_if_newer` independently protects against partial-file + // reads from in-flight queries that hold their own connection refs to + // the prior file inode. + let _write_guard = Arc::clone(&self.accelerator_write_mutex).lock_owned().await; + + // Hand the snapshot manager a schema validator that runs against + // the snapshot metadata's recorded schema **before** the file is + // downloaded or renamed. This guarantees a schema-incompatible + // snapshot can never overwrite the accelerator's primary file. + let live_schema = state.swappable_provider.schema(); + let live_schema_for_validate = Arc::clone(&live_schema); + let validator: Box bool + Send + Sync> = + Box::new(move |candidate: &arrow_schema::SchemaRef| { + schemas_compatible(candidate.as_ref(), live_schema_for_validate.as_ref()) + }); + let download_result = state + .manager + .download_if_newer(current_local_id, Some(validator.as_ref())) + .await; + + let info = match download_result { + Ok(Some(info)) => info, + Ok(None) => { + tracing::debug!( + dataset = %self.dataset_name, + current_snapshot_id = ?current_local_id, + "refresh_mode: snapshot - no newer snapshot available; skipping reload" + ); + let dataset_metrics_label_sets = + self.get_dataset_label_sets(&RefreshMode::Snapshot).await; + for label_set in &dataset_metrics_label_sets { + metrics::REFRESH_DATA_FETCHES_SKIPPED.add(1, label_set); + } + self.set_refresh_status(None, status::ComponentStatus::Ready) + .await; + return Ok(()); + } + Err(e) => { + tracing::warn!( + dataset = %self.dataset_name, + error = %e, + "refresh_mode: snapshot - failed to check/download snapshot" + ); + self.set_refresh_status( + None, + status::ComponentStatus::error_with_message( + "snapshot refresh failure".to_string(), + ), + ) + .await; + return Err(RetryError::transient( + super::Error::FailedToRefreshDataset { + source: datafusion::error::DataFusionError::External(Box::new(e)), + }, + )); + } + }; + + // The snapshot manager already rejected schema-incompatible + // snapshots before download; this is a defense-in-depth check + // against the (rare) case where the metadata's recorded schema + // differed from the schema actually embedded in the downloaded + // file. The downloaded file may have replaced the primary path + // here, but `reload_from_snapshot` is gated below — and a + // schema-mismatch returned here is treated as permanent. + if !schemas_compatible(info.schema.as_ref(), live_schema.as_ref()) { + tracing::error!( + dataset = %self.dataset_name, + snapshot_id = info.snapshot_id, + "refresh_mode: snapshot - downloaded snapshot schema does not match \ + accelerator schema; refusing to swap" + ); + self.set_refresh_status( + None, + status::ComponentStatus::error_with_message("snapshot refresh failure".to_string()), + ) + .await; + return Err(RetryError::permanent( + super::Error::FailedToRefreshDataset { + source: datafusion::error::DataFusionError::Internal( + "snapshot schema mismatch".to_string(), + ), + }, + )); + } + + // The accelerator write mutex was taken above, before the download, + // so the entire reload + swap remains serialized with concurrent + // accelerator writes. + let new_provider = match state + .accelerator + .reload_from_snapshot( + state.source.as_ref(), + state.swappable_provider.current(), + Arc::clone(&state.provider_factory), + ) + .await + { + Ok(p) => p, + Err(e) => { + tracing::error!( + dataset = %self.dataset_name, + snapshot_id = info.snapshot_id, + error = %e, + "refresh_mode: snapshot - accelerator failed to reload from snapshot" + ); + self.set_refresh_status( + None, + status::ComponentStatus::error_with_message( + "snapshot refresh failure".to_string(), + ), + ) + .await; + return Err(RetryError::transient( + super::Error::FailedToRefreshDataset { + source: datafusion::error::DataFusionError::Internal(e.to_string()), + }, + )); + } + }; + + if !schemas_compatible(new_provider.schema().as_ref(), live_schema.as_ref()) { + tracing::error!( + dataset = %self.dataset_name, + snapshot_id = info.snapshot_id, + "refresh_mode: snapshot - reloaded provider schema does not match accelerator \ + schema; refusing to swap" + ); + self.set_refresh_status( + None, + status::ComponentStatus::error_with_message("snapshot refresh failure".to_string()), + ) + .await; + return Err(RetryError::permanent( + super::Error::FailedToRefreshDataset { + source: datafusion::error::DataFusionError::Internal( + "reloaded snapshot provider schema mismatch".to_string(), + ), + }, + )); + } + + if let Err(swap_err) = state.swappable_provider.swap(new_provider) { + tracing::error!( + dataset = %self.dataset_name, + snapshot_id = info.snapshot_id, + error = %swap_err, + "refresh_mode: snapshot - swap rejected by SwappableTableProvider" + ); + self.set_refresh_status( + None, + status::ComponentStatus::error_with_message("snapshot refresh failure".to_string()), + ) + .await; + return Err(RetryError::permanent( + super::Error::FailedToRefreshDataset { + source: datafusion::error::DataFusionError::Internal(format!( + "snapshot swap rejected: {swap_err}" + )), + }, + )); + } + state.set_current_loaded_id(info.snapshot_id); + if let Some(updated_at) = info.last_updated_at { + self.last_updated_at + .store(updated_at, std::sync::atomic::Ordering::Release); + } + + if let Ok(elapsed) = util::humantime_elapsed(start_time) { + tracing::info!( + dataset = %self.dataset_name, + snapshot_id = info.snapshot_id, + bytes = info.bytes_downloaded, + "Loaded snapshot in {elapsed}" + ); + } + + self.set_refresh_status(None, status::ComponentStatus::Ready) + .await; + Ok(()) + } + async fn trace_load_completed( &self, start_time: SystemTime, @@ -955,6 +1210,9 @@ impl RefreshTask { RefreshMode::Append => UpdateType::Append, RefreshMode::Changes => unreachable!("changes are handled upstream"), RefreshMode::Caching => UpdateType::Overwrite, + RefreshMode::Snapshot => { + unreachable!("snapshot mode is handled by refresh_from_snapshot") + } }; // If a refresh SQL is explicitly provided for this `RefreshTask` (instead of provided at startup within the @@ -1650,6 +1908,15 @@ impl RefreshTask { } } +/// Returns true when `candidate` is structurally compatible with `expected` +/// for swapping a `TableProvider` under a `SwappableTableProvider`. See +/// [`crate::dataaccelerator::swappable::schemas_compatible`] for the precise +/// rules; this is a thin re-export so callers in this module can keep using +/// the unqualified name. +fn schemas_compatible(candidate: &arrow_schema::Schema, expected: &arrow_schema::Schema) -> bool { + crate::dataaccelerator::swappable::schemas_compatible(candidate, expected) +} + #[derive(Debug)] /// Tracks and logs data load progress for a dataset, periodically reporting the number of records received struct DataLoadTracing { diff --git a/crates/runtime/src/accelerated_table/refresh_task_runner.rs b/crates/runtime/src/accelerated_table/refresh_task_runner.rs index 3803f91efb..4909686b9f 100644 --- a/crates/runtime/src/accelerated_table/refresh_task_runner.rs +++ b/crates/runtime/src/accelerated_table/refresh_task_runner.rs @@ -59,6 +59,7 @@ pub struct RefreshTaskRunnerBuilder { last_updated_at: Arc, /// Whether the acceleration uses S3 Express One Zone storage. is_s3_express_acceleration: bool, + snapshot_refresh_state: Option, } impl RefreshTaskRunnerBuilder { @@ -90,6 +91,7 @@ impl RefreshTaskRunnerBuilder { accelerator_write_mutex, last_updated_at: Arc::new(AtomicI64::new(0)), is_s3_express_acceleration: false, + snapshot_refresh_state: None, } } @@ -140,6 +142,16 @@ impl RefreshTaskRunnerBuilder { self } + /// Provide the snapshot-refresh state required for `RefreshMode::Snapshot`. + #[must_use] + pub fn with_snapshot_refresh_state( + mut self, + state: Option, + ) -> Self { + self.snapshot_refresh_state = state; + self + } + #[must_use] pub fn build(self) -> RefreshTaskRunner { let mut refresh_task_builder = RefreshTask::builder( @@ -168,6 +180,9 @@ impl RefreshTaskRunnerBuilder { refresh_task_builder = refresh_task_builder.with_s3_express_acceleration(self.is_s3_express_acceleration); + refresh_task_builder = + refresh_task_builder.with_snapshot_refresh_state(self.snapshot_refresh_state); + let refresh_task = Arc::new(refresh_task_builder.build()); RefreshTaskRunner { diff --git a/crates/runtime/src/accelerated_table/snapshots.rs b/crates/runtime/src/accelerated_table/snapshots.rs index 6789c0d5c5..c02bc17788 100644 --- a/crates/runtime/src/accelerated_table/snapshots.rs +++ b/crates/runtime/src/accelerated_table/snapshots.rs @@ -12,6 +12,10 @@ limitations under the License. */ use crate::accelerated_table::SnapshotCreateTrigger; use crate::accelerated_table::refresh::Refresh; +use crate::dataaccelerator::AccelerationSource; +use crate::dataaccelerator::DataAccelerator; +use crate::dataaccelerator::ReloadProviderFactory; +use crate::dataaccelerator::swappable::SwappableTableProvider; use crate::status::RuntimeStatus; use arrow_schema::Schema; use datafusion::common::TableReference; @@ -21,11 +25,70 @@ use runtime_acceleration::dataset_checkpoint::DatasetCheckpointer; use runtime_acceleration::snapshot::{ForceCreate, SnapshotManager, metrics as snapshot_metrics}; use std::pin::Pin; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; use std::time::Duration; use tokio::sync::{Mutex, RwLock}; use tokio::time::interval; +/// Per-dataset state required to drive `refresh_mode: snapshot`. +/// +/// This bundle is built once during dataset registration and threaded down +/// through `Refresher` -> `RefreshTaskBuilder` -> `RefreshTask`. The refresh +/// task uses it on every tick to: +/// 1. Compare the remote `current_snapshot_id` against `current_snapshot_id`. +/// 2. Download and reload only when a strictly newer snapshot is available. +/// 3. Atomically swap the live `TableProvider` via `swappable_provider`. +#[derive(Clone)] +pub struct SnapshotRefreshState { + pub manager: Arc, + pub accelerator: Arc, + pub source: Arc, + pub swappable_provider: Arc, + /// Factory that re-runs `create_accelerator_table` for this dataset to + /// build a fresh provider over the on-disk snapshot file. + pub provider_factory: ReloadProviderFactory, + /// The currently-loaded snapshot id, if any. `None` means no snapshot has + /// been loaded yet for this dataset (e.g. fresh start with no bootstrap). + /// Wrapped in a sync `Mutex` because updates are infrequent (once per + /// successful reload) and the inner `Option` is `Copy` so reads are + /// trivial. Snapshot ids are not constrained: id `0` is a valid first + /// snapshot, so `Option` is the correct representation rather than + /// using a sentinel value. + pub current_snapshot_id: Arc>>, +} + +impl std::fmt::Debug for SnapshotRefreshState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let id = self.current_snapshot_id.lock().map(|g| *g).unwrap_or(None); + f.debug_struct("SnapshotRefreshState") + .field("current_snapshot_id", &id) + .finish_non_exhaustive() + } +} + +impl SnapshotRefreshState { + /// Returns the currently-loaded snapshot id, or `None` if no snapshot has + /// been loaded yet. + #[must_use] + pub fn current_loaded_id(&self) -> Option { + self.current_snapshot_id + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .as_ref() + .copied() + } + + /// Records `snapshot_id` as the most recently loaded snapshot id. + pub fn set_current_loaded_id(&self, snapshot_id: u64) { + let mut guard = self + .current_snapshot_id + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = Some(snapshot_id); + } +} + #[derive(Debug, Clone)] pub struct SnapshotCreationConfig { pub manager: Arc, diff --git a/crates/runtime/src/component/dataset/acceleration.rs b/crates/runtime/src/component/dataset/acceleration.rs index 0cceb5bbb1..64d3668d93 100644 --- a/crates/runtime/src/component/dataset/acceleration.rs +++ b/crates/runtime/src/component/dataset/acceleration.rs @@ -43,6 +43,17 @@ pub enum RefreshMode { Append, Changes, Caching, + /// Reload accelerator data from newer snapshots only; the federated + /// source is never queried for refreshes. + Snapshot, +} + +impl RefreshMode { + /// Returns true if this refresh mode never reads from the federated source. + #[must_use] + pub const fn is_snapshot_only(&self) -> bool { + matches!(self, RefreshMode::Snapshot) + } } impl From for RefreshMode { @@ -52,6 +63,7 @@ impl From for RefreshMode { spicepod_acceleration::RefreshMode::Append => RefreshMode::Append, spicepod_acceleration::RefreshMode::Changes => RefreshMode::Changes, spicepod_acceleration::RefreshMode::Caching => RefreshMode::Caching, + spicepod_acceleration::RefreshMode::Snapshot => RefreshMode::Snapshot, } } } diff --git a/crates/runtime/src/dataaccelerator/cayenne/mod.rs b/crates/runtime/src/dataaccelerator/cayenne/mod.rs index da03696ddd..3a1aeb34e1 100644 --- a/crates/runtime/src/dataaccelerator/cayenne/mod.rs +++ b/crates/runtime/src/dataaccelerator/cayenne/mod.rs @@ -15,6 +15,7 @@ limitations under the License. */ pub mod s3; +pub mod snapshot_engine; use std::any::Any; use std::collections::HashMap; @@ -1000,6 +1001,14 @@ impl DataAccelerator for CayenneAccelerator { snapshot_layout, AccelerationEngine::Cayenne, Arc::new(arrow_schema::Schema::empty()), + // For pre-recreate snapshots we don't have a constructed + // catalog handy (the metastore directory may even be in + // a transient state). Pass None and accept the default + // engine; the resulting snapshot will use the legacy + // archive-cayenne.db path. This is acceptable because + // pre-recreate snapshots are best-effort backups, not + // refresh_mode: snapshot sources. + None, ) .await; @@ -1056,13 +1065,45 @@ impl DataAccelerator for CayenneAccelerator { if let Some(acceleration) = source.acceleration() { let metadata_dir = PathBuf::from(Self::resolve_metadata_dir(Some(acceleration))); - let snapshot_adapter = - runtime_acceleration::snapshot::AccelerationLayout::cayenne(metadata_dir, path_buf); + let snapshot_adapter = runtime_acceleration::snapshot::AccelerationLayout::cayenne( + metadata_dir.clone(), + path_buf.clone(), + ); + // Build a CayenneSnapshotEngine so the snapshot tar uses the + // per-dataset metastore-slice format (no raw cayenne.db file) + // and so `download_latest_snapshot` imports the slice into the + // local metastore as the final extraction step. + let metastore_type = acceleration + .params + .get("cayenne_metastore") + .map_or("sqlite", String::as_str) + .to_string(); + let snapshot_engine = match self + .get_or_create_catalog(&metadata_dir.to_string_lossy(), &metastore_type) + .await + { + Ok(catalog) => Some(Arc::new( + crate::dataaccelerator::cayenne::snapshot_engine::CayenneSnapshotEngine::new( + catalog, + source.name().to_string(), + path_buf.clone(), + ), + ) + as Arc), + Err(err) => { + tracing::warn!( + "Failed to build CayenneSnapshotEngine for snapshot bootstrap, \ + falling back to default engine: {err}" + ); + None + } + }; Ok(download_snapshot_if_needed( acceleration, source, snapshot_adapter, AccelerationEngine::Cayenne, + snapshot_engine, ) .await) } else { @@ -1310,6 +1351,75 @@ impl DataAccelerator for CayenneAccelerator { PARAMETERS } + fn supports_snapshot_reload(&self) -> bool { + true + } + + /// Build a [`CayenneSnapshotEngine`] for this source so the on-disk + /// archive uses the per-dataset metastore-slice format (and the writer + /// skips `cayenne.db` / `-wal` / `-shm`). Returning `None` falls back to + /// the default `SnapshotManager` engine, which will include the raw + /// `cayenne.db` file (and its journal sidecar) — that legacy format + /// breaks `refresh_mode: snapshot` because the reader's local metastore + /// already exists at extract time. + async fn snapshot_engine_for_source( + &self, + source: &dyn AccelerationSource, + ) -> Option> { + let acceleration = source.acceleration()?; + let metadata_dir = PathBuf::from(Self::resolve_metadata_dir(Some(acceleration))); + let metastore_type = acceleration + .params + .get("cayenne_metastore") + .map_or("sqlite", String::as_str) + .to_string(); + let catalog = match self + .get_or_create_catalog(&metadata_dir.to_string_lossy(), &metastore_type) + .await + { + Ok(catalog) => catalog, + Err(err) => { + tracing::warn!( + "Failed to build CayenneSnapshotEngine for snapshot create/extract; \ + falling back to default engine: {err}" + ); + return None; + } + }; + let dir_path = match self.cayenne_data_dir(source) { + Ok(p) => p, + Err(err) => { + tracing::warn!( + "Failed to resolve cayenne data dir for snapshot engine; falling back to default engine: {err}" + ); + return None; + } + }; + Some(Arc::new( + crate::dataaccelerator::cayenne::snapshot_engine::CayenneSnapshotEngine::new( + catalog, + source.name().to_string(), + PathBuf::from(dir_path), + ), + )) + } + + /// Reloads the Cayenne-backed table provider from the snapshot directory + /// that was just restored to the accelerator's primary location. + /// + /// Cayenne uses a per-dataset directory layout; dropping the previous + /// provider releases the cached `Vortex` segment/footer caches, and the + /// factory then reopens the directory tree from disk. + async fn reload_from_snapshot( + &self, + _source: &dyn AccelerationSource, + previous_provider: Arc, + provider_factory: super::ReloadProviderFactory, + ) -> Result, Box> { + drop(previous_provider); + provider_factory().await + } + async fn drop_table( &self, table_name: &str, diff --git a/crates/runtime/src/dataaccelerator/cayenne/snapshot_engine.rs b/crates/runtime/src/dataaccelerator/cayenne/snapshot_engine.rs new file mode 100644 index 0000000000..7a5163d5fa --- /dev/null +++ b/crates/runtime/src/dataaccelerator/cayenne/snapshot_engine.rs @@ -0,0 +1,460 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Cayenne-specific snapshot engine. +//! +//! Cayenne stores per-table metadata in a shared SQLite/libSQL database +//! (`cayenne.db`). Shipping that file as part of a Cayenne snapshot is +//! problematic for three reasons: +//! +//! 1. **Path portability** (#10642): `cayenne_table.path`, +//! `cayenne_partition.path` and `cayenne_delete_file.path` store absolute +//! filesystem paths from the writer; readers with a different data +//! directory cannot resolve them. +//! +//! 2. **Multi-dataset clobbering**: `cayenne.db` contains rows for *every* +//! dataset sharing the metadata directory. Two datasets snapshotting the +//! same `cayenne.db` and extracting on a fresh reader would each clobber +//! the other's metastore rows, which is why +//! `validate_cayenne_snapshot_consistency` currently rejects multi-dataset +//! metastore directories. +//! +//! 3. **Init race / sidecars** (#10649): the reader's eager metastore +//! initialization opens `cayenne.db`, creating `cayenne.db-wal` / +//! `-shm` sidecars before snapshot extraction runs, breaking the +//! archive's checksum verification. +//! +//! `CayenneSnapshotEngine` fixes all three by **never** archiving +//! `cayenne.db*`. Instead, on the create side it serializes a per-dataset +//! metastore "slice" (versioned JSON, see +//! [`cayenne::metastore::snapshot::DatasetMetastoreSlice`]) and inserts it +//! into the tar at a well-known archive path. On the extract side it reads +//! the slice back and atomically imports it into the local metastore. +//! +//! Path columns in the slice are rewritten relative to the writer's data +//! directory at export time and re-anchored at the reader's data directory +//! on import, making the snapshot portable across nodes with different +//! local layouts. + +use std::path::PathBuf; +use std::sync::Arc; + +use async_trait::async_trait; +use cayenne::MetadataCatalog; +use cayenne::metastore::snapshot::DatasetMetastoreSlice; +use runtime_acceleration::snapshot::engine::{ + DirectoryArchiveExtra, DirectorySnapshotPlan, SnapshotEngine, SnapshotEngineError, +}; +use snafu::{ResultExt, Snafu}; +use tokio::fs; + +/// Well-known archive entry path for a Cayenne dataset's metastore slice. +/// The dataset name is included so multiple per-dataset slices can coexist +/// in the same tar in the (currently unused, but designed-for) future where +/// a snapshot covers more than one dataset. +/// Archive path for the per-dataset metastore slice JSON. +/// +/// Uses the `metadata/` prefix so it lines up with +/// `AccelerationLayout::cayenne`'s metadata-directory mapping. On extract, +/// `download_to_directories` writes it under the local metadata directory as +/// `/.slice.json`. +fn slice_archive_path(dataset_name: &str) -> String { + format!("metadata/{dataset_name}.slice.json") +} + +/// File names (relative to `metadata_dir`) that must be excluded from the +/// archive. Cayenne always opens the metastore in WAL journal mode, so the +/// `-wal` and `-shm` sidecars may be present alongside `cayenne.db`. +const METASTORE_FILES: &[&str] = &["cayenne.db", "cayenne.db-wal", "cayenne.db-shm"]; + +/// Errors raised by the Cayenne snapshot engine. +#[derive(Debug, Snafu)] +pub enum CayenneSnapshotError { + #[snafu(display("Cayenne metastore export failed for dataset '{dataset}': {source}"))] + Export { + dataset: String, + source: cayenne::CatalogError, + }, + + #[snafu(display("Cayenne metastore import failed for dataset '{dataset}': {source}"))] + Import { + dataset: String, + source: cayenne::CatalogError, + }, + + #[snafu(display("Failed to serialize Cayenne metastore slice for '{dataset}': {source}"))] + Serialize { + dataset: String, + source: serde_json::Error, + }, + + #[snafu(display( + "Cayenne snapshot at {path:?} is missing the per-dataset metastore slice. \ + The snapshot was likely produced by an older Spice that shipped the \ + raw cayenne.db file; that format is no longer supported. \ + Recreate the snapshot from a current writer." + ))] + MissingSlice { path: PathBuf }, + + #[snafu(display("Failed to read metastore slice from {path:?}: {source}"))] + ReadSlice { + path: PathBuf, + source: std::io::Error, + }, +} + +/// Snapshot engine for Cayenne accelerators. +/// +/// Holds an [`Arc`] so it can call +/// [`MetadataCatalog::export_dataset_slice`] / `import_dataset_slice` against +/// the same metastore the accelerator is using at runtime. +pub struct CayenneSnapshotEngine { + /// Cayenne metastore (sqlite or libsql) the engine talks to. + catalog: Arc, + /// Logical dataset name (the value of `cayenne_table.table_name`). + dataset_name: String, + /// Local data directory anchor used to rewrite path columns relative + /// on export and absolute on import. The export-side anchor must contain + /// the absolute paths stored in the metastore as a strict prefix; the + /// import-side anchor is where the new paths will be re-rooted. + data_dir_anchor: PathBuf, +} + +impl CayenneSnapshotEngine { + pub fn new( + catalog: Arc, + dataset_name: impl Into, + data_dir_anchor: PathBuf, + ) -> Self { + Self { + catalog, + dataset_name: dataset_name.into(), + data_dir_anchor, + } + } + + /// Returns the dataset name this engine snapshots. + #[must_use] + pub fn dataset_name(&self) -> &str { + &self.dataset_name + } + + /// Returns the data-dir anchor used for path rewriting. + #[must_use] + pub fn data_dir_anchor(&self) -> &std::path::Path { + &self.data_dir_anchor + } + + /// Convenience: turn a `CayenneSnapshotError` into a + /// `SnapshotEngineError::Generic` (or its closest analog) so the trait + /// signature stays clean. + fn engine_err(err: &CayenneSnapshotError) -> SnapshotEngineError { + // SnapshotEngineError doesn't have a Cayenne variant; surface as a + // generic boxed error via Display (the trait error is non-exhaustive + // at the call site, which renders Display). + SnapshotEngineError::from_display(err.to_string()) + } +} + +#[async_trait] +impl SnapshotEngine for CayenneSnapshotEngine { + async fn prepare_for_upload( + &self, + source_path: &std::path::Path, + _dataset_name: &str, + ) -> Result { + // Cayenne snapshots are directory-layout, not file-layout, so + // prepare_for_upload should never be called on this engine. Keep + // a passthrough for defense. + Ok(source_path.to_path_buf()) + } + + fn supports_compaction(&self) -> bool { + false + } + + async fn prepare_directory_snapshot( + &self, + _dirs: &[(PathBuf, String)], + dataset_name: &str, + ) -> Result { + // Sanity: refuse to snapshot a dataset other than the one we were + // constructed for. + if dataset_name != self.dataset_name { + return Err(SnapshotEngineError::from_display(format!( + "CayenneSnapshotEngine constructed for dataset '{}' but asked to snapshot '{}'", + self.dataset_name, dataset_name + ))); + } + + // 1. Export the per-dataset metastore slice. + let slice = self + .catalog + .export_dataset_slice(&self.dataset_name, &self.data_dir_anchor) + .await + .context(ExportSnafu { + dataset: self.dataset_name.clone(), + }) + .map_err(|e| Self::engine_err(&e))?; + + // 2. Serialize to JSON. + let bytes = slice + .to_json_bytes() + .context(SerializeSnafu { + dataset: self.dataset_name.clone(), + }) + .map_err(|e| Self::engine_err(&e))?; + + // 3. Build a plan: skip the cayenne.db* files, add the slice as an extra. + let skip = METASTORE_FILES.iter().map(PathBuf::from).collect(); + let extras = vec![DirectoryArchiveExtra { + archive_path: slice_archive_path(&self.dataset_name), + bytes, + }]; + + Ok(DirectorySnapshotPlan { + skip_relative_paths: skip, + extra_entries: extras, + }) + } + + async fn finalize_directory_snapshot( + &self, + dirs: &[(PathBuf, String)], + dataset_name: &str, + ) -> Result<(), SnapshotEngineError> { + if dataset_name != self.dataset_name { + return Err(SnapshotEngineError::from_display(format!( + "CayenneSnapshotEngine constructed for dataset '{}' but asked to extract '{}'", + self.dataset_name, dataset_name + ))); + } + + // The archive was extracted via prefix mappings, so the slice + // landed at `/.slice.json` (its + // archive path uses the same `metadata/` prefix that + // `AccelerationLayout::cayenne` configures). + let slice_filename = format!("{dataset_name}.slice.json"); + let metadata_candidates: Vec = dirs + .iter() + .filter(|(_, prefix)| prefix.starts_with("metadata")) + .map(|(target_dir, _)| target_dir.join(&slice_filename)) + .collect(); + // Fallback: search every dir, in case the layout prefix list + // changes shape in the future. + let candidate_paths: Vec = if metadata_candidates.is_empty() { + dirs.iter() + .map(|(target_dir, _)| target_dir.join(&slice_filename)) + .collect() + } else { + metadata_candidates + }; + + let mut slice_path: Option = None; + for cand in &candidate_paths { + match fs::try_exists(cand).await { + Ok(true) => { + slice_path = Some(cand.clone()); + break; + } + Ok(false) => {} // try the next candidate + Err(err) => { + // A real I/O error here (permissions, transient failure) + // would otherwise be silently swallowed and surface as a + // misleading `MissingSlice` below. Fail loudly instead. + return Err(SnapshotEngineError::from_display(format!( + "CayenneSnapshotEngine: failed to stat candidate slice path {}: {err}", + cand.display(), + ))); + } + } + } + let slice_path = slice_path.ok_or_else(|| { + Self::engine_err(&CayenneSnapshotError::MissingSlice { + path: candidate_paths + .first() + .cloned() + .unwrap_or_else(|| PathBuf::from(&slice_filename)), + }) + })?; + + let bytes = fs::read(&slice_path) + .await + .context(ReadSliceSnafu { + path: slice_path.clone(), + }) + .map_err(|e| Self::engine_err(&e))?; + let slice = DatasetMetastoreSlice::from_json_bytes(&bytes).map_err(|e| { + Self::engine_err(&CayenneSnapshotError::Import { + dataset: self.dataset_name.clone(), + source: e, + }) + })?; + + self.catalog + .import_dataset_slice(&slice, &self.data_dir_anchor) + .await + .context(ImportSnafu { + dataset: self.dataset_name.clone(), + }) + .map_err(|e| Self::engine_err(&e))?; + + // Best-effort: remove the slice file so it doesn't sit in the data + // directory after import. Its information now lives in the local + // metastore. + let _ = fs::remove_file(&slice_path).await; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use cayenne::CayenneCatalog; + use cayenne::metadata::CreateTableOptions; + use std::sync::Arc; + + async fn fresh_catalog(dir: &std::path::Path) -> Arc { + let conn = format!("sqlite://{}/cayenne.db", dir.display()); + let catalog = Arc::new(CayenneCatalog::new(conn).expect("catalog")); + catalog.init().await.expect("init"); + catalog + } + + fn schema() -> Arc { + Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "id", + arrow_schema::DataType::Int64, + false, + )])) + } + + #[tokio::test] + async fn create_directory_snapshot_skips_cayenne_db_and_emits_slice() { + let tmp = tempfile::tempdir().expect("tmp"); + let metadata_dir = tmp.path().join("metadata"); + std::fs::create_dir_all(&metadata_dir).expect("mkdir metadata"); + let data_dir = tmp.path().join("data").join("trips"); + std::fs::create_dir_all(&data_dir).expect("mkdir data"); + + let catalog = fresh_catalog(&metadata_dir).await; + catalog + .create_table(CreateTableOptions { + table_name: "trips".to_string(), + schema: schema(), + primary_key: vec![], + on_conflict: None, + base_path: data_dir.to_string_lossy().into_owned(), + partition_column: None, + vortex_config: cayenne::metadata::VortexConfig::default(), + }) + .await + .expect("create_table"); + + // Pre-populate cayenne.db so it shows up under metadata_dir. + // The catalog's init has already written cayenne.db; nothing more to do. + + let engine = CayenneSnapshotEngine::new( + catalog as Arc, + "trips", + data_dir.clone(), + ); + + let dirs = vec![ + (metadata_dir.clone(), "metadata/".to_string()), + (data_dir.clone(), "data/".to_string()), + ]; + + let plan = engine + .prepare_directory_snapshot(&dirs, "trips") + .await + .expect("prepare_directory_snapshot"); + + // Expect cayenne.db files to be in skip list. + assert!( + plan.skip_relative_paths + .contains(&PathBuf::from("cayenne.db")) + ); + assert!( + plan.skip_relative_paths + .contains(&PathBuf::from("cayenne.db-wal")) + ); + assert!( + plan.skip_relative_paths + .contains(&PathBuf::from("cayenne.db-shm")) + ); + + // Expect exactly one extra entry: the slice JSON. + assert_eq!(plan.extra_entries.len(), 1); + let extra = &plan.extra_entries[0]; + assert_eq!(extra.archive_path, "metadata/trips.slice.json"); + + // Sanity: the JSON parses as a versioned slice. + let slice = + cayenne::metastore::snapshot::DatasetMetastoreSlice::from_json_bytes(&extra.bytes) + .expect("parse slice"); + assert_eq!(slice.dataset_name, "trips"); + } + + #[tokio::test] + async fn refuses_mismatched_dataset() { + let tmp = tempfile::tempdir().expect("tmp"); + let metadata_dir = tmp.path().join("metadata"); + std::fs::create_dir_all(&metadata_dir).expect("mkdir metadata"); + let catalog = fresh_catalog(&metadata_dir).await; + + let engine = CayenneSnapshotEngine::new( + catalog as Arc, + "trips", + tmp.path().to_path_buf(), + ); + + let err = engine + .prepare_directory_snapshot(&[], "riders") + .await + .expect_err("must reject mismatched dataset name"); + assert!(err.to_string().contains("trips")); + assert!(err.to_string().contains("riders")); + } + + #[tokio::test] + async fn finalize_missing_slice_returns_clear_error() { + let tmp = tempfile::tempdir().expect("tmp"); + let metadata_dir = tmp.path().join("metadata"); + std::fs::create_dir_all(&metadata_dir).expect("mkdir metadata"); + let catalog = fresh_catalog(&metadata_dir).await; + + let engine = CayenneSnapshotEngine::new( + catalog as Arc, + "trips", + tmp.path().to_path_buf(), + ); + + // No slice file present in metadata_dir. + let dirs = vec![(metadata_dir.clone(), "metadata/".to_string())]; + let err = engine + .finalize_directory_snapshot(&dirs, "trips") + .await + .expect_err("must error when slice is missing"); + let msg = err.to_string(); + assert!( + msg.contains("missing the per-dataset metastore slice"), + "msg={msg}" + ); + assert!(msg.contains("older Spice"), "msg={msg}"); + } +} diff --git a/crates/runtime/src/dataaccelerator/duckdb.rs b/crates/runtime/src/dataaccelerator/duckdb.rs index 30f86fed1d..27e7456e6f 100644 --- a/crates/runtime/src/dataaccelerator/duckdb.rs +++ b/crates/runtime/src/dataaccelerator/duckdb.rs @@ -56,7 +56,10 @@ use datafusion_table_providers::{ DuckDB, DuckDBSettingsRegistry, DuckDBTableProviderFactory, write::{DuckDBTableWriter, WriteCompletionHandler}, }, - sql::db_connection_pool::duckdbpool::{DuckDbConnectionPool, DuckDbConnectionPoolBuilder}, + sql::db_connection_pool::{ + self as db_connection_pool, + duckdbpool::{DuckDbConnectionPool, DuckDbConnectionPoolBuilder}, + }, }; use duckdb::AccessMode; use itertools::Itertools; @@ -393,6 +396,7 @@ impl DataAccelerator for DuckDBAccelerator { )), AccelerationEngine::DuckDB, Arc::new(arrow_schema::Schema::empty()), + None, ) .await; @@ -411,6 +415,7 @@ impl DataAccelerator for DuckDBAccelerator { source, runtime_acceleration::snapshot::AccelerationLayout::file(PathBuf::from(path)), AccelerationEngine::DuckDB, + None, ) .await; @@ -584,6 +589,58 @@ impl DataAccelerator for DuckDBAccelerator { PARAMETERS } + fn supports_snapshot_reload(&self) -> bool { + true + } + + /// Reloads the `DuckDB`-backed table provider from the snapshot file + /// that was just written to the primary path. + /// + /// Drops the previous provider, evicts the cached connection pool from + /// the upstream `DuckDBTableProviderFactory` registry, and then re-runs + /// the registry factory to build a fresh provider over the on-disk file. + /// The pool eviction is required because the registry caches pool + /// instances by file path; without it, the freshly built provider would + /// reuse the prior pool's open connections — which keep observing the + /// previous file inode — and queries would continue to return stale data + /// even after the file has been atomically replaced on disk. + async fn reload_from_snapshot( + &self, + source: &dyn AccelerationSource, + previous_provider: Arc, + provider_factory: super::ReloadProviderFactory, + ) -> Result, Box> { + // Drop the caller's clone first so the only remaining strong refs to + // the prior pool are the registry entry (which we are about to evict) + // and any in-flight queries (which will drain naturally). + drop(previous_provider); + + // Evict the cached pool. For file mode this matches the path the + // factory keyed on at construction time; for memory mode this falls + // back to the in-memory key. Snapshot reload is only meaningful for + // file mode in practice (memory accelerators cannot be snapshotted), + // but the memory branch is kept for completeness. + let acceleration = + source + .acceleration() + .ok_or_else(|| -> Box { + "acceleration not configured for snapshot reload".into() + })?; + match acceleration.mode { + Mode::File | Mode::FileCreate | Mode::FileUpdate => { + let path = self.duckdb_file_path(source).boxed()?; + self.duckdb_factory.invalidate_file_instance(path).await; + } + Mode::Memory => { + self.duckdb_factory + .invalidate_instance(&db_connection_pool::DbInstanceKey::memory()) + .await; + } + } + + provider_factory().await + } + async fn drop_table( &self, table_name: &str, diff --git a/crates/runtime/src/dataaccelerator/mod.rs b/crates/runtime/src/dataaccelerator/mod.rs index 756b6bc0c7..de779f605a 100644 --- a/crates/runtime/src/dataaccelerator/mod.rs +++ b/crates/runtime/src/dataaccelerator/mod.rs @@ -58,6 +58,7 @@ pub mod turso; pub(crate) mod snapshots; pub mod spice_sys; +pub mod swappable; pub mod upsert_dedup; pub(crate) use snapshots::validate_snapshot_paths; @@ -496,6 +497,91 @@ pub trait DataAccelerator: Send + Sync { async fn shutdown(&self) -> Result<(), Box> { Ok(()) } + + /// Whether this accelerator supports reloading its data from a freshly + /// downloaded snapshot file after [`DataAccelerator::init`] has already + /// produced a [`TableProvider`]. + /// + /// Returning `true` indicates that [`DataAccelerator::reload_from_snapshot`] + /// is implemented; in-memory accelerators (e.g. Arrow) and accelerators + /// without a stable on-disk snapshot format must return `false`. + fn supports_snapshot_reload(&self) -> bool { + false + } + + /// Reload the accelerator from a snapshot file that has already been + /// downloaded and written to the accelerator's primary path on disk + /// (i.e. `acceleration_layout(source).primary_path()`). + /// + /// The runtime guarantees the per-dataset accelerator write mutex is held + /// for the duration of this call. Implementations must: + /// 1. Drop or clear any cached engine state (open connections, pool + /// entries, file handles, cached schema views, etc.) holding the + /// previous file open. + /// 2. Invoke `provider_factory` to construct a fresh [`TableProvider`] + /// backed by the now-replaced file at the primary path. + /// + /// `provider_factory` re-runs the same `create_accelerator_table` flow + /// used at startup, so the returned provider has the same logical schema, + /// constraints, and indexes as `previous_provider`. + /// + /// The default implementation rejects the call. File-based accelerators + /// that participate in `refresh_mode: snapshot` must override this. + async fn reload_from_snapshot( + &self, + _source: &dyn AccelerationSource, + _previous_provider: Arc, + _provider_factory: ReloadProviderFactory, + ) -> Result, Box> { + Err(Box::new(SnapshotReloadUnsupported { + engine: self.name(), + })) + } + + /// Optional engine-specific [`SnapshotEngine`] used for snapshot create/extract. + /// + /// Engines that need to customise the on-disk archive contents (e.g. + /// Cayenne, which ships a per-dataset metastore-slice JSON instead of + /// the raw `cayenne.db` file) override this to return their engine. + /// File-based accelerators (`DuckDB` / `SQLite` / `Turso`) return `None` and + /// the default `SnapshotManager` engine selection applies. + async fn snapshot_engine_for_source( + &self, + _source: &dyn AccelerationSource, + ) -> Option> { + None + } +} + +/// Factory that re-runs the `create_accelerator_table` registry flow to +/// produce a fresh [`TableProvider`] for an already-initialized dataset. +/// +/// Used by [`DataAccelerator::reload_from_snapshot`] so engines don't need +/// to re-derive table options, attach databases, write handlers, etc. +pub type ReloadProviderFactory = Arc< + dyn Fn() -> std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result< + Arc, + Box, + >, + > + Send, + >, + > + Send + + Sync, +>; + +/// Error returned by the default [`DataAccelerator::reload_from_snapshot`] +/// implementation when an engine does not support snapshot-based reloads. +#[derive(Debug, Snafu)] +#[snafu(display( + "Acceleration engine '{engine}' does not support reloading from a snapshot. \ + `refresh_mode: snapshot` requires a snapshot-capable file-based engine \ + (DuckDB, SQLite, Cayenne, or Turso)." +))] +pub struct SnapshotReloadUnsupported { + pub engine: &'static str, } pub struct AcceleratorExternalTableBuilder { @@ -784,6 +870,16 @@ impl BootstrapStatus { Self::Bootstrapped(info) => info.last_updated_at, } } + + /// The `snapshot_id` of the snapshot that was loaded at bootstrap, if any. + /// `None` when no bootstrap occurred (no snapshot, or snapshots disabled). + #[must_use] + pub const fn loaded_snapshot_id(&self) -> Option { + match self { + Self::None => None, + Self::Bootstrapped(info) => Some(info.snapshot_id), + } + } } #[cfg(test)] diff --git a/crates/runtime/src/dataaccelerator/snapshots.rs b/crates/runtime/src/dataaccelerator/snapshots.rs index 3ce9bd1e17..5d5fc68879 100644 --- a/crates/runtime/src/dataaccelerator/snapshots.rs +++ b/crates/runtime/src/dataaccelerator/snapshots.rs @@ -30,6 +30,7 @@ use crate::{ use runtime_acceleration::snapshot::AccelerationEngine; use runtime_acceleration::snapshot::AccelerationLayout; use runtime_acceleration::snapshot::ForceCreate; +use runtime_acceleration::snapshot::engine::SnapshotEngine; use runtime_acceleration::{ dataset_checkpoint::make_checkpointer_factory, snapshot::{SnapshotBehavior, SnapshotManager, metrics}, @@ -38,11 +39,18 @@ use snafu::{ResultExt, Snafu}; /// Downloads a snapshot if needed for bootstrapping. /// Returns `BootstrapStatus`::`Bootstrapped` if a snapshot was successfully downloaded. +/// +/// `engine_override`, when set, replaces the engine that the resulting +/// `SnapshotManager` would otherwise build via +/// `runtime_acceleration::snapshot::engine::create_snapshot_engine`. Used by +/// the Cayenne accelerator to inject a `CayenneSnapshotEngine` that knows +/// how to import a per-dataset metastore slice on extract. pub(super) async fn download_snapshot_if_needed( acceleration: &Acceleration, source: &dyn AccelerationSource, layout: AccelerationLayout, engine: AccelerationEngine, + engine_override: Option>, ) -> BootstrapStatus { if !acceleration.snapshot_behavior.bootstrap_enabled() { return BootstrapStatus::none(); @@ -86,7 +94,10 @@ pub(super) async fn download_snapshot_if_needed( ) .await { - let manager = manager.with_checkpointer_factory(checkpoint_factory); + let mut manager = manager.with_checkpointer_factory(checkpoint_factory); + if let Some(engine_override) = engine_override { + manager = manager.with_snapshot_engine(engine_override); + } let start_time = Instant::now(); match manager.download_latest_snapshot().await { Ok(Some(info)) => { @@ -117,12 +128,15 @@ pub(super) async fn download_snapshot_if_needed( /// /// This is a best-effort operation: if snapshotting fails, a warning is logged and the /// caller proceeds with recreation. +/// +/// `engine_override` parallels [`download_snapshot_if_needed`]. pub(crate) async fn snapshot_before_recreate( acceleration: &Acceleration, dataset_name: &str, layout: AccelerationLayout, engine: AccelerationEngine, schema: Arc, + engine_override: Option>, ) { if !acceleration.snapshot_behavior.create_enabled() { return; @@ -138,6 +152,11 @@ pub(crate) async fn snapshot_before_recreate( else { return; }; + let manager = if let Some(engine_override) = engine_override { + manager.with_snapshot_engine(engine_override) + } else { + manager + }; // If the caller provided an empty schema (e.g. during file_create init when the table // provider isn't available yet), try to read the real schema from existing snapshot @@ -318,13 +337,15 @@ pub fn validate_cayenne_snapshot_consistency( ); } - // If all datasets have snapshots enabled and there are multiple, that's an error - if !enabled.is_empty() && disabled.is_empty() && enabled.len() > 1 { - return Err(CayenneSnapshotValidationError::SharedAcceleration { - metadata_dir, - datasets: enabled.join(", "), - }); - } + // Multiple datasets sharing the metadata directory with snapshots all + // enabled is supported: each dataset's snapshot ships a per-dataset + // metastore-slice JSON via `CayenneSnapshotEngine`. That engine is + // wired in by `Cayenne::snapshot_engine_for_source` and threaded + // through both the snapshot-creation pipeline + // (`build_snapshot_creation_config`) and the snapshot-refresh-mode + // pipeline (`build_snapshot_refresh_state`), so per-dataset slices + // never clobber each other on extract. The previous restriction + // (single-dataset-per-metadata-dir) is therefore lifted. } Ok(()) @@ -426,22 +447,17 @@ mod tests { } #[tokio::test] - async fn test_cayenne_shared_acceleration_with_snapshots_errors() { + async fn test_cayenne_shared_acceleration_with_snapshots_now_supported() { + // Multi-dataset shared metastore + snapshots-enabled used to error + // with SharedAcceleration. With per-dataset metastore-slice snapshots + // (`CayenneSnapshotEngine`), this configuration is now supported. let sources: Vec> = vec![ MockSource::cayenne_with_metadata_dir("ds1", "/tmp/meta", true), MockSource::cayenne_with_metadata_dir("ds2", "/tmp/meta", true), ]; - let result = validate_cayenne_snapshot_consistency(&sources); - assert!(result.is_err()); - let err = result.expect_err("expected error"); - assert!( - matches!( - err, - CayenneSnapshotValidationError::SharedAcceleration { .. } - ), - "Expected SharedAcceleration error, got: {err}" - ); + validate_cayenne_snapshot_consistency(&sources) + .expect("shared metastore + snapshots is now valid"); } #[tokio::test] diff --git a/crates/runtime/src/dataaccelerator/sqlite.rs b/crates/runtime/src/dataaccelerator/sqlite.rs index c50ca4c4e6..10d07d2aa5 100644 --- a/crates/runtime/src/dataaccelerator/sqlite.rs +++ b/crates/runtime/src/dataaccelerator/sqlite.rs @@ -36,7 +36,8 @@ use datafusion::{ }; use datafusion_table_providers::{ sql::db_connection_pool::{ - dbconnection::sqliteconn::SqliteConnection, sqlitepool::SqliteConnectionPool, + self as db_connection_pool, dbconnection::sqliteconn::SqliteConnection, + sqlitepool::SqliteConnectionPool, }, sqlite::{SqliteTableProviderFactory, write::SqliteTableWriter}, }; @@ -276,6 +277,7 @@ impl DataAccelerator for SqliteAccelerator { )), AccelerationEngine::Sqlite, Arc::new(arrow_schema::Schema::empty()), + None, ) .await; @@ -294,6 +296,7 @@ impl DataAccelerator for SqliteAccelerator { source, runtime_acceleration::snapshot::AccelerationLayout::file(PathBuf::from(path)), AccelerationEngine::Sqlite, + None, ) .await; @@ -394,6 +397,47 @@ impl DataAccelerator for SqliteAccelerator { PARAMETERS } + fn supports_snapshot_reload(&self) -> bool { + true + } + + /// Reloads the SQLite-backed table provider from the snapshot file that + /// was just written to the primary path. + /// + /// Drops the previous provider, evicts the cached connection pool from + /// the upstream `SqliteTableProviderFactory` registry, and then re-runs + /// the registry factory to build a fresh provider over the on-disk file. + /// See the `DuckDB` implementation for the rationale around evicting the + /// pool before rebuilding. + async fn reload_from_snapshot( + &self, + source: &dyn AccelerationSource, + previous_provider: Arc, + provider_factory: super::ReloadProviderFactory, + ) -> Result, Box> { + drop(previous_provider); + + let acceleration = + source + .acceleration() + .ok_or_else(|| -> Box { + "acceleration not configured for snapshot reload".into() + })?; + match acceleration.mode { + Mode::File | Mode::FileCreate | Mode::FileUpdate => { + let path = self.sqlite_file_path(source).boxed()?; + self.sqlite_factory.invalidate_file_instance(path).await; + } + Mode::Memory => { + self.sqlite_factory + .invalidate_instance(&db_connection_pool::DbInstanceKey::memory()) + .await; + } + } + + provider_factory().await + } + async fn drop_table( &self, table_name: &str, diff --git a/crates/runtime/src/dataaccelerator/swappable.rs b/crates/runtime/src/dataaccelerator/swappable.rs new file mode 100644 index 0000000000..10a8006ed7 --- /dev/null +++ b/crates/runtime/src/dataaccelerator/swappable.rs @@ -0,0 +1,310 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! [`SwappableTableProvider`] is a [`TableProvider`] wrapper whose underlying +//! delegate can be replaced atomically at runtime. +//! +//! It is used by `refresh_mode: snapshot` to allow reloading the accelerator +//! table from a freshly downloaded snapshot file without tearing down the +//! enclosing [`AcceleratedTable`]. Schema, table type, and constraints are +//! captured at construction time and assumed to be invariant across reloads; +//! snapshots restore the same logical dataset so this invariant holds by +//! design (and is enforced by snapshot schema validation). + +use std::any::Any; +use std::sync::{Arc, RwLock}; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::{Schema, SchemaRef}; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::{Constraints, Statistics}; +use datafusion::error::Result as DFResult; +use datafusion::logical_expr::dml::InsertOp; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::physical_plan::ExecutionPlan; +use snafu::Snafu; + +/// Errors returned by [`SwappableTableProvider::swap`]. +#[derive(Debug, Snafu)] +pub enum SwapError { + #[snafu(display( + "swap rejected: candidate schema is incompatible with the cached schema (field count, names, data types, or nullability differ)" + ))] + SchemaMismatch, +} + +/// Returns true when `candidate` is structurally compatible with `expected` +/// for the purposes of swapping a [`TableProvider`] under a +/// [`SwappableTableProvider`]: same number of fields in the same order with +/// matching names, data types, and nullability flags. +/// +/// Per-field metadata and arrow `Schema`-level metadata are intentionally +/// ignored — different engines (e.g. `DuckDB`, `SQLite`, CSV) attach +/// engine-specific metadata for logically identical columns. Nullability is +/// included because a nullable↔non-nullable change is observable to +/// downstream planners (e.g. join key handling, predicate evaluation). +#[must_use] +pub fn schemas_compatible(candidate: &Schema, expected: &Schema) -> bool { + if candidate.fields().len() != expected.fields().len() { + return false; + } + candidate + .fields() + .iter() + .zip(expected.fields().iter()) + .all(|(c, e)| { + c.name() == e.name() + && c.data_type() == e.data_type() + && c.is_nullable() == e.is_nullable() + }) +} + +/// A [`TableProvider`] that delegates to an inner provider which may be +/// replaced at runtime via [`SwappableTableProvider::swap`]. +/// +/// All read/write methods (`scan`, `insert_into`, `supports_filters_pushdown`, +/// `statistics`) load the current inner provider on each call. Schema-shaped +/// metadata (`schema`, `table_type`, `constraints`) is captured at construction +/// from the initial inner provider and returned without re-locking; snapshot +/// reloads must preserve these. +pub struct SwappableTableProvider { + inner: RwLock>, + cached_schema: SchemaRef, + cached_table_type: TableType, + cached_constraints: Option, +} + +impl std::fmt::Debug for SwappableTableProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SwappableTableProvider") + .field("schema", &self.cached_schema) + .field("table_type", &self.cached_table_type) + .finish_non_exhaustive() + } +} + +impl SwappableTableProvider { + /// Wrap `inner`, caching its schema, table type, and constraints. Returns + /// an `Arc` so it can be threaded directly to call sites expecting + /// `Arc`. + #[must_use] + pub fn new(inner: Arc) -> Arc { + let cached_schema = inner.schema(); + let cached_table_type = inner.table_type(); + let cached_constraints = inner.constraints().cloned(); + Arc::new(Self { + inner: RwLock::new(inner), + cached_schema, + cached_table_type, + cached_constraints, + }) + } + + /// Returns the current inner provider. + /// + /// Lock poisoning is recovered transparently via + /// [`std::sync::PoisonError::into_inner`]: a previous panic in another + /// thread holding the lock does not propagate here. + #[must_use] + pub fn current(&self) -> Arc { + Arc::clone( + &self + .inner + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner), + ) + } + + /// Replace the inner provider. Validates that the new provider's schema + /// is compatible with the cached schema (see [`schemas_compatible`]) and + /// returns [`SwapError::SchemaMismatch`] otherwise without mutating + /// state. Production callers should pre-validate too so they can surface + /// dataset-aware error context, but this guard ensures incompatible + /// providers cannot be installed even in release builds. + /// + /// Lock poisoning is recovered transparently via + /// [`std::sync::PoisonError::into_inner`]. + pub fn swap(&self, new_inner: Arc) -> Result<(), SwapError> { + if !schemas_compatible(new_inner.schema().as_ref(), self.cached_schema.as_ref()) { + return Err(SwapError::SchemaMismatch); + } + let mut guard = self + .inner + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = new_inner; + Ok(()) + } +} + +#[async_trait] +impl TableProvider for SwappableTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.cached_schema) + } + + fn constraints(&self) -> Option<&Constraints> { + self.cached_constraints.as_ref() + } + + fn table_type(&self) -> TableType { + self.cached_table_type + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> DFResult> { + self.current().scan(state, projection, filters, limit).await + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> DFResult> { + self.current().supports_filters_pushdown(filters) + } + + fn statistics(&self) -> Option { + self.current().statistics() + } + + async fn insert_into( + &self, + state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> DFResult> { + self.current().insert_into(state, input, insert_op).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::catalog::MemTable; + + fn mem_provider(value: i32) -> Arc { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![value]))], + ) + .expect("build batch"); + let table = MemTable::try_new(schema, vec![vec![batch]]).expect("build memtable"); + Arc::new(table) + } + + #[test] + fn current_and_swap_replace_inner_provider() { + let initial = mem_provider(1); + let initial_ptr = Arc::as_ptr(&initial).cast::<()>(); + let swappable = SwappableTableProvider::new(initial); + + let observed = swappable.current(); + assert_eq!(Arc::as_ptr(&observed).cast::<()>(), initial_ptr); + + let replacement = mem_provider(2); + let replacement_ptr = Arc::as_ptr(&replacement).cast::<()>(); + swappable + .swap(replacement) + .expect("swap with compatible schema"); + + let observed = swappable.current(); + assert_eq!(Arc::as_ptr(&observed).cast::<()>(), replacement_ptr); + } + + #[test] + fn schema_is_cached_at_construction() { + let initial = mem_provider(1); + let initial_schema = initial.schema(); + let swappable = SwappableTableProvider::new(initial); + + // Even after swapping in a (schema-equivalent) replacement, the wrapper + // returns the same SchemaRef instance it cached at construction. + let replacement = mem_provider(99); + swappable + .swap(replacement) + .expect("swap with compatible schema"); + assert!( + Arc::ptr_eq(&swappable.schema(), &initial_schema), + "swappable.schema() should return the cached SchemaRef instance" + ); + } + + #[test] + fn swap_rejects_schema_mismatch() { + let initial = mem_provider(1); + let swappable = SwappableTableProvider::new(initial); + + let mismatched_schema = Arc::new(Schema::new(vec![Field::new( + "other", + DataType::Utf8, + false, + )])); + let mismatched_batch = RecordBatch::try_new( + Arc::clone(&mismatched_schema), + vec![Arc::new(datafusion::arrow::array::StringArray::from(vec![ + "x", + ]))], + ) + .expect("build batch"); + let mismatched = Arc::new( + MemTable::try_new(mismatched_schema, vec![vec![mismatched_batch]]) + .expect("build memtable"), + ); + let err = swappable + .swap(mismatched) + .expect_err("swap should reject mismatched schema"); + assert!(matches!(err, SwapError::SchemaMismatch)); + + // The cached schema is unchanged and the inner provider is preserved. + let observed = swappable.current(); + assert_eq!(observed.schema().fields().len(), 1); + assert_eq!(observed.schema().field(0).name(), "v"); + } + + #[test] + fn swap_rejects_nullability_change() { + let initial = mem_provider(1); + let swappable = SwappableTableProvider::new(initial); + + // Same name + data type, different nullability — must be rejected. + let nullable_schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let batch = RecordBatch::try_new( + Arc::clone(&nullable_schema), + vec![Arc::new(Int32Array::from(vec![Some(1)]))], + ) + .expect("build batch"); + let nullable = Arc::new( + MemTable::try_new(nullable_schema, vec![vec![batch]]).expect("build memtable"), + ); + let err = swappable + .swap(nullable) + .expect_err("nullability change must be rejected"); + assert!(matches!(err, SwapError::SchemaMismatch)); + } +} diff --git a/crates/runtime/src/dataaccelerator/turso.rs b/crates/runtime/src/dataaccelerator/turso.rs index 13ea6c2ad8..dd139df203 100644 --- a/crates/runtime/src/dataaccelerator/turso.rs +++ b/crates/runtime/src/dataaccelerator/turso.rs @@ -469,6 +469,7 @@ impl DataAccelerator for TursoAccelerator { )), AccelerationEngine::Turso, Arc::new(arrow_schema::Schema::empty()), + None, ) .await; @@ -487,6 +488,7 @@ impl DataAccelerator for TursoAccelerator { source, runtime_acceleration::snapshot::AccelerationLayout::file(PathBuf::from(path)), AccelerationEngine::Turso, + None, ) .await; @@ -692,6 +694,42 @@ impl DataAccelerator for TursoAccelerator { PARAMETERS } + fn supports_snapshot_reload(&self) -> bool { + true + } + + /// Reloads the Turso-backed table provider from the snapshot file that + /// was just written to the primary path. + /// + /// Drops the previous provider, evicts the cached `TursoConnectionPool` + /// from the per-accelerator `pools` map (keyed by the on-disk file path), + /// and then re-runs the registry factory to build a fresh provider over + /// the new on-disk contents. Without the eviction step, the next + /// `provider_factory()` call would re-use the cached pool and its open + /// connections, which can continue to observe the prior file's pages. + async fn reload_from_snapshot( + &self, + source: &dyn AccelerationSource, + previous_provider: Arc, + provider_factory: super::ReloadProviderFactory, + ) -> Result, Box> { + drop(previous_provider); + + let turso_file = self + .turso_file_path(source) + .map_err(|e| -> Box { Box::new(e) })?; + // Drop the cached pool entry so the factory rebuild opens the file + // fresh. Existing `Arc` clones held by the + // previous provider have already been dropped above; once the last + // strong reference is released the pool's connections are closed. + { + let mut pools = self.pools.lock().await; + pools.remove(&turso_file); + } + + provider_factory().await + } + async fn drop_table( &self, table_name: &str, diff --git a/crates/runtime/src/datafusion/iceberg_ddl/physical_plans.rs b/crates/runtime/src/datafusion/iceberg_ddl/physical_plans.rs index 6c683713e0..aaebcbb1d4 100644 --- a/crates/runtime/src/datafusion/iceberg_ddl/physical_plans.rs +++ b/crates/runtime/src/datafusion/iceberg_ddl/physical_plans.rs @@ -1105,6 +1105,7 @@ fn render_refresh_mode(mode: &spicepod::acceleration::RefreshMode) -> &'static s spicepod::acceleration::RefreshMode::Append => "append", spicepod::acceleration::RefreshMode::Changes => "changes", spicepod::acceleration::RefreshMode::Caching => "caching", + spicepod::acceleration::RefreshMode::Snapshot => "snapshot", } } diff --git a/crates/runtime/src/datafusion/mod.rs b/crates/runtime/src/datafusion/mod.rs index e5cf66e022..a6ec0bc5b5 100644 --- a/crates/runtime/src/datafusion/mod.rs +++ b/crates/runtime/src/datafusion/mod.rs @@ -19,6 +19,7 @@ use std::sync::{Arc, OnceLock, RwLock, Weak}; use std::time::Duration; use crate::accelerated_table::refresh::{self, RefreshOverrides}; +use crate::accelerated_table::snapshots::SnapshotRefreshState; use crate::accelerated_table::{ self, AcceleratedTableBuilderError, SnapshotCreateTrigger, SnapshotCreationConfig, }; @@ -28,8 +29,10 @@ use crate::component::access::AccessMode; use crate::component::dataset::acceleration::{Acceleration, Engine, Mode, RefreshMode}; use crate::component::dataset::{Dataset, ReadyState}; use crate::component::view::View; +use crate::dataaccelerator::ReloadProviderFactory; use crate::dataaccelerator::spice_sys::OpenOption; use crate::dataaccelerator::spice_sys::dataset_checkpoint::DatasetCheckpoint; +use crate::dataaccelerator::swappable::SwappableTableProvider; use crate::dataaccelerator::{self, BootstrapStatus}; use crate::dataaccelerator::{AcceleratorEngineRegistry, get_acceleration_layout}; use crate::dataconnector::deferred::DeferredConnector; @@ -459,6 +462,36 @@ pub enum Error { ))] UnsupportedAccelerationEngineForSnapshots, + #[snafu(display( + "refresh_mode: snapshot requires snapshot bootstrap to be enabled \ + (set `acceleration.snapshots: enabled` or `bootstrap_only`); \ + `disabled` and `create_only` are not sufficient because the dataset \ + must be able to load from a snapshot." + ))] + SnapshotRefreshModeRequiresSnapshots, + + #[snafu(display( + "refresh_mode: snapshot requires a snapshot-capable file-based engine \ + (DuckDB, SQLite, Cayenne, or Turso); engine '{engine}' is not supported." + ))] + SnapshotRefreshModeUnsupportedEngine { engine: String }, + + #[snafu(display( + "refresh_mode: snapshot requires the accelerator to support snapshot reload, but \ + engine '{engine}' does not implement `reload_from_snapshot`." + ))] + SnapshotRefreshModeReloadUnsupported { engine: String }, + + #[snafu(display("Failed to construct snapshot manager for refresh_mode: snapshot."))] + SnapshotRefreshModeManagerUnavailable, + + #[snafu(display( + "refresh_mode: snapshot could not resolve the accelerator file layout: {source}" + ))] + SnapshotRefreshModeLayoutUnavailable { + source: crate::dataaccelerator::FilePathError, + }, + #[snafu(display("Pre-refresh partition discovery failed for table '{table_name}': {source}"))] PreRefreshPartitionDiscoveryFailed { table_name: String, @@ -561,6 +594,12 @@ fn remap_constraints_to_refresh_schema( const DEFAULT_SNAPSHOT_CREATION_INTERVAL: Duration = Duration::from_mins(10); const DEFAULT_SNAPSHOT_CREATION_BATCHES: i64 = 100; +/// Default polling interval for `refresh_mode: snapshot` when the user does +/// not specify `refresh_check_interval` explicitly. Picked to be slightly +/// shorter than the default snapshot creation interval so a freshly created +/// snapshot is picked up promptly without aggressive object-store load. +const DEFAULT_SNAPSHOT_REFRESH_CHECK_INTERVAL: Duration = Duration::from_mins(1); + pub enum Table { Accelerated { source: Arc, @@ -1847,6 +1886,30 @@ impl DataFusion { .await .context(UnableToCreateDataAcceleratorSnafu)?; + // For RefreshMode::Snapshot, wrap the accelerator in a SwappableTableProvider + // so the underlying provider can be replaced atomically when a newer snapshot + // is loaded. The snapshot refresh state captures everything `RefreshTask` needs + // to query the snapshot store and rebuild the provider on reload. + let (accelerated_table_provider, snapshot_refresh_state) = + if matches!(refresh_mode, RefreshMode::Snapshot) { + let snapshot_state = build_snapshot_refresh_state( + self, + dataset, + Arc::clone(&refresh_schema), + constraints.clone(), + &acceleration_settings, + Arc::clone(&secrets), + Arc::clone(&accelerated_table_provider), + bootstrap_status.loaded_snapshot_id(), + ) + .await?; + let swappable: Arc = + Arc::clone(&snapshot_state.swappable_provider) as Arc; + (swappable, Some(snapshot_state)) + } else { + (accelerated_table_provider, None) + }; + // If we already have an existing dataset checkpoint table that has been checkpointed, // it means there is data from a previous acceleration and we don't need // to wait for the first refresh to complete to mark it ready. @@ -1897,6 +1960,16 @@ impl DataFusion { } if let Some(check_interval) = dataset.refresh_check_interval() { refresh = refresh.check_interval(check_interval); + } else if matches!(refresh_mode, RefreshMode::Snapshot) { + // Snapshot mode polls the snapshot store for newer snapshots; if the + // user did not configure a polling interval, fall back to a sensible + // default so the dataset stays current without requiring manual config. + tracing::info!( + dataset = %dataset.name, + interval_secs = DEFAULT_SNAPSHOT_REFRESH_CHECK_INTERVAL.as_secs(), + "refresh_mode: snapshot - using default refresh_check_interval" + ); + refresh = refresh.check_interval(DEFAULT_SNAPSHOT_REFRESH_CHECK_INTERVAL); } if let Some(max_jitter) = dataset.refresh_max_jitter() { refresh = refresh.max_jitter(max_jitter); @@ -2063,11 +2136,23 @@ impl DataFusion { if acceleration_settings.snapshot_behavior.create_enabled() { if let Some(ref layout) = acceleration_layout { if layout.is_enabled() { + // Resolve any engine-specific snapshot engine override + // (e.g. CayenneSnapshotEngine) so the upload pipeline + // ships the engine's preferred archive format. + let snapshot_engine_override = match self + .accelerator_engine_registry + .get_accelerator_engine(acceleration_settings.engine) + .await + { + Some(accel) => accel.snapshot_engine_for_source(dataset).await, + None => None, + }; if let Some(snapshot_config) = build_snapshot_creation_config( dataset, &acceleration_settings, refresh_mode, layout.clone(), + snapshot_engine_override, ) .await? { @@ -2087,6 +2172,8 @@ impl DataFusion { } } + accelerated_table_builder.snapshot_refresh_state(snapshot_refresh_state); + // Pass the acceleration layout for size metrics if let Some(layout) = acceleration_layout { accelerated_table_builder.acceleration_layout(layout); @@ -2263,6 +2350,7 @@ impl DataFusion { layout, accel_engine, Arc::clone(&existing_schema), + None, ) .await; } @@ -3448,6 +3536,9 @@ async fn build_snapshot_creation_config( acceleration_settings: &Acceleration, refresh_mode: RefreshMode, acceleration_layout: AccelerationLayout, + snapshot_engine_override: Option< + Arc, + >, ) -> Result> { let is_streaming_refresh = matches!(refresh_mode, RefreshMode::Changes) || (matches!(refresh_mode, RefreshMode::Append) && dataset.time_column.is_none()); @@ -3581,10 +3672,157 @@ async fn build_snapshot_creation_config( .await .map(|sm| { let sm = sm.with_snapshots_creation_policy(acceleration_settings.snapshots_creation_policy); + let sm = if let Some(engine) = snapshot_engine_override { + sm.with_snapshot_engine(engine) + } else { + sm + }; SnapshotCreationConfig::new(Arc::new(sm), snapshot_creation_trigger) })) } +/// Build the per-dataset state required to drive `RefreshMode::Snapshot`. +/// +/// Validates that the configuration is sound (snapshots enabled, supported +/// engine, supported reload), constructs a [`SnapshotManager`] for the +/// dataset, wraps the freshly-created accelerator provider in a +/// [`SwappableTableProvider`], and captures a [`ReloadProviderFactory`] that +/// re-runs `create_accelerator_table` on each reload. +#[expect(clippy::too_many_arguments)] +async fn build_snapshot_refresh_state( + df: &DataFusion, + dataset: &Dataset, + refresh_schema: SchemaRef, + constraints: Option, + acceleration_settings: &Acceleration, + secrets: Arc>, + initial_provider: Arc, + bootstrap_loaded_id: Option, +) -> Result { + // 1. snapshots must be enabled. + if !acceleration_settings.snapshot_behavior.bootstrap_enabled() { + return SnapshotRefreshModeRequiresSnapshotsSnafu.fail(); + } + + // 2. engine must be snapshot-capable (file-based with a known layout). + let acceleration_engine = engine_to_acceleration_engine(acceleration_settings.engine) + .ok_or_else(|| Error::SnapshotRefreshModeUnsupportedEngine { + engine: acceleration_settings.engine.to_string(), + })?; + + // 3. accelerator must support reload_from_snapshot. + let accelerator = df + .accelerator_engine_registry + .get_accelerator_engine(acceleration_settings.engine) + .await + .ok_or_else(|| Error::SnapshotRefreshModeUnsupportedEngine { + engine: acceleration_settings.engine.to_string(), + })?; + if !accelerator.supports_snapshot_reload() { + return SnapshotRefreshModeReloadUnsupportedSnafu { + engine: acceleration_settings.engine.to_string(), + } + .fail(); + } + + // 4. obtain (or warn) a SnapshotManager for this dataset. + let acceleration_layout = get_acceleration_layout(dataset) + .await + .context(SnapshotRefreshModeLayoutUnavailableSnafu)?; + if !acceleration_layout.is_enabled() { + return Err(Error::SnapshotRefreshModeManagerUnavailable); + } + + let manager = SnapshotManager::try_new( + dataset.name.to_string(), + acceleration_settings.snapshot_behavior.clone(), + acceleration_layout, + acceleration_engine, + ) + .await + .ok_or(Error::SnapshotRefreshModeManagerUnavailable)?; + // Apply any engine-specific snapshot-engine override (e.g. CayenneSnapshotEngine). + let manager = match accelerator.snapshot_engine_for_source(dataset).await { + Some(engine) => manager.with_snapshot_engine(engine), + None => manager, + }; + // Build a checkpointer factory mirroring the bootstrap path so the + // refresh-time `download_latest_snapshot` call can succeed (it requires a + // factory to materialize a checkpoint for restore). + let source_for_checkpointer: Arc = + Arc::new(dataset.clone()); + let snapshot_behavior_for_checkpointer = acceleration_settings.snapshot_behavior.clone(); + let checkpoint_factory = + runtime_acceleration::dataset_checkpoint::make_checkpointer_factory(move || { + let source = Arc::clone(&source_for_checkpointer); + let snapshot_behavior = snapshot_behavior_for_checkpointer.clone(); + async move { + use crate::dataaccelerator::spice_sys::OpenOption; + use crate::dataaccelerator::spice_sys::dataset_checkpoint::DatasetCheckpoint; + use snafu::ResultExt; + DatasetCheckpoint::try_new(source.as_ref(), OpenOption::OpenExisting) + .await + .boxed() + .map(|checkpoint| { + checkpoint + .with_snapshot_behavior(snapshot_behavior) + .to_arc() + }) + } + }); + let manager = manager + .with_snapshots_creation_policy(acceleration_settings.snapshots_creation_policy) + .with_checkpointer_factory(checkpoint_factory); + let manager = Arc::new(manager); + + // 5. clone everything the reload factory needs into 'static state. + let registry = Arc::clone(&df.accelerator_engine_registry); + let dataset_owned = Arc::new(dataset.clone()); + let acceleration_settings_owned = acceleration_settings.clone(); + let ctx_owned = Arc::clone(&df.ctx); + let secrets_for_factory = Arc::clone(&secrets); + let table_name = dataset.name.clone(); + let schema_for_factory = Arc::clone(&refresh_schema); + let constraints_for_factory = constraints; + + let provider_factory: ReloadProviderFactory = Arc::new(move || { + let registry = Arc::clone(®istry); + let dataset_owned = Arc::clone(&dataset_owned); + let acceleration_settings_owned = acceleration_settings_owned.clone(); + let ctx_owned = Arc::clone(&ctx_owned); + let secrets_for_factory = Arc::clone(&secrets_for_factory); + let table_name = table_name.clone(); + let schema_for_factory = Arc::clone(&schema_for_factory); + let constraints_for_factory = constraints_for_factory.clone(); + Box::pin(async move { + registry + .create_accelerator_table( + table_name, + schema_for_factory, + constraints_for_factory.as_ref(), + &acceleration_settings_owned, + secrets_for_factory, + Some(dataset_owned.as_ref()), + ctx_owned, + ) + .await + .map_err(|e| -> Box { Box::new(e) }) + }) + }); + + let swappable_provider = SwappableTableProvider::new(initial_provider); + let current_snapshot_id = std::sync::Arc::new(std::sync::Mutex::new(bootstrap_loaded_id)); + + Ok(SnapshotRefreshState { + manager, + accelerator, + source: Arc::new(dataset.clone()), + swappable_provider, + provider_factory, + current_snapshot_id, + }) +} + #[cfg(test)] mod tests { use arrow::array::Int32Array; @@ -3797,6 +4035,7 @@ mod tests { &acceleration, RefreshMode::Full, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -3821,6 +4060,7 @@ mod tests { &acceleration, RefreshMode::Append, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -3858,6 +4098,7 @@ mod tests { &acceleration, RefreshMode::Changes, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -3895,6 +4136,7 @@ mod tests { &acceleration, RefreshMode::Full, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -3927,6 +4169,7 @@ mod tests { &acceleration, RefreshMode::Append, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -3959,6 +4202,7 @@ mod tests { &acceleration, RefreshMode::Append, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -3991,6 +4235,7 @@ mod tests { &acceleration, RefreshMode::Append, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -4019,6 +4264,7 @@ mod tests { &acceleration, RefreshMode::Append, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -4047,6 +4293,7 @@ mod tests { &acceleration, RefreshMode::Full, AccelerationLayout::file(snapshot_path), + None, ) .await; @@ -4078,6 +4325,7 @@ mod tests { &acceleration, RefreshMode::Changes, AccelerationLayout::file(snapshot_path), + None, ) .await; diff --git a/crates/runtime/src/tracing_util.rs b/crates/runtime/src/tracing_util.rs index ab871c8523..dc82467658 100644 --- a/crates/runtime/src/tracing_util.rs +++ b/crates/runtime/src/tracing_util.rs @@ -102,6 +102,9 @@ fn acceleration_info( RefreshMode::Caching => { info.push_str(", caching"); } + RefreshMode::Snapshot => { + info.push_str(", snapshot"); + } } if let Some(refresh_interval) = &acceleration.refresh_check_interval { diff --git a/crates/runtime/tests/integration_snapshot_refresh.rs b/crates/runtime/tests/integration_snapshot_refresh.rs new file mode 100644 index 0000000000..5c4e7f63b8 --- /dev/null +++ b/crates/runtime/tests/integration_snapshot_refresh.rs @@ -0,0 +1,81 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Two-runtime + S3 + per-engine async fixtures push the type-checker over +// the default 128 query depth limit on stable rustc. +#![recursion_limit = "256"] + +//! Integration tests for `refresh_mode: snapshot`. +//! +//! These tests exercise the full snapshot-refresh flow with two cooperating +//! `Runtime` instances: +//! +//! * a **writer** that owns the federated source, runs `refresh_mode: full`, +//! and is configured to create a new snapshot on every refresh, and +//! * a **reader** that runs `refresh_mode: snapshot`, bootstraps from the +//! shared snapshot store, polls for newer snapshots, and reloads its +//! accelerator under a `SwappableTableProvider` swap. +//! +//! The shared snapshot store is backed by a real S3 bucket (mirroring how +//! customers actually deploy snapshot refresh in production). Tests therefore +//! require AWS credentials — set `AWS_SNAPSHOT_KEY`/`AWS_SNAPSHOT_SECRET`, or +//! configure `AWS_PROFILE` with read/write access to the test bucket. See +//! `snapshot_refresh/mod.rs` for the full setup. Each accelerator that +//! advertises `supports_snapshot_reload()` (`DuckDB`, `SQLite`, `Cayenne`, `Turso`) +//! gets its own end-to-end test. + +mod snapshot_refresh; +mod utils; + +use runtime::Runtime; +use std::sync::Once; +use tracing_subscriber::EnvFilter; + +static INIT_TRACING: Once = Once::new(); + +fn init_tracing(default_level: Option<&str>) { + INIT_TRACING.call_once(|| { + let filter = match (default_level, std::env::var("SPICED_LOG").ok()) { + (_, Some(log)) => EnvFilter::new(log), + (Some(level), None) => EnvFilter::new(level), + _ => EnvFilter::new( + "runtime=DEBUG,runtime_acceleration=DEBUG,datafusion-federation=INFO,info", + ), + }; + + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(filter) + .with_ansi(true) + .with_test_writer() + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + }); +} + +/// Re-export so the `snapshot_refresh` submodule can run queries through the +/// shared helpers without re-implementing them. +pub(crate) async fn run_query( + rt: &std::sync::Arc, + query: &str, +) -> Result, anyhow::Error> { + use futures::StreamExt; + let mut result = rt.datafusion().query_builder(query).build().run().await?; + let mut results = Vec::new(); + while let Some(batch) = result.data.next().await { + results.push(batch?); + } + Ok(results) +} diff --git a/crates/runtime/tests/snapshot_refresh/cayenne.rs b/crates/runtime/tests/snapshot_refresh/cayenne.rs new file mode 100644 index 0000000000..10ba24849a --- /dev/null +++ b/crates/runtime/tests/snapshot_refresh/cayenne.rs @@ -0,0 +1,32 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use super::{EngineKind, run_bootstrap_then_refresh_cycle}; + +// Cayenne snapshot refresh exercises the per-dataset metastore-slice format +// shipped by `CayenneSnapshotEngine`: the writer's snapshot tar contains +// `metadata/.slice.json` instead of the raw `cayenne.db` file, and +// the reader's bootstrap path imports the slice into its local metastore. +// +// The wholesale-replace import correctly rebuilds the cayenne-domain tables +// (`cayenne_table`, `cayenne_partition`, `cayenne_delete_file`); the +// spice_sys `_dataset_checkpoint` schema row is bootstrapped from the +// snapshot metadata by `download_latest_snapshot` (closes +// spiceai/spiceai#10658), so this test now runs by default. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn snapshot_refresh_cayenne_bootstrap_then_refresh() -> Result<(), anyhow::Error> { + run_bootstrap_then_refresh_cycle("snapshot_refresh_cayenne", EngineKind::Cayenne).await +} diff --git a/crates/runtime/tests/snapshot_refresh/duckdb.rs b/crates/runtime/tests/snapshot_refresh/duckdb.rs new file mode 100644 index 0000000000..28cb529e7a --- /dev/null +++ b/crates/runtime/tests/snapshot_refresh/duckdb.rs @@ -0,0 +1,22 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use super::{EngineKind, run_bootstrap_then_refresh_cycle}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn snapshot_refresh_duckdb_bootstrap_then_refresh() -> Result<(), anyhow::Error> { + run_bootstrap_then_refresh_cycle("snapshot_refresh_duckdb", EngineKind::DuckDB).await +} diff --git a/crates/runtime/tests/snapshot_refresh/mod.rs b/crates/runtime/tests/snapshot_refresh/mod.rs new file mode 100644 index 0000000000..8f072011b8 --- /dev/null +++ b/crates/runtime/tests/snapshot_refresh/mod.rs @@ -0,0 +1,599 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Shared infrastructure for `refresh_mode: snapshot` integration tests. +//! +//! The shared snapshot store is backed by a real S3 bucket, mirroring the +//! existing `snapshot_integration` tests, because the local filesystem +//! `object_store` backend does not implement conditional `PutMode::Update` +//! and therefore cannot host the per-snapshot `metadata.json` rewrites that +//! these tests need. +//! +//! Required environment for CI: +//! +//! * `AWS_SNAPSHOT_KEY` + `AWS_SNAPSHOT_SECRET`, **or** +//! * `AWS_PROFILE` configured with read/write access to the test bucket. +//! +//! Each test reserves its own UUID-prefixed key range under the shared +//! bucket and cleans up after itself. + +use std::{ + collections::HashMap, + env, + path::PathBuf, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::{Context, Result, anyhow}; +use app::AppBuilder; +use aws_sdk_credential_bridge::{S3CredentialProvider, get_or_init_sdk_config}; +use futures::StreamExt; +use object_store::{ClientOptions, ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath}; +use runtime::Runtime; +use spicepod::{ + acceleration::{ + Acceleration, Mode, OnConflictBehavior, RefreshMode, RefreshOnStartup, SnapshotBehavior, + SnapshotsCreationPolicy, + }, + component::{ + access::AccessMode, + dataset::Dataset, + snapshot::{BootstrapOnFailureBehavior, Snapshots}, + }, + param::Params, +}; +use tempfile::TempDir; +use tokio::time::{sleep, timeout}; +use uuid::Uuid; + +use crate::utils::{register_test_connectors, runtime_ready_check, test_request_context}; +use crate::{init_tracing, run_query}; + +#[cfg(not(windows))] +mod cayenne; +#[cfg(feature = "duckdb")] +mod duckdb; +#[cfg(feature = "sqlite")] +mod sqlite; +#[cfg(feature = "turso")] +mod turso; + +const DATASET_NAME: &str = "trips"; +const SNAPSHOT_BUCKET: &str = "spiceai-snapshot-integration-tests"; +const SNAPSHOT_REGION: &str = "us-west-2"; + +/// Engine variants exercised by these tests. Each variant is enabled at +/// compile time only when the corresponding feature is on, mirroring the +/// `AccelerationEngine` enum in `runtime-acceleration`. +#[derive(Clone, Copy, Debug)] +pub(crate) enum EngineKind { + Cayenne, + #[cfg(feature = "duckdb")] + DuckDB, + #[cfg(feature = "sqlite")] + Sqlite, + #[cfg(feature = "turso")] + Turso, +} + +impl EngineKind { + fn engine_name(self) -> &'static str { + match self { + Self::Cayenne => "cayenne", + #[cfg(feature = "duckdb")] + Self::DuckDB => "duckdb", + #[cfg(feature = "sqlite")] + Self::Sqlite => "sqlite", + #[cfg(feature = "turso")] + Self::Turso => "turso", + } + } + + fn file_extension(self) -> &'static str { + match self { + Self::Cayenne => "cayenne", + #[cfg(feature = "duckdb")] + Self::DuckDB => "duckdb", + #[cfg(feature = "sqlite")] + Self::Sqlite => "sqlite", + #[cfg(feature = "turso")] + Self::Turso => "turso", + } + } +} + +/// Per-test S3 prefix + an `ObjectStore` pointed at it for cleanup. +struct SnapshotS3Context { + store: Arc, + prefix: String, +} + +impl SnapshotS3Context { + async fn new(test_name: &str) -> Result { + let store = build_snapshot_store().await?; + let prefix = format!("{test_name}/{}", Uuid::now_v7()); + Ok(Self { store, prefix }) + } + + fn location_uri(&self) -> String { + format!( + "s3://{SNAPSHOT_BUCKET}/{}/", + self.prefix.trim_end_matches('/') + ) + } + + async fn cleanup(&self) -> Result<()> { + let base = ObjectPath::from(self.prefix.clone()); + let mut stream = self.store.list(Some(&base)); + let mut to_delete = Vec::new(); + while let Some(meta) = stream.next().await { + let meta = meta.context("listing snapshot bucket for cleanup")?; + to_delete.push(meta.location); + } + for loc in to_delete { + let _ = self.store.delete(&loc).await; + } + Ok(()) + } +} + +async fn build_snapshot_store() -> Result> { + let mut builder = AmazonS3Builder::from_env() + .with_bucket_name(SNAPSHOT_BUCKET) + .with_region(SNAPSHOT_REGION) + .with_client_options(ClientOptions::default()); + + if let (Ok(key), Ok(secret)) = ( + env::var("AWS_SNAPSHOT_KEY"), + env::var("AWS_SNAPSHOT_SECRET"), + ) { + builder = builder + .with_access_key_id(key) + .with_secret_access_key(secret); + if let Ok(token) = env::var("AWS_SNAPSHOT_SESSION_TOKEN") { + builder = builder.with_token(token); + } + } else { + let config = get_or_init_sdk_config() + .await + .map_err(|err| anyhow!("Failed to initialize AWS credentials: {err}"))?; + let Some(config) = config else { + return Err(anyhow!( + "AWS credentials are required to run snapshot refresh integration tests. \ + Provide AWS_SNAPSHOT_KEY/AWS_SNAPSHOT_SECRET or configure AWS_PROFILE." + )); + }; + builder = builder.with_credentials(Arc::new( + S3CredentialProvider::from_config(config.as_ref()) + .context("Loading AWS credentials from environment")?, + )); + } + + Ok(Arc::new(builder.build().context( + "Building Amazon S3 object store client for snapshots", + )?)) +} + +/// Holds everything a single integration test needs to drive a writer + +/// reader pair against a shared S3 snapshot store. +struct SnapshotRefreshFixture { + s3: SnapshotS3Context, + _temp_dir: TempDir, + source_csv_path: PathBuf, + writer_local_db: PathBuf, + reader_local_db: PathBuf, + engine: EngineKind, +} + +impl SnapshotRefreshFixture { + async fn new(test_name: &str, engine: EngineKind) -> Result { + let temp_dir = TempDir::new().context("creating temp dir for snapshot refresh test")?; + let source_csv_path = temp_dir.path().join("source.csv"); + let writer_local_db = temp_dir + .path() + .join(format!("writer.{}", engine.file_extension())); + let reader_local_db = temp_dir + .path() + .join(format!("reader.{}", engine.file_extension())); + let s3 = SnapshotS3Context::new(test_name).await?; + Ok(Self { + s3, + _temp_dir: temp_dir, + source_csv_path, + writer_local_db, + reader_local_db, + engine, + }) + } + + fn source_from_uri(&self) -> String { + format!("file://{}", self.source_csv_path.display()) + } + + /// Atomically rewrite the source CSV. + fn write_source(&self, csv: &str) -> Result<()> { + let tmp = self.source_csv_path.with_extension("csv.tmp"); + std::fs::write(&tmp, csv).context("writing temp source csv")?; + std::fs::rename(&tmp, &self.source_csv_path).context("renaming source csv into place")?; + Ok(()) + } + + pub(crate) fn dataset_params() -> HashMap { + HashMap::from([ + ("file_format".to_string(), "csv".to_string()), + ("csv_has_header".to_string(), "true".to_string()), + ]) + } + + /// Engine-specific acceleration params that pin the on-disk location to + /// `local_db_path`. Cayenne is directory-based and uses two distinct + /// param names (`cayenne_file_path` for data, `cayenne_metadata_dir` for + /// the catalog metastore), so we route to a sibling `metadata/` directory + /// to keep writer and reader fully isolated. Other engines are + /// single-file and use `_file`. + fn engine_accel_params(&self, local_db_path: &std::path::Path) -> HashMap { + let mut params = HashMap::new(); + match self.engine { + EngineKind::Cayenne => { + params.insert( + "cayenne_file_path".to_string(), + local_db_path.to_string_lossy().into_owned(), + ); + let metadata_dir = local_db_path.with_extension("metadata"); + params.insert( + "cayenne_metadata_dir".to_string(), + metadata_dir.to_string_lossy().into_owned(), + ); + } + #[cfg(feature = "duckdb")] + EngineKind::DuckDB => { + params.insert( + "duckdb_file".to_string(), + local_db_path.to_string_lossy().into_owned(), + ); + } + #[cfg(feature = "sqlite")] + EngineKind::Sqlite => { + params.insert( + "sqlite_file".to_string(), + local_db_path.to_string_lossy().into_owned(), + ); + } + #[cfg(feature = "turso")] + EngineKind::Turso => { + params.insert( + "turso_file".to_string(), + local_db_path.to_string_lossy().into_owned(), + ); + } + } + params + } + + /// Build the writer dataset: full refresh + create-on-change snapshots. + fn writer_dataset(&self) -> Dataset { + let accel_params = self.engine_accel_params(&self.writer_local_db); + + let mut dataset = Dataset::new(self.source_from_uri(), DATASET_NAME); + dataset.params = Some(Params::from_string_map(Self::dataset_params())); + dataset.acceleration = Some(Acceleration { + enabled: true, + mode: Mode::File, + engine: Some(self.engine.engine_name().to_string()), + params: Some(Params::from_string_map(accel_params)), + refresh_mode: Some(RefreshMode::Full), + // Drive refreshes (and therefore snapshot creation) quickly so + // the test stays in the few-seconds range. + refresh_check_interval: Some("1s".to_string()), + refresh_on_startup: RefreshOnStartup::Auto, + snapshots: SnapshotBehavior::Enabled, + snapshots_creation_policy: SnapshotsCreationPolicy::OnChange, + ..Acceleration::default() + }); + dataset + } + + /// Build the reader dataset: snapshot refresh + bootstrap from the + /// shared snapshot store. + fn reader_dataset(&self) -> Dataset { + let accel_params = self.engine_accel_params(&self.reader_local_db); + + let mut dataset = Dataset::new(self.source_from_uri(), DATASET_NAME); + dataset.params = Some(Params::from_string_map(Self::dataset_params())); + // Mark the reader dataset as `read_write` with a non-empty + // `on_conflict` map so the runtime's access gate accepts the dataset + // (read_write requires either replication or on_conflict). The + // on_conflict configuration itself is never exercised because the + // refresh_mode: snapshot rejection inside `AcceleratedTable::insert_into` + // fires before any write reaches the accelerator. + dataset.access = AccessMode::ReadWrite; + let mut on_conflict = HashMap::new(); + on_conflict.insert("id".to_string(), OnConflictBehavior::Upsert); + dataset.acceleration = Some(Acceleration { + enabled: true, + mode: Mode::File, + engine: Some(self.engine.engine_name().to_string()), + params: Some(Params::from_string_map(accel_params)), + refresh_mode: Some(RefreshMode::Snapshot), + // Poll often so the test does not have to wait for the default + // 1-minute interval to detect the new snapshot. + refresh_check_interval: Some("1s".to_string()), + refresh_on_startup: RefreshOnStartup::Auto, + snapshots: SnapshotBehavior::Enabled, + primary_key: Some("id".to_string()), + on_conflict, + ..Acceleration::default() + }); + dataset + } + + fn snapshots_config(&self) -> Snapshots { + let mut params = HashMap::from([("s3_region".to_string(), SNAPSHOT_REGION.to_string())]); + if env::var("AWS_PROFILE").is_ok() { + params.insert("s3_auth".to_string(), "iam_role".to_string()); + } else { + params.insert("s3_auth".to_string(), "key".to_string()); + params.insert( + "s3_key".to_string(), + "${secrets:AWS_SNAPSHOT_KEY}".to_string(), + ); + params.insert( + "s3_secret".to_string(), + "${secrets:AWS_SNAPSHOT_SECRET}".to_string(), + ); + } + Snapshots { + enabled: true, + location: Some(self.s3.location_uri()), + bootstrap_on_failure_behavior: BootstrapOnFailureBehavior::Warn, + params: Some(Params::from_string_map(params)), + } + } + + /// Read the current `current-snapshot-id` for this dataset from the + /// shared metadata file. Returns `None` if metadata is not yet written. + async fn current_snapshot_id(&self) -> Result> { + let metadata_path = ObjectPath::from(format!("{}/metadata.json", self.s3.prefix)); + let bytes = match self.s3.store.get(&metadata_path).await { + Ok(get) => get + .bytes() + .await + .context("reading metadata.json bytes from snapshot store")?, + Err(object_store::Error::NotFound { .. }) => return Ok(None), + Err(e) => return Err(anyhow::Error::from(e).context("getting metadata.json")), + }; + let metadata: serde_json::Value = + serde_json::from_slice(&bytes).context("parsing metadata.json")?; + let Some(dataset_entry) = metadata.get(DATASET_NAME) else { + return Ok(None); + }; + Ok(dataset_entry + .get("current-snapshot-id") + .and_then(serde_json::Value::as_i64)) + } + + async fn wait_for_snapshot_id(&self, minimum_id: i64, max_wait: Duration) -> Result { + let deadline = Instant::now() + max_wait; + loop { + if let Some(id) = self.current_snapshot_id().await? + && id >= minimum_id + { + return Ok(id); + } + if Instant::now() >= deadline { + return Err(anyhow!( + "timed out waiting for snapshot id >= {minimum_id} in s3://{SNAPSHOT_BUCKET}/{}", + self.s3.prefix + )); + } + sleep(Duration::from_millis(250)).await; + } + } +} + +async fn load_runtime(rt: Arc) -> Result<()> { + timeout(Duration::from_secs(120), Arc::clone(&rt).load_components()) + .await + .map_err(|_| anyhow!("Timed out waiting for runtime components to load"))?; + runtime_ready_check(rt.as_ref()).await; + Ok(()) +} + +/// Run a query and return the total number of rows across all batches. We +/// avoid `SELECT count(*)` so the assertion is robust across engines that +/// reject or pushdown-translate the count differently (Turso in particular +/// rejects the federation-pushed-down count form). +async fn count_rows(rt: &Arc, table: &str) -> Result { + let batches = run_query(rt, &format!("SELECT id FROM {table}")) + .await + .with_context(|| format!("counting rows in {table}"))?; + Ok(batches + .iter() + .map(arrow::array::RecordBatch::num_rows) + .sum()) +} + +/// Initial source CSV with three rows. +const INITIAL_CSV: &str = "\ +id,name,score +1,alpha,10 +2,bravo,20 +3,charlie,30 +"; + +/// Mutated CSV with five rows, distinct content from the initial one. Both +/// sets share the same schema so snapshot reload is allowed. +const MUTATED_CSV: &str = "\ +id,name,score +1,alpha,10 +2,bravo,20 +3,charlie,30 +4,delta,40 +5,echo,50 +"; + +/// End-to-end scenario: writer creates snapshots; reader bootstraps and +/// follows. Each per-engine `#[tokio::test]` calls this with its engine. +async fn run_bootstrap_then_refresh_cycle(test_name: &str, engine: EngineKind) -> Result<()> { + init_tracing(None); + register_test_connectors().await; + + test_request_context() + .scope(async move { + let fixture = SnapshotRefreshFixture::new(test_name, engine).await?; + // Always attempt cleanup, even on test failure. + let result = run_inner(&fixture).await; + if let Err(cleanup_err) = fixture.s3.cleanup().await { + tracing::warn!( + test = test_name, + error = %cleanup_err, + "snapshot cleanup encountered errors (test result preserved)" + ); + } + result + }) + .await +} + +async fn run_inner(fixture: &SnapshotRefreshFixture) -> Result<()> { + // Drive the in-process variant by calling the writer phase up to the + // first snapshot, then starting the reader and letting the writer + // publish the mutated snapshot. The two helpers below are also called + // (separately, in two different processes) by the dockerized Cayenne + // orchestrator below. + fixture.write_source(INITIAL_CSV)?; + + // ---------------------- start writer ---------------------- + let writer_app = AppBuilder::new(format!("snapshot_writer_{}", fixture.engine.engine_name())) + .with_snapshots(fixture.snapshots_config()) + .with_dataset(fixture.writer_dataset()) + .build(); + let writer = Arc::new(Runtime::builder().with_app(writer_app).build().await); + load_runtime(Arc::clone(&writer)).await?; + + let writer_initial = count_rows(&writer, "trips").await?; + if writer_initial != 3 { + return Err(anyhow!( + "writer should serve the initial 3 rows, got {writer_initial}" + )); + } + + let first_id = fixture + .wait_for_snapshot_id(0, Duration::from_secs(60)) + .await + .context("waiting for writer to produce initial snapshot")?; + if first_id != 0 { + return Err(anyhow!("first snapshot should have id 0, got {first_id}")); + } + + // ---------------------- start reader ---------------------- + let reader_app = AppBuilder::new(format!("snapshot_reader_{}", fixture.engine.engine_name())) + .with_snapshots(fixture.snapshots_config()) + .with_dataset(fixture.reader_dataset()) + .build(); + let reader = Arc::new(Runtime::builder().with_app(reader_app).build().await); + load_runtime(Arc::clone(&reader)).await?; + + let reader_initial = count_rows(&reader, "trips").await?; + if reader_initial != 3 { + return Err(anyhow!( + "reader should bootstrap to 3 rows from snapshot, got {reader_initial}" + )); + } + + assert_insert_rejected(&reader).await?; + + // ---------------------- mutate source --------------------- + fixture.write_source(MUTATED_CSV)?; + + let next_id = fixture + .wait_for_snapshot_id(first_id + 1, Duration::from_secs(60)) + .await + .context("waiting for writer to publish a snapshot for the mutated source")?; + if next_id <= first_id { + return Err(anyhow!( + "snapshot id must advance after source change ({first_id} -> {next_id})" + )); + } + + wait_for_reader_swap(&reader, 5, Duration::from_secs(60)).await?; + assert_swap_sanity(&reader).await?; + + reader.shutdown().await; + writer.shutdown().await; + Ok(()) +} + +/// Verify that an INSERT against a snapshot-mode reader is rejected with a +/// snapshot-specific error message. +pub(crate) async fn assert_insert_rejected(reader: &Arc) -> Result<()> { + let insert_err = run_query(reader, "INSERT INTO trips VALUES (99, 'zulu', 999)") + .await + .err() + .ok_or_else(|| anyhow!("INSERT INTO under refresh_mode: snapshot must fail"))?; + let msg = format!("{insert_err}"); + if !msg.contains("snapshot") { + return Err(anyhow!( + "INSERT error must mention snapshot mode, got: {msg}" + )); + } + Ok(()) +} + +/// Poll the reader until its row count reaches `expected_rows`, or fail. +pub(crate) async fn wait_for_reader_swap( + reader: &Arc, + expected_rows: usize, + max_wait: Duration, +) -> Result<()> { + let deadline = Instant::now() + max_wait; + loop { + let observed = count_rows(reader, "trips").await?; + if observed == expected_rows { + return Ok(()); + } + if Instant::now() >= deadline { + return Err(anyhow!( + "reader did not observe expected snapshot in time; last observed row count: {observed} (expected {expected_rows})" + )); + } + sleep(Duration::from_millis(500)).await; + } +} + +/// Sanity-check that both new and original rows are present after swap. +pub(crate) async fn assert_swap_sanity(reader: &Arc) -> Result<()> { + let id5 = run_query(reader, "SELECT name FROM trips WHERE id = 5").await?; + let id5_pretty = arrow::util::pretty::pretty_format_batches(&id5) + .map(|fmt| fmt.to_string()) + .context("formatting id=5 row")?; + if !id5_pretty.contains("echo") { + return Err(anyhow!( + "reader should serve the new id=5 row after swap; got:\n{id5_pretty}" + )); + } + let id1 = run_query(reader, "SELECT name FROM trips WHERE id = 1").await?; + let id1_pretty = arrow::util::pretty::pretty_format_batches(&id1) + .map(|fmt| fmt.to_string()) + .context("formatting id=1 row")?; + if !id1_pretty.contains("alpha") { + return Err(anyhow!( + "reader should still serve the original id=1 row after swap; got:\n{id1_pretty}" + )); + } + Ok(()) +} diff --git a/crates/runtime/tests/snapshot_refresh/sqlite.rs b/crates/runtime/tests/snapshot_refresh/sqlite.rs new file mode 100644 index 0000000000..a7fd46565c --- /dev/null +++ b/crates/runtime/tests/snapshot_refresh/sqlite.rs @@ -0,0 +1,26 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use super::{EngineKind, run_bootstrap_then_refresh_cycle}; + +// SQLite acceleration ran in WAL journaling mode without flushing before +// `fs::copy`, so the snapshot uploaded only the page-zero header. Fixed by +// the WAL-checkpoint hook on `SnapshotEngine` (see commit `fix(snapshot): +// flush SQLite/Turso WAL before snapshotting`); spiceai/spiceai#10643. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn snapshot_refresh_sqlite_bootstrap_then_refresh() -> Result<(), anyhow::Error> { + run_bootstrap_then_refresh_cycle("snapshot_refresh_sqlite", EngineKind::Sqlite).await +} diff --git a/crates/runtime/tests/snapshot_refresh/turso.rs b/crates/runtime/tests/snapshot_refresh/turso.rs new file mode 100644 index 0000000000..2b9c57fe3a --- /dev/null +++ b/crates/runtime/tests/snapshot_refresh/turso.rs @@ -0,0 +1,31 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use super::{EngineKind, run_bootstrap_then_refresh_cycle}; + +// Turso (libsql) WAL flush via rusqlite currently fails with +// "file is not a database" — the libsql on-disk file isn't byte-compatible +// with the rusqlite reader. Tracked in spiceai/spiceai#10657. +// +// The WAL-flush hook (commit 1 of #10651) is therefore a no-op for Turso +// today; data loss can occur on snapshot creation if writes are still in +// the WAL. The integration test stays \#\[ignore\]d until #10657 swaps in +// a turso-native checkpoint path. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "snapshot_refresh for turso requires a turso-native WAL checkpoint; see spiceai/spiceai#10657"] +async fn snapshot_refresh_turso_bootstrap_then_refresh() -> Result<(), anyhow::Error> { + run_bootstrap_then_refresh_cycle("snapshot_refresh_turso", EngineKind::Turso).await +} diff --git a/crates/spicepod/src/acceleration/mod.rs b/crates/spicepod/src/acceleration/mod.rs index 6466f8300d..21641ed059 100644 --- a/crates/spicepod/src/acceleration/mod.rs +++ b/crates/spicepod/src/acceleration/mod.rs @@ -33,6 +33,10 @@ pub enum RefreshMode { Append, Changes, Caching, + /// Refresh exclusively by reloading newer snapshots from the configured + /// snapshot location. The federated source is never queried for refreshes. + /// Requires `snapshots` to be enabled and a snapshot-supporting engine. + Snapshot, } /// Controls the write behavior for accelerated read-write datasets. @@ -541,4 +545,31 @@ mod tests { assert_eq!(accel.mode, mode, "round-trip failed for mode '{s}'"); } } + + #[test] + fn test_deserialize_refresh_mode_snapshot() { + let yaml = "refresh_mode: snapshot"; + let accel: Acceleration = yaml::from_str(yaml).expect("should parse"); + assert_eq!(accel.refresh_mode, Some(RefreshMode::Snapshot)); + } + + #[test] + fn test_deserialize_all_refresh_modes() { + for (yaml_value, expected) in [ + ("full", RefreshMode::Full), + ("append", RefreshMode::Append), + ("changes", RefreshMode::Changes), + ("caching", RefreshMode::Caching), + ("snapshot", RefreshMode::Snapshot), + ] { + let yaml = format!("refresh_mode: {yaml_value}"); + let accel: Acceleration = yaml::from_str(&yaml) + .unwrap_or_else(|_| panic!("should parse refresh_mode '{yaml_value}'")); + assert_eq!( + accel.refresh_mode, + Some(expected), + "unexpected parse for '{yaml_value}'" + ); + } + } } From a43482869748a236ad0d4cd9021b4d326402ff4e Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 4 May 2026 20:31:29 +0300 Subject: [PATCH 3/6] Add integration test for HTTP dynamic headers with pagination (#10663) --- crates/runtime/tests/http/mod.rs | 128 ++++++++++++++++++ ...aders_from_subquery_paginated_results.snap | 33 +++++ 2 files changed, 161 insertions(+) create mode 100644 crates/runtime/tests/http/snapshots/integration__http__http_dynamic_request_headers_from_subquery_paginated_results.snap diff --git a/crates/runtime/tests/http/mod.rs b/crates/runtime/tests/http/mod.rs index bcc2cf2706..ce49bec988 100644 --- a/crates/runtime/tests/http/mod.rs +++ b/crates/runtime/tests/http/mod.rs @@ -242,6 +242,42 @@ async fn start_http_server() -> Result< .route( "/data/items.csv", get(|| async { ([("content-type", "text/csv")], ITEMS_CSV) }), + ) + .route( + "/api/metrics-paginated", + get( + |query: axum::extract::Query>| async move { + // Token-based pagination: returns 2 metrics per page, 3 pages total (5 metrics). + static METRICS: &[(&str, f64)] = &[ + ("cpu", 42.0), + ("mem", 78.5), + ("disk", 55.0), + ("net_in", 12.3), + ("net_out", 9.7), + ]; + let page: usize = query + .get("cursor") + .and_then(|v| v.parse().ok()) + .unwrap_or(1); + let items_per_page = 2; + let start = (page - 1) * items_per_page; + let end = std::cmp::min(start + items_per_page, METRICS.len()); + let items: Vec = METRICS[start..end] + .iter() + .map(|(metric, reading)| json!({ "metric": metric, "reading": reading })) + .collect(); + let next_cursor = if end < METRICS.len() { + Value::Number(serde_json::Number::from(page + 1)) + } else { + Value::Null + }; + let body = json!({ + "data": items, + "next_cursor": next_cursor, + }); + ([("content-type", "application/json")], body.to_string()) + }, + ), ); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.map_err(|e| { @@ -1385,3 +1421,95 @@ async fn test_http_oauth2_rejects_partial_configuration() -> Result<(), String> }) .await } + +/// Tests `IN (SELECT ...)` subqueries with **pagination** against a real registered table +/// +/// 1. A CSV file (`orgs`) with org IDs +/// 2. An HTTP dataset (`paginated_api`) with header filters and token-based pagination +/// 3. A query that builds JSON headers from the CSV rows and uses +/// `IN (SELECT ...)` to drive dynamic HTTP requests across multiple pages +#[tokio::test] +async fn test_http_dynamic_request_headers_from_subquery_with_pagination() -> Result<(), String> { + let _tracing = init_tracing(Some("integration=debug,info")); + register_test_connectors().await; + + test_request_context() + .scope(async { + let (tx, addr, _) = start_http_server().await?; + tracing::debug!("HTTP test server started at {addr}"); + + // 1. Register both datasets: the S3 CSV lookup table and the paginated HTTP API. + let orgs_dataset = Dataset::new("s3://spiceai-public-datasets/orgs.csv", "orgs"); + + let mut http_dataset = Dataset::new(format!("http://{addr}/api"), "paginated_api"); + http_dataset.params = Some(DatasetParams::from_string_map(HashMap::from([ + ("file_format".to_string(), "json".to_string()), + ( + "allowed_request_paths".to_string(), + "/metrics-paginated".to_string(), + ), + ("request_header_filters".to_string(), "enabled".to_string()), + ( + "request_header_allowlist".to_string(), + "x-org-id".to_string(), + ), + ("max_request_partitions".to_string(), "100".to_string()), + // Token-based pagination config + ("pagination".to_string(), "enabled".to_string()), + ( + "pagination_next_pointer".to_string(), + "/next_cursor".to_string(), + ), + ("pagination_token_param".to_string(), "cursor".to_string()), + ("pagination_data_pointer".to_string(), "/data".to_string()), + ("pagination_max_pages".to_string(), "10".to_string()), + ]))); + + let app = AppBuilder::new("http_dynamic_headers_subquery_paginated_test") + .with_dataset(orgs_dataset) + .with_dataset(http_dataset) + .build(); + let mut rt = load_runtime(app).await?; + + // 2. Build header JSON from CSV rows, use IN (SELECT ...) to drive + // dynamic paginated HTTP requests. + let query = r#" + WITH org_headers AS ( + SELECT '{"x-org-id":"' || org_id || '"}' AS hdr + FROM orgs + ) + SELECT + request_headers, + json_get_str(content, 'metric') AS metric, + json_get_float(content, 'reading') AS reading + FROM paginated_api + WHERE request_path = '/metrics-paginated' + AND request_headers IN (SELECT hdr FROM org_headers) + ORDER BY request_headers, metric + "#; + + run_query_and_check_results( + &mut rt, + "http_dynamic_request_headers_from_subquery_paginated", + query, + false, + Some(Box::new(|result_batches: Vec| { + // Each org should get 5 metrics (3 pages: 2+2+1). + let total_rows: usize = result_batches.iter().map(RecordBatch::num_rows).sum(); + assert!(total_rows > 0, "expected paginated results but got 0 rows"); + let pretty = arrow::util::pretty::pretty_format_batches(&result_batches) + .expect("failed to format batches"); + insta::assert_snapshot!( + "http_dynamic_request_headers_from_subquery_paginated_results", + pretty + ); + })), + ) + .await?; + + tx.send(()) + .map_err(|()| "Failed to send shutdown signal".to_string())?; + Ok(()) + }) + .await +} diff --git a/crates/runtime/tests/http/snapshots/integration__http__http_dynamic_request_headers_from_subquery_paginated_results.snap b/crates/runtime/tests/http/snapshots/integration__http__http_dynamic_request_headers_from_subquery_paginated_results.snap new file mode 100644 index 0000000000..6d39d09428 --- /dev/null +++ b/crates/runtime/tests/http/snapshots/integration__http__http_dynamic_request_headers_from_subquery_paginated_results.snap @@ -0,0 +1,33 @@ +--- +source: crates/runtime/tests/http/mod.rs +expression: pretty +--- ++------------------------+---------+---------+ +| request_headers | metric | reading | ++------------------------+---------+---------+ +| {"x-org-id":"org-001"} | cpu | 42.0 | +| {"x-org-id":"org-001"} | disk | 55.0 | +| {"x-org-id":"org-001"} | mem | 78.5 | +| {"x-org-id":"org-001"} | net_in | 12.3 | +| {"x-org-id":"org-001"} | net_out | 9.7 | +| {"x-org-id":"org-002"} | cpu | 42.0 | +| {"x-org-id":"org-002"} | disk | 55.0 | +| {"x-org-id":"org-002"} | mem | 78.5 | +| {"x-org-id":"org-002"} | net_in | 12.3 | +| {"x-org-id":"org-002"} | net_out | 9.7 | +| {"x-org-id":"org-003"} | cpu | 42.0 | +| {"x-org-id":"org-003"} | disk | 55.0 | +| {"x-org-id":"org-003"} | mem | 78.5 | +| {"x-org-id":"org-003"} | net_in | 12.3 | +| {"x-org-id":"org-003"} | net_out | 9.7 | +| {"x-org-id":"org-004"} | cpu | 42.0 | +| {"x-org-id":"org-004"} | disk | 55.0 | +| {"x-org-id":"org-004"} | mem | 78.5 | +| {"x-org-id":"org-004"} | net_in | 12.3 | +| {"x-org-id":"org-004"} | net_out | 9.7 | +| {"x-org-id":"org-005"} | cpu | 42.0 | +| {"x-org-id":"org-005"} | disk | 55.0 | +| {"x-org-id":"org-005"} | mem | 78.5 | +| {"x-org-id":"org-005"} | net_in | 12.3 | +| {"x-org-id":"org-005"} | net_out | 9.7 | ++------------------------+---------+---------+ From 393aef440ff78cc1839b66dc26eb64115f42fb80 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 4 May 2026 11:27:08 -0700 Subject: [PATCH 4/6] Add Delta Lake Azure tenant parameter (#10671) * Add Delta Lake Azure tenant parameter * Fix Delta Lake tenant parameter lint --- .../connector-delta-lake/src/lib.rs | 48 +++++++++++++++++++ tools/spicepodschema/tests/spicepod.all.yaml | 1 + 2 files changed, 49 insertions(+) diff --git a/crates/data-connectors/connector-delta-lake/src/lib.rs b/crates/data-connectors/connector-delta-lake/src/lib.rs index 1a7dde220d..08560c30d9 100644 --- a/crates/data-connectors/connector-delta-lake/src/lib.rs +++ b/crates/data-connectors/connector-delta-lake/src/lib.rs @@ -108,6 +108,9 @@ const PARAMETERS: &[ParameterSpec] = &[ ParameterSpec::component("azure_storage_client_secret") .description("The service principal client secret for accessing the storage account.") .secret(), + ParameterSpec::component("azure_storage_tenant_id") + .description("The service principal tenant id for accessing the storage account.") + .secret(), ParameterSpec::component("azure_storage_sas_key") .description("The shared access signature key for accessing the storage account.") .secret(), @@ -295,3 +298,48 @@ pub const CONNECTOR_NAME: &str = "delta_lake"; pub fn factory() -> Arc { DeltaLakeFactory::new_arc() } + +#[cfg(test)] +mod tests { + use super::*; + use runtime::secrets::Secrets; + use secrecy::SecretString; + use tokio::sync::RwLock; + + #[tokio::test] + async fn tenant_id_parameter_is_accepted_and_registered() { + let parameters = Parameters::try_new( + "connector delta_lake", + vec![( + "delta_lake_azure_storage_tenant_id".to_string(), + SecretString::new("tenant-id".to_string().into()), + )], + "delta_lake", + Arc::new(RwLock::new(Secrets::new())), + PARAMETERS, + ) + .await + .expect("tenant id should be accepted for delta_lake"); + + assert_eq!( + parameters.get("azure_storage_tenant_id").expose().ok(), + Some("tenant-id") + ); + + let delta_table_options = parameters.to_secret_map(); + assert_eq!( + delta_table_options + .get("azure_storage_tenant_id") + .map(ExposeSecret::expose_secret), + Some("tenant-id") + ); + + let registry_params = parameters.storage_registry_params(); + let tenant_id = registry_params + .iter() + .find(|(key, _)| key == "tenant_id") + .map(|(_, value)| value.expose_secret()); + + assert_eq!(tenant_id, Some("tenant-id")); + } +} diff --git a/tools/spicepodschema/tests/spicepod.all.yaml b/tools/spicepodschema/tests/spicepod.all.yaml index deced43047..8c5fa6a75f 100644 --- a/tools/spicepodschema/tests/spicepod.all.yaml +++ b/tools/spicepodschema/tests/spicepod.all.yaml @@ -663,6 +663,7 @@ datasets: delta_lake_azure_storage_account_key: key delta_lake_azure_storage_client_id: client-id delta_lake_azure_storage_client_secret: secret + delta_lake_azure_storage_tenant_id: tenant-id delta_lake_azure_storage_sas_key: sas delta_lake_azure_storage_endpoint: endpoint # GCP params From 98377ae90235e59f120e5f86ff467f4b936f82a7 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 4 May 2026 12:21:34 -0700 Subject: [PATCH 5/6] Add provider-aware LLM prompt caching (#10645) * Add provider-aware LLM prompt caching * feat: Add prompt cache key support and related tests for OpenAI models * fix: Update prompt tokens details initialization to use `then_some` for clarity * refactor: Update LOCAL_LLM_MAX_SEQS to NonZeroUsize and improve request handling in ResponsesWrapper * docs: Add LLM prompt caching internals * docs: Expand project docs index * feat: Add tests for prompt cache point handling in Bedrock and Databricks * ci: guard Install protoc step with relevant_changes check (#10646) The Install protoc step in the Build Test Binary job referenced the local ./.github/actions/install-protoc action but lacked the relevant_changes conditional guard that gates the surrounding steps. When a PR contains no relevant code changes (e.g. openapi.json or spicepod schema only), actions/checkout is skipped, leaving the workspace empty. The Install protoc step then fails immediately at job setup with "Can't find action.yml... install-protoc". This blocks all non-Rust PRs from completing the integration test workflow. Adding the same `if: needs.check_changes.outputs.relevant_changes == 'true'` guard restores the intended skip-on-no-relevant-changes behavior. * fix(benchmarks): redact non-deterministic partition_sizes in explain plan snapshots (#10641) * fix(mistral): refine paged attention support check and update documentation for CUDA backend * fix(docs): format provider mappings table for clarity in llm-prompt-caching documentation * fix(mistral): update LOCAL_LLM_MAX_SEQS initialization for clarity and safety fix(wrapper): improve error message for model parameter parsing * docs(llms): explain local LLM scheduler limit * Fix Anthropic streaming cache usage accounting * feat: use saturating addition for token counts in CompletionUsage --------- Co-authored-by: claudespice Co-authored-by: Sergei Grebnov --- crates/llms/src/anthropic/chat.rs | 59 +++++-- crates/llms/src/anthropic/types.rs | 61 ++++++- crates/llms/src/anthropic/types_stream.rs | 162 ++++++++++++++++-- crates/llms/src/bedrock/chat/mod.rs | 100 +++++++++-- crates/llms/src/chat/mistral.rs | 118 ++++++++++--- crates/llms/src/databricks/mod.rs | 91 +++++++++- crates/llms/src/google/chat.rs | 52 +++++- crates/llms/src/xai/mod.rs | 51 +++++- crates/runtime/src/http/v1/nsql.rs | 11 +- crates/runtime/src/model/chat.rs | 19 ++ crates/runtime/src/model/params/mod.rs | 12 +- crates/runtime/src/model/responses.rs | 60 ++++++- crates/runtime/src/model/wrapper/mod.rs | 6 +- crates/runtime/src/model/wrapper/responses.rs | 73 +++++++- docs/README.md | 118 +++++++++++++ docs/dev/llm-prompt-caching.md | 60 +++++++ 16 files changed, 964 insertions(+), 89 deletions(-) create mode 100644 docs/dev/llm-prompt-caching.md diff --git a/crates/llms/src/anthropic/chat.rs b/crates/llms/src/anthropic/chat.rs index cf76188b82..439815faeb 100644 --- a/crates/llms/src/anthropic/chat.rs +++ b/crates/llms/src/anthropic/chat.rs @@ -33,19 +33,18 @@ use async_openai::types::chat::{ ChatCompletionRequestToolMessageContent, ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart, ChatCompletionResponseMessage, - ChatCompletionResponseStream, ChatCompletionToolChoiceOption, CompletionUsage, - CreateChatCompletionRequest, CreateChatCompletionResponse, FinishReason, FunctionCall, - FunctionName, ReasoningEffort, ResponseFormat, ResponseFormatJsonSchema, Role, - StopConfiguration, ToolChoiceOptions, + ChatCompletionResponseStream, ChatCompletionToolChoiceOption, CreateChatCompletionRequest, + CreateChatCompletionResponse, FinishReason, FunctionCall, FunctionName, ReasoningEffort, + ResponseFormat, ResponseFormatJsonSchema, Role, StopConfiguration, ToolChoiceOptions, }; use serde_json::json; use super::Anthropic; use super::types::{ - AnthropicModelVariant, ContentBlock, ContentParam, MessageCreateParams, MessageCreateResponse, - MessageParam, MessageRole, MetadataParam, ResponseContentBlock, ResponseTextBlock, StopReason, - TextBlockParam, ToolChoiceParam, ToolResultBlockParam, ToolUseBlockParam, default_max_tokens, - tool_from_completion_tools, + AnthropicModelVariant, CacheControlEphemeral, ContentBlock, ContentParam, MessageCreateParams, + MessageCreateResponse, MessageParam, MessageRole, MetadataParam, ResponseContentBlock, + ResponseTextBlock, StopReason, TextBlockParam, ToolChoiceParam, ToolResultBlockParam, + ToolUseBlockParam, default_max_tokens, tool_from_completion_tools, }; use super::types_stream::transform_stream; use async_trait::async_trait; @@ -99,13 +98,7 @@ impl TryFrom for CreateChatCompletionResponse { Ok(CreateChatCompletionResponse { id: value.id, model: value.model.clone(), - usage: Some(CompletionUsage { - prompt_tokens: value.usage.input_tokens, - completion_tokens: value.usage.output_tokens, - total_tokens: value.usage.input_tokens + value.usage.output_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - }), + usage: Some(value.usage.into()), created: SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .map_err(|e| OpenAIError::InvalidArgument(e.to_string()))? @@ -360,6 +353,10 @@ impl TryFrom<(AnthropicModelVariant, CreateChatCompletionRequest)> for MessageCr pair: (AnthropicModelVariant, CreateChatCompletionRequest), ) -> Result { let (model, value) = pair; + let cache_control = value + .prompt_cache_key + .as_ref() + .map(|_| CacheControlEphemeral::ephemeral()); let messages = value .messages @@ -395,6 +392,7 @@ impl TryFrom<(AnthropicModelVariant, CreateChatCompletionRequest)> for MessageCr StopConfiguration::StringArray(a) => a, }), system: system_message_from_messages(&value.messages), + cache_control, messages, tool_choice: match value.tool_choice { @@ -495,3 +493,34 @@ fn system_message_from_messages(messages: &[ChatCompletionRequestMessage]) -> Op Some(system_messages.join("\n")) } } + +#[cfg(test)] +mod tests { + use super::*; + use async_openai::types::chat::{ + ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequest, + }; + + #[test] + fn prompt_cache_key_enables_automatic_cache_control() { + let req = CreateChatCompletionRequest { + messages: vec![ + ChatCompletionRequestUserMessageArgs::default() + .content("Use the cached context.") + .build() + .expect("user message should build") + .into(), + ], + prompt_cache_key: Some("schema-context".to_string()), + ..CreateChatCompletionRequest::default() + }; + + let params = MessageCreateParams::try_from(("claude-sonnet-4-6".to_string(), req)) + .expect("anthropic request should convert"); + + assert_eq!( + params.cache_control, + Some(CacheControlEphemeral::ephemeral()) + ); + } +} diff --git a/crates/llms/src/anthropic/types.rs b/crates/llms/src/anthropic/types.rs index ccd67135b8..f64219ae1f 100644 --- a/crates/llms/src/anthropic/types.rs +++ b/crates/llms/src/anthropic/types.rs @@ -16,7 +16,7 @@ limitations under the License. use async_openai::{ error::OpenAIError, - types::chat::{ChatCompletionTool, ChatCompletionTools}, + types::chat::{ChatCompletionTool, ChatCompletionTools, CompletionUsage, PromptTokensDetails}, }; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -36,6 +36,8 @@ pub struct MessageCreateParams { #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, @@ -398,6 +400,15 @@ pub struct CacheControlEphemeral { pub ttl: Option, } +impl CacheControlEphemeral { + pub fn ephemeral() -> Self { + Self { + control_type: "ephemeral".to_string(), + ttl: None, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum CacheTtl { #[serde(rename = "5m")] @@ -675,6 +686,29 @@ pub struct Usage { pub service_tier: Option, } +impl From for CompletionUsage { + fn from(usage: Usage) -> Self { + let cache_creation_input_tokens = usage.cache_creation_input_tokens.unwrap_or_default(); + let cache_read_input_tokens = usage.cache_read_input_tokens.unwrap_or_default(); + let prompt_tokens = usage + .input_tokens + .saturating_add(cache_creation_input_tokens) + .saturating_add(cache_read_input_tokens); + let prompt_tokens_details = (cache_read_input_tokens > 0).then_some(PromptTokensDetails { + cached_tokens: Some(cache_read_input_tokens), + audio_tokens: None, + }); + + CompletionUsage { + prompt_tokens, + completion_tokens: usage.output_tokens, + total_tokens: prompt_tokens.saturating_add(usage.output_tokens), + prompt_tokens_details, + completion_tokens_details: None, + } + } +} + #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct CacheCreation { #[serde(default)] @@ -701,7 +735,30 @@ pub enum ServiceTier { #[cfg(test)] mod tests { - use super::validate_model_variant; + use super::{Usage, validate_model_variant}; + use async_openai::types::chat::CompletionUsage; + + #[test] + fn usage_conversion_saturates_token_totals() { + let usage = CompletionUsage::from(Usage { + input_tokens: u32::MAX, + output_tokens: 1, + cache_creation_input_tokens: Some(1), + cache_read_input_tokens: Some(2), + ..Usage::default() + }); + + assert_eq!(usage.prompt_tokens, u32::MAX); + assert_eq!(usage.completion_tokens, 1); + assert_eq!(usage.total_tokens, u32::MAX); + assert_eq!( + usage + .prompt_tokens_details + .as_ref() + .and_then(|details| details.cached_tokens), + Some(2) + ); + } // Current Anthropic model names to validate. // Based on the models list from https://docs.claude.com/en/docs/about-claude/models/overview, as of 2025-09-28. diff --git a/crates/llms/src/anthropic/types_stream.rs b/crates/llms/src/anthropic/types_stream.rs index 588dda0e57..293e2613f9 100644 --- a/crates/llms/src/anthropic/types_stream.rs +++ b/crates/llms/src/anthropic/types_stream.rs @@ -19,8 +19,9 @@ use async_openai::{ error::{ApiError, OpenAIError}, types::chat::{ ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionResponseStream, - ChatCompletionStreamResponseDelta, CompletionUsage, CreateChatCompletionStreamResponse, - FinishReason, FunctionCallStream, FunctionType, Role, + ChatCompletionStreamResponseDelta, CompletionTokensDetails, CompletionUsage, + CreateChatCompletionStreamResponse, FinishReason, FunctionCallStream, FunctionType, + PromptTokensDetails, Role, }, }; use futures::{Stream, StreamExt}; @@ -258,13 +259,7 @@ pub fn transform_stream( }) => { state.role = MessageRole::from_opt(&inner_role); state.id = Some(inner_id); - state.usage = Some(CompletionUsage { - prompt_tokens: inner_usage.input_tokens, - completion_tokens: inner_usage.output_tokens, - total_tokens: inner_usage.input_tokens + inner_usage.output_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - }); + state.usage = Some(inner_usage.into()); state.model = Some(model); Some(create_anthropic_stream_response( &state.id.clone().unwrap_or_default(), @@ -318,9 +313,7 @@ pub fn transform_stream( }) => { // Update usage if let Some(ref mut u) = state.usage { - u.prompt_tokens += inner_usage.input_tokens; - u.completion_tokens += inner_usage.output_tokens; - u.total_tokens += inner_usage.input_tokens + inner_usage.output_tokens; + add_usage_delta(u, inner_usage); } Some(create_anthropic_stream_response( &state.id.clone().unwrap_or_default(), @@ -378,6 +371,71 @@ pub fn transform_stream( Box::pin(transformed_stream) } +fn add_usage_delta(usage: &mut CompletionUsage, delta: Usage) { + let delta = CompletionUsage::from(delta); + + usage.prompt_tokens = usage.prompt_tokens.saturating_add(delta.prompt_tokens); + usage.completion_tokens = usage + .completion_tokens + .saturating_add(delta.completion_tokens); + usage.total_tokens = usage.total_tokens.saturating_add(delta.total_tokens); + usage.prompt_tokens_details = combine_prompt_token_details( + usage.prompt_tokens_details.take(), + delta.prompt_tokens_details, + ); + usage.completion_tokens_details = combine_completion_token_details( + usage.completion_tokens_details.take(), + delta.completion_tokens_details, + ); +} + +fn combine_prompt_token_details( + current: Option, + delta: Option, +) -> Option { + match (current, delta) { + (Some(current), Some(delta)) => Some(PromptTokensDetails { + audio_tokens: combine_opt_u32(current.audio_tokens, delta.audio_tokens), + cached_tokens: combine_opt_u32(current.cached_tokens, delta.cached_tokens), + }), + (Some(current), None) => Some(current), + (None, Some(delta)) => Some(delta), + (None, None) => None, + } +} + +fn combine_completion_token_details( + current: Option, + delta: Option, +) -> Option { + match (current, delta) { + (Some(current), Some(delta)) => Some(CompletionTokensDetails { + accepted_prediction_tokens: combine_opt_u32( + current.accepted_prediction_tokens, + delta.accepted_prediction_tokens, + ), + audio_tokens: combine_opt_u32(current.audio_tokens, delta.audio_tokens), + reasoning_tokens: combine_opt_u32(current.reasoning_tokens, delta.reasoning_tokens), + rejected_prediction_tokens: combine_opt_u32( + current.rejected_prediction_tokens, + delta.rejected_prediction_tokens, + ), + }), + (Some(current), None) => Some(current), + (None, Some(delta)) => Some(delta), + (None, None) => None, + } +} + +fn combine_opt_u32(current: Option, delta: Option) -> Option { + match (current, delta) { + (Some(current), Some(delta)) => Some(current.saturating_add(delta)), + (Some(current), None) => Some(current), + (None, Some(delta)) => Some(delta), + (None, None) => None, + } +} + fn format_anthropic_stream_error(error: OpenAIError) -> OpenAIError { let OpenAIError::ApiError(api_error) = error else { return error; @@ -430,3 +488,83 @@ fn create_anthropic_stream_response( crate::streaming_utils::create_stream_response(id, model, choices, usage) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn usage_delta_accumulates_cache_tokens() { + let mut usage = Usage { + input_tokens: 10, + output_tokens: 1, + cache_read_input_tokens: Some(3), + ..Usage::default() + } + .into(); + + add_usage_delta( + &mut usage, + Usage { + input_tokens: 2, + output_tokens: 4, + cache_creation_input_tokens: Some(5), + cache_read_input_tokens: Some(7), + ..Usage::default() + }, + ); + + assert_eq!(usage.prompt_tokens, 27); + assert_eq!(usage.completion_tokens, 5); + assert_eq!(usage.total_tokens, 32); + assert_eq!( + usage + .prompt_tokens_details + .as_ref() + .and_then(|details| details.cached_tokens), + Some(10) + ); + } + + #[test] + fn usage_delta_saturates_token_counts() { + let mut usage = CompletionUsage { + prompt_tokens: u32::MAX - 1, + completion_tokens: u32::MAX - 1, + total_tokens: u32::MAX - 1, + prompt_tokens_details: Some(PromptTokensDetails { + cached_tokens: Some(u32::MAX - 1), + audio_tokens: Some(u32::MAX - 1), + }), + completion_tokens_details: None, + }; + + add_usage_delta( + &mut usage, + Usage { + input_tokens: 2, + output_tokens: 2, + cache_read_input_tokens: Some(2), + ..Usage::default() + }, + ); + + assert_eq!(usage.prompt_tokens, u32::MAX); + assert_eq!(usage.completion_tokens, u32::MAX); + assert_eq!(usage.total_tokens, u32::MAX); + assert_eq!( + usage + .prompt_tokens_details + .as_ref() + .and_then(|details| details.cached_tokens), + Some(u32::MAX) + ); + assert_eq!( + usage + .prompt_tokens_details + .as_ref() + .and_then(|details| details.audio_tokens), + Some(u32::MAX - 1) + ); + } +} diff --git a/crates/llms/src/bedrock/chat/mod.rs b/crates/llms/src/bedrock/chat/mod.rs index 019169b3cf..26969bbd72 100644 --- a/crates/llms/src/bedrock/chat/mod.rs +++ b/crates/llms/src/bedrock/chat/mod.rs @@ -54,13 +54,14 @@ use aws_sdk_bedrockruntime::types::builders::{ }; use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError; use aws_sdk_bedrockruntime::types::{ - ContentBlock, ContentBlockDelta as ContentBlockDeltaType, ContentBlockDeltaEvent, - ContentBlockStart as ContentBlockStartInner, ContentBlockStartEvent, ConversationRole, - ConverseStreamMetadataEvent, ConverseStreamOutput as ConverseStreamOutputPacket, - GuardrailConfiguration, GuardrailStreamConfiguration, InferenceConfiguration, - JsonSchemaDefinition, Message, MessageStartEvent, MessageStopEvent, OutputConfig, OutputFormat, - OutputFormatStructure, OutputFormatType, SystemContentBlock, ToolResultContentBlock, - ToolResultStatus, ToolUseBlockDelta, ToolUseBlockStart, + CachePointBlock, CachePointType, ContentBlock, ContentBlockDelta as ContentBlockDeltaType, + ContentBlockDeltaEvent, ContentBlockStart as ContentBlockStartInner, ContentBlockStartEvent, + ConversationRole, ConverseStreamMetadataEvent, + ConverseStreamOutput as ConverseStreamOutputPacket, GuardrailConfiguration, + GuardrailStreamConfiguration, InferenceConfiguration, JsonSchemaDefinition, Message, + MessageStartEvent, MessageStopEvent, OutputConfig, OutputFormat, OutputFormatStructure, + OutputFormatType, SystemContentBlock, ToolResultContentBlock, ToolResultStatus, + ToolUseBlockDelta, ToolUseBlockStart, }; use aws_smithy_types::Document; use futures::stream::StreamExt; @@ -349,6 +350,27 @@ impl BedrockConverse { .collect() } + fn prompt_cache_point() -> Result { + CachePointBlock::builder() + .r#type(CachePointType::Default) + .build() + } + + fn add_prompt_cache_point( + system: &mut Vec, + messages: &mut [Message], + ) -> Result<(), BuildError> { + let cache_point = Self::prompt_cache_point()?; + if let Some(last_message) = messages.last_mut() { + last_message + .content + .push(ContentBlock::CachePoint(cache_point)); + } else { + system.push(SystemContentBlock::CachePoint(cache_point)); + } + Ok(()) + } + fn output_config( response_format: Option, ) -> Result, OpenAIError> { @@ -415,6 +437,7 @@ impl BedrockConverse { tool_choice, tools, response_format, + prompt_cache_key, .. } = req; @@ -430,9 +453,13 @@ impl BedrockConverse { ) }); - let system = Self::convert_system_messages(system); - let messages = + let mut system = Self::convert_system_messages(system); + let mut messages = Self::convert_non_system_messages(messages).map_err(|e| to_api_error(e.to_string()))?; + if prompt_cache_key.is_some() { + Self::add_prompt_cache_point(&mut system, &mut messages) + .map_err(|e| to_api_error(e.to_string()))?; + } let guardrails: Option = self .guardrail @@ -476,6 +503,7 @@ impl BedrockConverse { tools, tool_choice, response_format, + prompt_cache_key, .. } = req; @@ -491,9 +519,13 @@ impl BedrockConverse { ) }); - let system = Self::convert_system_messages(system); - let messages = + let mut system = Self::convert_system_messages(system); + let mut messages = Self::convert_non_system_messages(messages).map_err(|e| to_api_error(e.to_string()))?; + if prompt_cache_key.is_some() { + Self::add_prompt_cache_point(&mut system, &mut messages) + .map_err(|e| to_api_error(e.to_string()))?; + } let guardrails: Option = self .guardrail @@ -827,3 +859,49 @@ impl Chat for BedrockConverse { None } } + +#[cfg(test)] +mod tests { + use super::*; + use async_openai::types::chat::ChatCompletionRequestUserMessageArgs; + + #[test] + fn prompt_cache_point_is_added_to_last_message() { + let messages = vec![ + ChatCompletionRequestUserMessageArgs::default() + .content("Reusable context") + .build() + .expect("user message should build") + .into(), + ]; + let mut system = vec![]; + let mut messages = BedrockConverse::convert_non_system_messages(messages) + .expect("bedrock messages should convert"); + + BedrockConverse::add_prompt_cache_point(&mut system, &mut messages) + .expect("cache point should build"); + + assert!(matches!( + messages + .last() + .and_then(|message| message.content.last()), + Some(ContentBlock::CachePoint(cache_point)) + if cache_point.r#type == CachePointType::Default + )); + } + + #[test] + fn prompt_cache_point_is_added_to_system_when_messages_are_empty() { + let mut system = vec![]; + let mut messages = vec![]; + + BedrockConverse::add_prompt_cache_point(&mut system, &mut messages) + .expect("cache point should build"); + + assert!(matches!( + system.last(), + Some(SystemContentBlock::CachePoint(cache_point)) + if cache_point.r#type == CachePointType::Default + )); + } +} diff --git a/crates/llms/src/chat/mistral.rs b/crates/llms/src/chat/mistral.rs index 4e7a13c9ec..148f80107e 100644 --- a/crates/llms/src/chat/mistral.rs +++ b/crates/llms/src/chat/mistral.rs @@ -60,6 +60,10 @@ use std::{ }; use tokio::sync::mpsc::{Receiver, Sender, channel}; +/// Preserve the existing local LLM scheduler concurrency. Paged attention uses +/// the same cap so enabling cache-aware scheduling does not change request parallelism. +const LOCAL_LLM_MAX_SEQS: NonZeroUsize = NonZeroUsize::MIN.saturating_add(4); + pub struct MistralLlama { pipeline: Arc, counter: AtomicUsize, @@ -133,17 +137,33 @@ impl MistralLlama { .and_then(|p| p.as_path().extension()) .and_then(|e| e.to_str()); + let paged_attn_config = Self::paged_attention_config(&device); + let paged_attn_requested = paged_attn_config.is_some(); let pipeline = match extension { - Some("ggml") => { - Self::load_ggml_pipeline(paths, &device, &model_id, chat_template_literal)? - } - Some("gguf") => { - Self::load_gguf_pipeline(paths, &device, &model_id, chat_template_literal)? - } - _ => Self::load_default_pipeline(paths, &device, &model_id, chat_template_literal)?, + Some("ggml") => Self::load_ggml_pipeline( + paths, + &device, + &model_id, + chat_template_literal, + paged_attn_config, + )?, + Some("gguf") => Self::load_gguf_pipeline( + paths, + &device, + &model_id, + chat_template_literal, + paged_attn_config, + )?, + _ => Self::load_default_pipeline( + paths, + &device, + &model_id, + chat_template_literal, + paged_attn_config, + )?, }; - Ok(Self::from_pipeline(pipeline).await) + Ok(Self::from_pipeline(pipeline, paged_attn_requested).await) } /// Create paths object, [`ModelPaths`], to create new [`MistralLlama`]. @@ -178,6 +198,7 @@ impl MistralLlama { device: &Device, model_id: &str, chat_template_literal: Option<&str>, + paged_attn_config: Option, ) -> Result>> { let model_parts: Vec<&str> = model_id.split(':').collect(); NormalLoaderBuilder::new( @@ -197,7 +218,7 @@ impl MistralLlama { true, DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()), None, - None, + paged_attn_config, ) .map_err(|e| ChatError::FailedToLoadModel { source: e.into() }) } @@ -207,6 +228,7 @@ impl MistralLlama { device: &Device, model_id: &str, chat_template_literal: Option<&str>, + paged_attn_config: Option, ) -> Result>> { // Note: GGUF supports chat templates in the file, but since GGML/llama.cpp does // not write them into GGUF with their conversions, often it requires user @@ -256,7 +278,7 @@ impl MistralLlama { true, DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()), None, - None, + paged_attn_config, ) .map_err(|e| ChatError::FailedToLoadModel { source: e.into() }) } @@ -266,6 +288,7 @@ impl MistralLlama { device: &Device, model_id: &str, chat_template_literal: Option<&str>, + paged_attn_config: Option, ) -> Result>> { let tokenizer = paths.get_tokenizer_filename().to_string_lossy().to_string(); GGMLLoaderBuilder::new( @@ -286,11 +309,56 @@ impl MistralLlama { true, DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()), None, - None, + paged_attn_config, ) .map_err(|e| ChatError::FailedToLoadModel { source: e.into() }) } + fn paged_attention_config(device: &Device) -> Option { + if matches!(device, Device::Cpu) || !Self::paged_attention_supported() { + return None; + } + + match mistralrs::PagedAttentionMetaBuilder::default().build() { + Ok(config) => Some(config), + Err(e) => { + tracing::warn!("Failed to initialize local LLM paged attention cache: {e}"); + None + } + } + } + + fn paged_attention_supported() -> bool { + cfg!(all(feature = "cuda", target_family = "unix")) + } + + fn default_scheduler_config() -> mistralrs::SchedulerConfig { + mistralrs::SchedulerConfig::DefaultScheduler { + method: mistralrs::DefaultSchedulerMethod::Fixed(LOCAL_LLM_MAX_SEQS), + } + } + + async fn scheduler_config( + pipeline: &Arc>, + paged_attn_requested: bool, + ) -> mistralrs::SchedulerConfig { + if paged_attn_requested { + let cache_config = pipeline.lock().await.get_metadata().cache_config.clone(); + if let Some(config) = cache_config { + return mistralrs::SchedulerConfig::PagedAttentionMeta { + max_num_seqs: LOCAL_LLM_MAX_SEQS.get(), + config, + }; + } + + tracing::debug!( + "Paged attention was requested for local LLM caching, but the model did not initialize paged cache metadata. Falling back to the default scheduler." + ); + } + + Self::default_scheduler_config() + } + /// Get the device to use for the model. /// Preference order: CUDA, Metal, CPU. fn get_device() -> Device { @@ -390,6 +458,8 @@ impl MistralLlama { TokenSource::Literal(secret.expose_secret().to_string()) }); + let paged_attn_config = Self::paged_attention_config(&device); + let paged_attn_requested = paged_attn_config.is_some(); let pipeline = loader? .load_model_from_hf( model_parts.get(1).map(|&x| x.to_string()), @@ -399,28 +469,22 @@ impl MistralLlama { false, DeviceMapSetting::Auto(device_map_params), None, - None, + paged_attn_config, ) .map_err(|e| ChatError::FailedToLoadModel { source: e.into() })?; - Ok(Self::from_pipeline(pipeline).await) + Ok(Self::from_pipeline(pipeline, paged_attn_requested).await) } - #[expect(clippy::expect_used)] - async fn from_pipeline(p: Arc>) -> Self { + async fn from_pipeline( + pipeline: Arc>, + paged_attn_requested: bool, + ) -> Self { + let scheduler_config = Self::scheduler_config(&pipeline, paged_attn_requested).await; Self { - pipeline: MistralRsBuilder::new( - p, - mistralrs::SchedulerConfig::DefaultScheduler { - method: mistralrs::DefaultSchedulerMethod::Fixed( - NonZeroUsize::new(5).expect("unreachable 5 > 0"), - ), - }, - false, - None, - ) - .build() - .await, + pipeline: MistralRsBuilder::new(pipeline, scheduler_config, false, None) + .build() + .await, counter: AtomicUsize::new(0), } } diff --git a/crates/llms/src/databricks/mod.rs b/crates/llms/src/databricks/mod.rs index 700983b792..04344422e8 100644 --- a/crates/llms/src/databricks/mod.rs +++ b/crates/llms/src/databricks/mod.rs @@ -108,6 +108,53 @@ impl Databricks { req } + fn alter_request_body( + &self, + req: CreateChatCompletionRequest, + stream: bool, + ) -> Result { + let mut req = self.alter_request(req); + let prompt_cache_requested = req.prompt_cache_key.take().is_some(); + req.stream = Some(stream); + + let mut body = + serde_json::to_value(req).map_err(|e| OpenAIError::InvalidArgument(e.to_string()))?; + if prompt_cache_requested { + Self::add_prompt_cache_control(&mut body); + } + Ok(body) + } + + fn add_prompt_cache_control(body: &mut Value) { + let Some(messages) = body.get_mut("messages").and_then(Value::as_array_mut) else { + return; + }; + for message in messages.iter_mut().rev() { + let Some(content) = message.get_mut("content") else { + continue; + }; + + match content { + Value::String(text) if !text.is_empty() => { + let text = std::mem::take(text); + *content = json!([{ "type": "text", "text": text, "cache_control": { "type": "ephemeral" } }]); + return; + } + Value::Array(parts) => { + if let Some(Value::Object(part)) = parts + .iter_mut() + .rev() + .find(|part| part.get("text").is_some()) + { + part.insert("cache_control".to_string(), json!({ "type": "ephemeral" })); + return; + } + } + _ => {} + } + } + } + #[must_use] pub fn set_cache( mut self, @@ -275,7 +322,7 @@ impl Chat for Databricks { req: CreateChatCompletionRequest, ) -> Result { // Must use `create_stream_byot` with custom response type to handle Databricks-specific format. - let altered_req = self.alter_request(req); + let altered_req = self.alter_request_body(req, true)?; let stream: std::pin::Pin< Box< dyn futures::Stream< @@ -299,7 +346,7 @@ impl Chat for Databricks { self.client .chat() .path("")? - .create_byot(self.alter_request(req)) + .create_byot(self.alter_request_body(req, false)?) .await } } @@ -406,3 +453,43 @@ impl std::fmt::Debug for Databricks { .finish_non_exhaustive() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prompt_cache_control_is_added_to_last_text_message() { + let mut body = json!({ + "messages": [ + {"role": "system", "content": "Static instructions"}, + {"role": "user", "content": "Reusable context"} + ] + }); + + Databricks::add_prompt_cache_control(&mut body); + + assert_eq!( + body["messages"][1]["content"][0]["cache_control"], + json!({ "type": "ephemeral" }) + ); + } + + #[test] + fn prompt_cache_control_skips_trailing_non_text_content() { + let mut body = json!({ + "messages": [ + {"role": "user", "content": "Reusable context"}, + {"role": "assistant", "content": [{"type": "tool_use", "id": "call_123"}]} + ] + }); + + Databricks::add_prompt_cache_control(&mut body); + + assert_eq!( + body["messages"][0]["content"][0]["cache_control"], + json!({ "type": "ephemeral" }) + ); + assert!(body["messages"][1]["content"][0]["cache_control"].is_null()); + } +} diff --git a/crates/llms/src/google/chat.rs b/crates/llms/src/google/chat.rs index 8c144eb744..47dec3e568 100644 --- a/crates/llms/src/google/chat.rs +++ b/crates/llms/src/google/chat.rs @@ -34,7 +34,7 @@ use async_trait::async_trait; use futures::Stream; use futures::StreamExt; use google_genai::generate::{GenerateContentRequest, GenerateContentResponse}; -use google_genai::types::{Content, FunctionDeclaration, FunctionResponse, Part}; +use google_genai::types::{CachedContent, Content, FunctionDeclaration, FunctionResponse, Part}; use std::collections::HashMap; use std::pin::Pin; use std::time::SystemTime; @@ -83,9 +83,15 @@ impl Chat for Google { } fn convert_to_google_request(req: CreateChatCompletionRequest) -> GenerateContentRequest { + let CreateChatCompletionRequest { + messages, + prompt_cache_key, + tools, + .. + } = req; let mut contents = Vec::new(); - for message in req.messages { + for message in messages { let content = match message { ChatCompletionRequestMessage::User(msg) => { let text = match msg.content { @@ -204,7 +210,13 @@ fn convert_to_google_request(req: CreateChatCompletionRequest) -> GenerateConten let mut google_req = GenerateContentRequest::new(contents); // Convert tools if present - if let Some(openai_tools) = req.tools { + if let Some(prompt_cache_key) = prompt_cache_key { + google_req = google_req.with_cached_content(CachedContent { + name: Some(prompt_cache_key), + }); + } + + if let Some(openai_tools) = tools { let google_tools: Vec = openai_tools .into_iter() .filter_map(|tool| { @@ -443,3 +455,37 @@ fn convert_google_stream_to_openai( .map_err(|e| openai_api_error(e.to_string())) }) } + +#[cfg(test)] +mod tests { + use super::*; + use async_openai::types::chat::{ + ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequest, + }; + + #[test] + fn forwards_prompt_cache_key_as_cached_content() { + let req = CreateChatCompletionRequest { + messages: vec![ + ChatCompletionRequestUserMessageArgs::default() + .content("What did the cached context say?") + .build() + .expect("user message should build") + .into(), + ], + prompt_cache_key: Some("cachedContents/schema-cache".to_string()), + ..CreateChatCompletionRequest::default() + }; + + let google_req = convert_to_google_request(req); + + assert_eq!( + google_req + .cached_content + .expect("cached content should be set") + .name + .as_deref(), + Some("cachedContents/schema-cache") + ); + } +} diff --git a/crates/llms/src/xai/mod.rs b/crates/llms/src/xai/mod.rs index 63b290934f..f5ff68c906 100644 --- a/crates/llms/src/xai/mod.rs +++ b/crates/llms/src/xai/mod.rs @@ -29,6 +29,7 @@ use async_openai::{ Client, config::OpenAIConfig, error::OpenAIError, + traits::RequestOptionsBuilder, types::chat::{ ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, @@ -70,8 +71,13 @@ impl Xai { } } - /// Changes to `req` to accomodate xAi not being `OpenAI` compatible. - fn alter_request(&self, mut req: CreateChatCompletionRequest) -> CreateChatCompletionRequest { + /// Changes to `req` to accommodate xAi not being `OpenAI` compatible. + fn alter_request( + &self, + mut req: CreateChatCompletionRequest, + ) -> (CreateChatCompletionRequest, Option) { + let prompt_cache_key = req.prompt_cache_key.take(); + // Use name of xAI model, not spicepod model. req.model.clone_from(&self.model); @@ -122,7 +128,7 @@ impl Xai { } } - req + (req, prompt_cache_key) } } @@ -156,11 +162,13 @@ impl Chat for Xai { .await .map_err(|e| OpenAIError::InvalidArgument(e.to_string()))?; - let stream = self - .client - .chat() - .create_stream(self.alter_request(req)) - .await?; + let (req, prompt_cache_key) = self.alter_request(req); + let mut chat = self.client.chat(); + if let Some(prompt_cache_key) = prompt_cache_key { + chat = chat.header("x-grok-conv-id", prompt_cache_key)?; + } + + let stream = chat.create_stream(req).await?; drop(permit); Ok(Box::pin(stream)) @@ -176,9 +184,34 @@ impl Chat for Xai { .await .map_err(|e| OpenAIError::InvalidArgument(e.to_string()))?; - let resp = self.client.chat().create(self.alter_request(req)).await?; + let (req, prompt_cache_key) = self.alter_request(req); + let mut chat = self.client.chat(); + if let Some(prompt_cache_key) = prompt_cache_key { + chat = chat.header("x-grok-conv-id", prompt_cache_key)?; + } + + let resp = chat.create(req).await?; drop(permit); Ok(resp) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prompt_cache_key_is_moved_to_chat_header_value() { + let xai = Xai::new(Some("grok-4.3"), "test-key"); + let req = CreateChatCompletionRequest { + prompt_cache_key: Some("conversation-123".to_string()), + ..CreateChatCompletionRequest::default() + }; + + let (req, prompt_cache_key) = xai.alter_request(req); + + assert_eq!(prompt_cache_key.as_deref(), Some("conversation-123")); + assert!(req.prompt_cache_key.is_none()); + } +} diff --git a/crates/runtime/src/http/v1/nsql.rs b/crates/runtime/src/http/v1/nsql.rs index fe0ec17e6c..a3670a8a52 100644 --- a/crates/runtime/src/http/v1/nsql.rs +++ b/crates/runtime/src/http/v1/nsql.rs @@ -153,6 +153,10 @@ pub struct Request { /// Names of datasets to sample from when constructing model context; this is a sampling hint and does not restrict which tables queries can target. If omitted, all datasets are used. #[serde(skip_serializing_if = "Option::is_none")] pub datasets: Option>, + + /// Stable prompt-cache key forwarded to the configured NSQL model for provider-specific cache handling. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, } fn default_sample_data_enabled() -> bool { @@ -192,7 +196,8 @@ fn return_sql_only(accept: Option<&TypedHeader>) -> bool { "model": "nql", "stream": false, "sample_data_enabled": true, - "datasets": ["sales_data"] + "datasets": ["sales_data"], + "prompt_cache_key": "sales-dashboard" }) )) ), @@ -314,6 +319,7 @@ pub(crate) async fn handle_nsql_query( model, sample_data_enabled, datasets, + prompt_cache_key, .. } = payload; let table_allowlist_opt = match table_allowlist(&model, &rt).await { @@ -417,6 +423,9 @@ pub(crate) async fn handle_nsql_query( req.messages.extend(schema_messages.clone()); req.messages.extend(sample_data_messages.clone()); + if let Some(prompt_cache_key) = &prompt_cache_key { + req.prompt_cache_key = Some(prompt_cache_key.clone()); + } let resp = match nql_model.chat_request(req).instrument(span.clone()).await { Ok(r) => r, diff --git a/crates/runtime/src/model/chat.rs b/crates/runtime/src/model/chat.rs index 83f1190901..ecd39219f1 100644 --- a/crates/runtime/src/model/chat.rs +++ b/crates/runtime/src/model/chat.rs @@ -616,6 +616,25 @@ mod test { ); } + #[test] + fn test_get_openai_request_overrides_with_prompt_cache_key() { + let mut model = Model::new("hf:test_model", "test_model"); + model.params.insert( + "hf_prompt_cache_key".to_string(), + Value::String("schema-context".to_string()), + ); + + let overrides = get_openai_request_overrides(&model, "hf"); + + assert_eq!(overrides.len(), 1); + assert!( + overrides + .iter() + .any(|(key, value)| key == "prompt_cache_key" + && value == &Value::String("schema-context".to_string())) + ); + } + #[test] // Param with takes precedence over the deprecated openai_ prefix. fn test_get_openai_request_overrides_with_model_prefix_and_deprecated() { diff --git a/crates/runtime/src/model/params/mod.rs b/crates/runtime/src/model/params/mod.rs index ad6918054b..9ab1f02c50 100644 --- a/crates/runtime/src/model/params/mod.rs +++ b/crates/runtime/src/model/params/mod.rs @@ -49,8 +49,8 @@ pub fn get_params_spec(source: &ModelSource) -> Option<&'static [ParameterSpec]> } } -pub const PARAM_LEN: usize = 47; -pub const PARAM_WITH_DEPRE_LEN: usize = 48; +pub const PARAM_LEN: usize = 51; +pub const PARAM_WITH_DEPRE_LEN: usize = 52; // Model parameters that are used for openai model provider. Those parameters are supported by other (non-openai) models as well. // OpenAI model is prefixed with `openai_`, use separate PARAMETERS constant to avoid confusion with other model providers. @@ -88,6 +88,8 @@ pub const COMMON_MODEL_PARAMETERS: [ParameterSpec; PARAM_LEN] = [ ParameterSpec::runtime("top_p"), ParameterSpec::runtime("tool_choice"), ParameterSpec::runtime("parallel_tool_calls"), + ParameterSpec::runtime("prompt_cache_key"), + ParameterSpec::runtime("prompt_cache_retention"), ParameterSpec::runtime("user"), ParameterSpec::component("frequency_penalty").deprecated("Use 'frequency_penalty' without prefix"), ParameterSpec::component("logit_bias").deprecated("Use 'logit_bias' without prefix"), @@ -109,6 +111,8 @@ pub const COMMON_MODEL_PARAMETERS: [ParameterSpec; PARAM_LEN] = [ ParameterSpec::component("tools").deprecated("Use 'tools' without prefix"), ParameterSpec::component("tool_choice").deprecated("Use 'tool_choice' without prefix"), ParameterSpec::component("parallel_tool_calls").deprecated("Use 'parallel_tool_calls' without prefix"), + ParameterSpec::component("prompt_cache_key").deprecated("Use 'prompt_cache_key' without prefix"), + ParameterSpec::component("prompt_cache_retention").deprecated("Use 'prompt_cache_retention' without prefix"), ParameterSpec::component("user").deprecated("Use 'user' without prefix"), ]; @@ -148,6 +152,8 @@ pub const COMMON_MODEL_PARAMETERS_WITH_DEPRECATED: [ParameterSpec; PARAM_WITH_DE ParameterSpec::component("tools"), ParameterSpec::component("tool_choice"), ParameterSpec::component("parallel_tool_calls"), + ParameterSpec::component("prompt_cache_key"), + ParameterSpec::component("prompt_cache_retention"), ParameterSpec::component("user"), // For model providers that are not OpenAI // The default Override parameters start with `openai_` is deprecated and will be removed in a future release. @@ -172,5 +178,7 @@ pub const COMMON_MODEL_PARAMETERS_WITH_DEPRECATED: [ParameterSpec; PARAM_WITH_DE ParameterSpec::runtime("openai_tools").deprecated(DEPRECATED_MESSAGE), ParameterSpec::runtime("openai_tool_choice").deprecated(DEPRECATED_MESSAGE), ParameterSpec::runtime("openai_parallel_tool_calls").deprecated(DEPRECATED_MESSAGE), + ParameterSpec::runtime("openai_prompt_cache_key").deprecated(DEPRECATED_MESSAGE), + ParameterSpec::runtime("openai_prompt_cache_retention").deprecated(DEPRECATED_MESSAGE), ParameterSpec::runtime("openai_user").deprecated(DEPRECATED_MESSAGE), ]; diff --git a/crates/runtime/src/model/responses.rs b/crates/runtime/src/model/responses.rs index 9997e8b51e..c89269b3ac 100644 --- a/crates/runtime/src/model/responses.rs +++ b/crates/runtime/src/model/responses.rs @@ -29,14 +29,19 @@ use llms::responses::Responses; use secrecy::SecretString; use serde_json::Value; use spicepod::component::model::{Model, ModelSource}; -use std::collections::HashMap; use std::str::FromStr; -use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, LazyLock}, +}; pub type LLMResponsesModelStore = HashMap>; const DEFAULT_SPICE_TOOL_RECURSION_LIMIT: usize = 10; +static OPENAI_RESPONSES_DEFAULT_PARAM_KEYS: LazyLock> = + LazyLock::new(|| HashSet::from(["prompt_cache_key", "prompt_cache_retention"])); + macro_rules! extract_secret { ($params:expr, $key:expr) => { $params.get($key).map(secrecy::ExposeSecret::expose_secret) @@ -167,9 +172,25 @@ fn construct_model( model, component.name.as_str(), system_prompt, + get_openai_responses_request_overrides(component, params.prefix()), ))) } +pub fn get_openai_responses_request_overrides(model: &Model, prefix: &str) -> Vec<(String, Value)> { + let mut request_overrides: HashMap = HashMap::new(); + for &key in OPENAI_RESPONSES_DEFAULT_PARAM_KEYS.iter() { + if let Some(value) = model.params.get(key) { + request_overrides.insert(key.to_string(), value.clone()); + } else if let Some(value) = model.params.get(&format!("{prefix}_{key}")) { + request_overrides.insert(key.to_string(), value.clone()); + } else if let Some(value) = model.params.get(&format!("openai_{key}")) { + request_overrides.insert(key.to_string(), value.clone()); + } + } + + request_overrides.into_iter().collect() +} + fn openai(model_id: Option, params: &Parameters) -> Result, LlmError> { let api_base = params.get("endpoint").expose().ok(); let api_key = params.get("api_key").expose().ok(); @@ -277,3 +298,38 @@ fn xai(model_id: Option<&str>, params: &Parameters) -> Result }; Ok(Arc::new(llms::xai::Xai::new(model_id, api_key)) as Arc) } + +#[cfg(test)] +mod tests { + use super::*; + use spicepod::component::model::Model; + + #[test] + fn test_get_openai_responses_request_overrides_with_prompt_cache() { + let mut model = Model::new("openai:gpt-4o", "test_model"); + model.params.insert( + "prompt_cache_key".to_string(), + Value::String("default-key".to_string()), + ); + model.params.insert( + "openai_prompt_cache_retention".to_string(), + Value::String("24h".to_string()), + ); + + let overrides = get_openai_responses_request_overrides(&model, "openai"); + + assert_eq!(overrides.len(), 2); + assert!( + overrides + .iter() + .any(|(key, value)| key == "prompt_cache_key" + && value == &Value::String("default-key".to_string())) + ); + assert!( + overrides + .iter() + .any(|(key, value)| key == "prompt_cache_retention" + && value == &Value::String("24h".to_string())) + ); + } +} diff --git a/crates/runtime/src/model/wrapper/mod.rs b/crates/runtime/src/model/wrapper/mod.rs index 3fee872542..b1cfb474e1 100644 --- a/crates/runtime/src/model/wrapper/mod.rs +++ b/crates/runtime/src/model/wrapper/mod.rs @@ -69,6 +69,7 @@ pub(crate) static OPENAI_DEFAULT_PARAM_KEYS: LazyLock> = "top_p", "tool_choice", "parallel_tool_calls", + "prompt_cache_key", "user", ]) }); @@ -94,7 +95,7 @@ macro_rules! set_default_w_warning { Ok(val) => Some(val), Err(_) => { tracing::warn!( - "Failed to parse `openai_{}` override for model='{}'. Ensure {:?} is of the correct format.", + "Failed to parse `{}` model parameter override for model='{}'. Ensure {:?} is of the correct format.", stringify!($field), $model, $value @@ -265,6 +266,9 @@ impl ChatWrapper { "parallel_tool_calls" => { set_default_w_warning!(req, parallel_tool_calls, value, self.public_name); } + "prompt_cache_key" => { + set_default_w_warning!(req, prompt_cache_key, value, self.public_name); + } "user" => set_default_w_warning!(req, user, value, self.public_name), _ => { tracing::debug!("Ignoring unknown default key: {}", key); diff --git a/crates/runtime/src/model/wrapper/responses.rs b/crates/runtime/src/model/wrapper/responses.rs index 3c550b0146..ee9947a06e 100644 --- a/crates/runtime/src/model/wrapper/responses.rs +++ b/crates/runtime/src/model/wrapper/responses.rs @@ -42,6 +42,26 @@ pub struct ResponsesWrapper { pub public_name: String, pub responses: Arc, pub system_prompt: Option, + pub defaults: Vec<(String, serde_json::Value)>, +} + +macro_rules! set_default_w_warning { + ($req:expr, $field:ident, $value:expr, $model:expr) => { + $req.$field = $req + .$field + .or_else(|| match serde_json::from_value($value.clone()) { + Ok(val) => Some(val), + Err(_) => { + tracing::warn!( + "Failed to parse Responses API `{}` override for model='{}'. Ensure {:?} is of the correct format.", + stringify!($field), + $model, + $value + ); + None + } + }) + }; } impl ResponsesWrapper { @@ -49,16 +69,18 @@ impl ResponsesWrapper { responses: Arc, public_name: &str, system_prompt: Option<&str>, + defaults: Vec<(String, serde_json::Value)>, ) -> Self { Self { public_name: public_name.to_string(), responses, system_prompt: system_prompt.map(ToString::to_string), + defaults, } } fn prepare_req(&self, req: CreateResponse) -> CreateResponse { - self.with_system_prompt(req) + self.with_model_defaults(self.with_system_prompt(req)) } /// Injects a system prompt into the instructions field in the request, if it exists. @@ -72,6 +94,21 @@ impl ResponsesWrapper { } req } + + fn with_model_defaults(&self, mut req: CreateResponse) -> CreateResponse { + for (key, value) in &self.defaults { + match key.as_str() { + "prompt_cache_key" => { + set_default_w_warning!(req, prompt_cache_key, value, self.public_name); + } + "prompt_cache_retention" => { + set_default_w_warning!(req, prompt_cache_retention, value, self.public_name); + } + _ => tracing::debug!("Ignoring unknown Responses API default key: {key}"), + } + } + req + } } #[async_trait] @@ -297,7 +334,7 @@ impl Drop for TracedResponseStream { #[cfg(test)] mod tests { use super::*; - use async_openai::types::responses::CreateResponse; + use async_openai::types::responses::{CreateResponse, PromptCacheRetention}; /// Helper to create a [`ResponsesWrapper`] with the given system prompt (no underlying model needed for `with_system_prompt` tests). fn wrapper_with_prompt(prompt: Option<&str>) -> ResponsesWrapper { @@ -305,6 +342,7 @@ mod tests { public_name: "test-model".to_string(), responses: Arc::new(NoopResponses), system_prompt: prompt.map(ToString::to_string), + defaults: Vec::new(), } } @@ -395,4 +433,35 @@ mod tests { "Instructions should remain None when neither is set" ); } + + #[test] + fn test_prompt_cache_defaults_preserve_request_values() { + let wrapper = ResponsesWrapper { + public_name: "test-model".to_string(), + responses: Arc::new(NoopResponses), + system_prompt: None, + defaults: vec![ + ( + "prompt_cache_key".to_string(), + serde_json::Value::String("default-key".to_string()), + ), + ( + "prompt_cache_retention".to_string(), + serde_json::Value::String("24h".to_string()), + ), + ], + }; + + let req = CreateResponse { + prompt_cache_key: Some("request-key".to_string()), + ..CreateResponse::default() + }; + let result = wrapper.with_model_defaults(req); + + assert_eq!(result.prompt_cache_key.as_deref(), Some("request-key")); + assert_eq!( + result.prompt_cache_retention, + Some(PromptCacheRetention::Hours24) + ); + } } diff --git a/docs/README.md b/docs/README.md index d55efcc93a..fd6d2ddd8f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,11 +2,129 @@ The project docs for contributors and community. For user documentation of the Spice.ai platform, see [spiceai.org/docs](https://spiceai.org/docs). +## Core Docs + - [Principles](PRINCIPLES.md) - [Roadmap](ROADMAP.md) - [Distributions](DISTRIBUTIONS.md) +- [Extensibility](EXTENSIBILITY.md) ## Contributing - [CONTRIBUTING.md](../CONTRIBUTING.md) - [Release Process](RELEASE.md) + +## Developer Notes + +- [Cosmos DB](dev/cosmosdb.md) +- [Error Handling](dev/error_handling.md) +- [LLM Prompt Caching](dev/llm-prompt-caching.md) +- [Metrics](dev/metrics.md) +- [Snapshot Tests](dev/snapshot_tests.md) +- [Rust Style Guide](dev/style_guide.md) + +## Standard Operating Procedures + +- [Upgrade mistral.rs](dev/sop-upgrade-mistral-rs.md) +- [Upgrade Text Embeddings Inference](dev/sop-upgrade-tei.md) + +## Criteria + +- [Criteria Principles](criteria/PRINCIPLES.md) +- [Criteria Definitions](criteria/definitions.md) + +### Accelerators + +- [Alpha](criteria/accelerators/alpha.md) +- [Beta](criteria/accelerators/beta.md) +- [Release Candidate](criteria/accelerators/rc.md) +- [Stable](criteria/accelerators/stable.md) + +### Catalogs + +- [Alpha](criteria/catalogs/alpha.md) +- [Beta](criteria/catalogs/beta.md) +- [Release Candidate](criteria/catalogs/rc.md) +- [Stable](criteria/catalogs/stable.md) + +### Connectors + +- [Alpha](criteria/connectors/alpha.md) +- [Beta](criteria/connectors/beta.md) +- [Release Candidate](criteria/connectors/rc.md) +- [Stable](criteria/connectors/stable.md) + +### Embeddings + +- [Alpha](criteria/embeddings/alpha.md) +- [Beta](criteria/embeddings/beta.md) +- [Release Candidate](criteria/embeddings/rc.md) +- [Stable](criteria/embeddings/stable.md) + +### Features + +- [Overview](criteria/features/README.md) +- [Alpha](criteria/features/alpha.md) +- [Beta](criteria/features/beta.md) +- [Release Candidate](criteria/features/rc.md) +- [Stable](criteria/features/stable.md) + +### Models + +- [Grading](criteria/models/grading.md) +- [Alpha](criteria/models/alpha.md) +- [Beta](criteria/models/beta.md) +- [Release Candidate](criteria/models/rc.md) +- [Stable](criteria/models/stable.md) + +## Architecture Decisions + +- [001: Use snmalloc as Global Allocator](decisions/001-use-snmalloc-as-global-allocator.md) +- [002: Default Ports](decisions/002-default-ports.md) +- [003: Duration Milliseconds](decisions/003-duration-ms.md) +- [004: Distributed Query Framework](decisions/004-distributed-query-framework.md) +- [005: Ballista Extensions](decisions/005-ballista-extensions.md) +- [006: High-Availability Distributed Query](decisions/006-ha-distributed-query.md) +- [007: Cluster mTLS](decisions/007-cluster-mtls.md) + +## Examples + +- [HTTP Refresh SQL Example](examples/http_refresh_sql_example.md) +- [Turso Acceleration Example](examples/turso_acceleration_example.md) + +## Feature Notes + +- [Databricks Resilience](features/databricks-resilience.md) +- [DuckDB Index Scan Settings](features/duckdb_index_scan_settings.md) +- [GCS Connector](features/gcs-connector.md) +- [Git Connector](features/git-connector.md) +- [Postgres Replication](features/postgres-replication.md) +- [Schema Decomposition](features/schema-decomposition.md) + +## Threat Models + +- [v1.9.2](threat_models/v1.9.2.md) +- [v2.0.0](threat_models/v2.0.0.md) +- [v0.17.4-beta JSON](threat_models/v0.17.4-beta.json) +- [v1.9.1 JSON](threat_models/v1.9.1.json) + +## Release Notes + +Release notes are stored in [release_notes](release_notes/). Use the series directories for older releases: + +- [Alpha](release_notes/alpha/) +- [Beta](release_notes/beta/) +- [Release Candidate](release_notes/rc/) +- [v1.0](release_notes/v1.0/) +- [v1.1](release_notes/v1.1/) +- [v1.2](release_notes/v1.2/) +- [v1.3](release_notes/v1.3/) +- [v1.4](release_notes/v1.4/) +- [v1.5](release_notes/v1.5/) +- [v1.6](release_notes/v1.6/) +- [v1.7](release_notes/v1.7/) +- [v1.8](release_notes/v1.8/) +- [v1.9](release_notes/v1.9/) +- [v1.10](release_notes/v1.10/) + +Recent v1.11 and v2.0 release notes are at the top level of [release_notes](release_notes/). diff --git a/docs/dev/llm-prompt-caching.md b/docs/dev/llm-prompt-caching.md new file mode 100644 index 0000000000..754644fc05 --- /dev/null +++ b/docs/dev/llm-prompt-caching.md @@ -0,0 +1,60 @@ +# LLM Prompt Caching + +This note documents how prompt-cache intent flows through Spice's LLM runtime and provider adapters. It is internal maintainer guidance, not a user-facing API reference. + +## Core Model + +`prompt_cache_key` and `prompt_cache_retention` describe cache intent. They do not describe a portable serialized KV-cache format, and Spice should not attempt to store live LLM KV tensors in the Arrow accelerator or hash indexes. KV cache state is model-engine, provider, device, and scheduler specific, so it belongs in the provider or local model engine. + +Provider adapters should map cache intent to the provider-native mechanism when one exists. If a provider does not support explicit prompt caching, preserve provider correctness and request semantics rather than fabricating cache behavior. + +## Runtime Entry Points + +- Chat model defaults are collected in `crates/runtime/src/model/chat.rs` and applied by `crates/runtime/src/model/wrapper/mod.rs`. +- Responses model defaults are collected in `crates/runtime/src/model/responses.rs` and applied by `crates/runtime/src/model/wrapper/responses.rs`. +- Model parameter specs live in `crates/runtime/src/model/params/mod.rs`; update the parameter counts when adding entries. +- `/v1/nsql` accepts `prompt_cache_key` and forwards it to the configured NSQL chat model in `crates/runtime/src/http/v1/nsql.rs`. + +Defaults must not override request-provided values. Keep default parsing on the request path where warnings can identify the model and field that failed to parse. + +## Provider Mappings + +| Provider path | Cache mapping | +| ------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| OpenAI-compatible chat and Azure chat | Pass `prompt_cache_key` through the OpenAI-compatible chat request field. | +| OpenAI-compatible Responses | Pass `prompt_cache_key` and `prompt_cache_retention` through the Responses request fields. | +| xAI chat | Move `prompt_cache_key` out of the request body and send it as the `x-grok-conv-id` request header. | +| xAI Responses | Leave Responses fields in the request body. | +| Google Gemini | Map `prompt_cache_key` to `GenerateContentRequest.cached_content.name`; callers must provide a valid cached-content resource name. | +| Anthropic | Set top-level ephemeral `cache_control` when cache intent is present and preserve cache usage fields in OpenAI-compatible usage. | +| Bedrock Converse | Append a native `CachePoint` to the last message, or to system content when no messages exist. | +| Databricks hosted Claude | Use BYOT JSON and add Claude-style `cache_control` to the last text content part. | +| Local HuggingFace/file models | Use `mistral-rs` native KV cache and paged-attention scheduling when the backend and pipeline support it. Request-level cache keys are not portable to local tensors. | + +## Usage Accounting + +When providers return cache-token usage, keep totals data-correct: + +- Include provider-reported cache creation and cache read input tokens in `prompt_tokens` and `total_tokens` when the provider's accounting reports them separately from normal input tokens. +- Populate `prompt_tokens_details.cached_tokens` only with cache-read tokens. Cache-creation tokens are prompt work, but they were not read from cache. +- Do not invent cached-token counts when a provider omits them. + +## Local Model Notes + +`mistral-rs` regular KV caching is enabled by keeping `no_kv_cache` false. Paged attention is requested only on supported CUDA Unix backends, then used only if the loaded pipeline exposes cache metadata. If metadata is missing, fall back to the default scheduler and log at debug level. Metal keeps the default scheduler because the current paged-attention path can panic in the underlying Metal kernels. + +Keep scheduler constants non-zero at compile time; do not use `unwrap` or `expect` in the production scheduler path. + +## Maintenance Checklist + +When changing prompt caching behavior: + +- Add or update provider-specific unit tests in `crates/llms`. +- Add runtime wrapper/default extraction tests in `crates/runtime` when model parameters or defaults change. +- Ensure NSQL forwards only cache intent, not provider-specific behavior. +- Keep provider mappings explicit; avoid one-size-fits-all request mutation across providers. +- Do not store live KV tensors in data accelerators, Arrow arrays, or hash indexes. +- Run `cargo fmt --all`. +- Run `cargo test -p llms prompt_cache --features local_llm`. +- Run `cargo test -p runtime prompt_cache --features models`. +- Run `make lint` before resolving review threads or handing off the PR. \ No newline at end of file From 5b1b824e0cb8da60687a2c8ff939a0279020e9df Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 4 May 2026 22:28:50 +0300 Subject: [PATCH 6/6] Fix DuckDB HNSW vector indexes lost after data refresh (#10668) * Fix DuckDB HNSW vector indexes lost after data refresh * Improve * Improve tracing * Fix lint --- crates/runtime/src/accelerated_table/mod.rs | 18 +- .../src/accelerated_table/sink/table.rs | 7 +- crates/runtime/src/dataconnector/mod.rs | 19 +- crates/runtime/src/datafusion/mod.rs | 5 + crates/runtime/src/embeddings/connector.rs | 28 +- .../runtime/src/search/full_text/connector.rs | 12 +- crates/runtime/tests/models/hnsw_index.rs | 437 ++++++++++++++++++ crates/runtime/tests/models/mod.rs | 2 + .../test_data/mega-science-sample.jsonl | 5 + crates/search/src/index/duckdb.rs | 24 +- 10 files changed, 537 insertions(+), 20 deletions(-) create mode 100644 crates/runtime/tests/models/hnsw_index.rs create mode 100644 crates/runtime/tests/models/test_data/mega-science-sample.jsonl diff --git a/crates/runtime/src/accelerated_table/mod.rs b/crates/runtime/src/accelerated_table/mod.rs index 0adccfd080..0ce8a4934e 100644 --- a/crates/runtime/src/accelerated_table/mod.rs +++ b/crates/runtime/src/accelerated_table/mod.rs @@ -448,6 +448,20 @@ impl Builder { self } + /// Returns a clone of the accelerator `Arc`. + #[must_use] + pub fn get_accelerator(&self) -> Arc { + Arc::clone(&self.accelerator) + } + + /// Replace the accelerator provider. + /// + /// This must be called **before** [`build`](Self::build) so that the + /// refresher (created during build) receives the updated provider. + pub fn set_accelerator(&mut self, accelerator: Arc) { + self.accelerator = accelerator; + } + /// Set to only write to the accelerator (not replicate to federated source). /// This is used when `on_conflict` is configured - writes go only to the accelerator. pub fn write_to_accelerator_only(&mut self) -> &mut Self { @@ -1130,10 +1144,6 @@ impl AcceleratedTable { &self.accelerator } - pub(crate) fn set_accelerator(&mut self, accelerator: Arc) { - self.accelerator = accelerator; - } - /// Add a child accelerator that should receive cached data when this parent stores new cache entries. /// This is used for localpod caching synchronization. pub async fn add_synchronized_child(&self, child_accelerator: Arc) { diff --git a/crates/runtime/src/accelerated_table/sink/table.rs b/crates/runtime/src/accelerated_table/sink/table.rs index db37889833..cd665cc97d 100644 --- a/crates/runtime/src/accelerated_table/sink/table.rs +++ b/crates/runtime/src/accelerated_table/sink/table.rs @@ -129,7 +129,12 @@ impl TableSink { // Uses IF NOT EXISTS semantics: creates index after overwrite (new table), // no-op after append (index already exists). CDC skips this path entirely. if let Some(indexed) = provider.as_any().downcast_ref::() { - for index in indexed.get_all_indexes() { + let indexes = indexed.get_all_indexes(); + tracing::debug!( + index_names = ?indexes.iter().map(|i| i.name()).collect::>(), + "Running on_write_complete for indexes" + ); + for index in indexes { if let Err(e) = index.on_write_complete().await { tracing::warn!( "TableSink: on_write_complete failed for index '{}': {e}. Index may be stale until next refresh.", diff --git a/crates/runtime/src/dataconnector/mod.rs b/crates/runtime/src/dataconnector/mod.rs index 8e21b12438..709b6ef37d 100644 --- a/crates/runtime/src/dataconnector/mod.rs +++ b/crates/runtime/src/dataconnector/mod.rs @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -use crate::accelerated_table::AcceleratedTable; +use crate::accelerated_table::{self, AcceleratedTable}; use crate::component::ComponentInitialization; use crate::component::catalog::Catalog; use crate::component::dataset::Dataset; @@ -657,6 +657,23 @@ pub trait DataConnector: Debug + Send + Sync + 'static { Ok(()) } + /// A hook called **before** the accelerated table is built, giving the + /// connector a chance to wrap or replace the accelerator provider on the + /// [`Builder`](crate::accelerated_table::Builder). + /// + /// Any provider set here will be shared with the [`Refresher`] that is + /// created during [`Builder::build`]. Use this hook instead of + /// [`on_accelerated_table_registration`](Self::on_accelerated_table_registration) + /// when the wrapped provider must be visible to the refresh pipeline + /// (e.g. to recreate indexes after a data refresh). + async fn on_accelerator_setup( + &self, + _dataset: &Dataset, + _builder: &mut accelerated_table::Builder, + ) -> Result<(), Box> { + Ok(()) + } + /// A hook that is called when an accelerated table is registered to the /// `DataFusion` context for this data connector. /// diff --git a/crates/runtime/src/datafusion/mod.rs b/crates/runtime/src/datafusion/mod.rs index a6ec0bc5b5..7fa57d15a5 100644 --- a/crates/runtime/src/datafusion/mod.rs +++ b/crates/runtime/src/datafusion/mod.rs @@ -2287,6 +2287,11 @@ impl DataFusion { let is_s3_express_acceleration = false; accelerated_table_builder.s3_express_acceleration(is_s3_express_acceleration); + source + .on_accelerator_setup(dataset, &mut accelerated_table_builder) + .await + .context(AccelerationRegistrationSnafu)?; + accelerated_table_builder .build() .await diff --git a/crates/runtime/src/embeddings/connector.rs b/crates/runtime/src/embeddings/connector.rs index d7ca603f86..2a3bf91f23 100644 --- a/crates/runtime/src/embeddings/connector.rs +++ b/crates/runtime/src/embeddings/connector.rs @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -use crate::accelerated_table::AcceleratedTable; +use crate::accelerated_table::{self, AcceleratedTable}; use crate::changes::Indexes; use crate::changes::index_change_envelope; use crate::component::ComponentInitialization; @@ -260,13 +260,13 @@ impl DataConnector for EmbeddingConnector { self.inner_connector.metrics_provider() } - async fn on_accelerated_table_registration( + async fn on_accelerator_setup( &self, dataset: &Dataset, - accelerated_table: &mut AcceleratedTable, + builder: &mut accelerated_table::Builder, ) -> Result<(), Box> { self.inner_connector - .on_accelerated_table_registration(dataset, accelerated_table) + .on_accelerator_setup(dataset, builder) .await?; #[cfg(feature = "duckdb")] @@ -276,7 +276,13 @@ impl DataConnector for EmbeddingConnector { return Ok(()); } - let accelerator = accelerated_table.get_accelerator(); + tracing::debug!( + dataset = %dataset.name, + columns = ?embedding_columns.iter().map(|(col, _)| col.as_str()).collect::>(), + "Wrapping accelerator with DuckDB HNSW vector indexes" + ); + + let accelerator = builder.get_accelerator(); let indexed_accelerator = crate::embeddings::index::duckdb::wrap_accelerator_with_duckdb_vector_indexes( &dataset.name, @@ -287,12 +293,22 @@ impl DataConnector for EmbeddingConnector { Arc::clone(&self.secrets), ) .await?; - accelerated_table.set_accelerator(indexed_accelerator); + builder.set_accelerator(indexed_accelerator); } Ok(()) } + async fn on_accelerated_table_registration( + &self, + dataset: &Dataset, + accelerated_table: &mut AcceleratedTable, + ) -> Result<(), Box> { + self.inner_connector + .on_accelerated_table_registration(dataset, accelerated_table) + .await + } + fn supports_changes_stream(&self) -> bool { self.inner_connector.supports_changes_stream() } diff --git a/crates/runtime/src/search/full_text/connector.rs b/crates/runtime/src/search/full_text/connector.rs index 80154bc5d2..8117c24b02 100644 --- a/crates/runtime/src/search/full_text/connector.rs +++ b/crates/runtime/src/search/full_text/connector.rs @@ -20,7 +20,7 @@ use runtime_datafusion_index::IndexedTableProvider; use std::any::Any; use std::sync::Arc; -use crate::accelerated_table::AcceleratedTable; +use crate::accelerated_table::{self, AcceleratedTable}; use crate::changes::{Indexes, index_change_envelope}; use crate::component::{ ComponentInitialization, @@ -141,6 +141,16 @@ impl DataConnector for FullTextConnector { self.inner_connector.metrics_provider() } + async fn on_accelerator_setup( + &self, + dataset: &Dataset, + builder: &mut accelerated_table::Builder, + ) -> Result<(), Box> { + self.inner_connector + .on_accelerator_setup(dataset, builder) + .await + } + async fn on_accelerated_table_registration( &self, dataset: &Dataset, diff --git a/crates/runtime/tests/models/hnsw_index.rs b/crates/runtime/tests/models/hnsw_index.rs new file mode 100644 index 0000000000..43860dea33 --- /dev/null +++ b/crates/runtime/tests/models/hnsw_index.rs @@ -0,0 +1,437 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Integration tests verifying `DuckDB` HNSW vector indexes. +//! +//! Uses a native `DuckDbConnectionPool` (post-shutdown) to query `duckdb_indexes()` +//! and confirm that the HNSW index exists on the correct underlying table. + +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; + +use anyhow::Context as _; +use app::AppBuilder; +use arrow::array::{RecordBatch, StringArray}; +use datafusion::sql::TableReference; +use datafusion_table_providers::sql::db_connection_pool::DbConnectionPool; +use datafusion_table_providers::sql::db_connection_pool::duckdbpool::DuckDbConnectionPool; +use duckdb::AccessMode; +use futures::TryStreamExt; +use runtime::Runtime; +use runtime::auth::EndpointAuth; +use spicepod::acceleration::{Acceleration, Mode, RefreshMode}; +use spicepod::component::dataset::Dataset; +use spicepod::component::embeddings::Embeddings; +use spicepod::param::Params; +use spicepod::semantic::{Column, ColumnLevelEmbeddingConfig}; +use spicepod::vector::VectorStore; +use tokio::sync::Mutex; + +use crate::models::create_api_bindings_config; +use crate::utils::{register_test_connectors, runtime_ready_check, test_request_context}; +use crate::{configure_test_datafusion, init_tracing}; + +/// Serializes HNSW tests because `Runtime::shutdown()` calls `unregister_all()`, +/// which clears the global connector registry and breaks parallel tests. +static HNSW_TEST_MUTEX: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn cleanup_db_path(db_path: &str) { + for suffix in ["", ".wal"] { + let path = format!("{db_path}{suffix}"); + if std::path::Path::new(&path).exists() { + let _ = std::fs::remove_file(&path); + } + } +} + +fn model2vec_embedding() -> Embeddings { + Embeddings::new("model2vec:minishlab/potion-base-2M", "test_embed") +} + +fn hnsw_dataset(name: &str, db_path: &str, refresh_mode: RefreshMode) -> Dataset { + let mut dataset = Dataset::new( + "s3://spiceai-public-datasets/MegaScience/mega-science-small.jsonl", + name, + ); + dataset.params = Some(Params::from_string_map( + vec![("client_timeout".to_string(), "120s".to_string())] + .into_iter() + .collect(), + )); + + let accel_params: HashMap = + HashMap::from([("duckdb_file".to_string(), db_path.to_string())]); + // Don't set HNSW params on acceleration — they go in vectors.params + dataset.acceleration = Some(Acceleration { + enabled: true, + engine: Some("duckdb".to_string()), + mode: Mode::File, + refresh_mode: Some(refresh_mode), + refresh_sql: Some(format!("SELECT * FROM {name} LIMIT 64")), + params: Some(Params::from_string_map(accel_params)), + ..Acceleration::default() + }); + + dataset.vectors = Some(VectorStore { + enabled: true, + engine: Some("duckdb".to_string()), + partition_by: Vec::new(), + params: Some(Params::from_string_map(HashMap::from([ + ("duckdb_distance_metric".to_string(), "cosine".to_string()), + ("duckdb_hnsw_m".to_string(), "8".to_string()), + ("duckdb_hnsw_ef_construction".to_string(), "24".to_string()), + ("duckdb_hnsw_ef_search".to_string(), "12".to_string()), + ]))), + }); + + dataset.columns = vec![ + Column::new("question") + .with_embedding(ColumnLevelEmbeddingConfig::model("test_embed").with_row_id("id")), + ]; + + dataset +} + +/// Start a runtime with the given app, wait for components to load, and return the runtime. +async fn start_runtime(app: app::App) -> Arc { + register_test_connectors().await; + configure_test_datafusion(); + + let rt = Arc::new(Runtime::builder().with_app(app).build().await); + + let api_config = create_api_bindings_config(); + let rt_ref = Arc::clone(&rt); + tokio::spawn(async move { + Box::pin(rt_ref.start_servers(api_config, None, EndpointAuth::no_auth())).await + }); + + tokio::select! { + () = tokio::time::sleep(std::time::Duration::from_secs(120)) => { + panic!("Timed out waiting for components to load"); + } + () = Arc::clone(&rt).load_components() => {} + } + + runtime_ready_check(&rt).await; + rt +} + +/// Trigger a manual refresh and wait for it to complete. +async fn refresh_table(rt: &Arc, table_name: &str) -> Result<(), anyhow::Error> { + let notifier = rt + .datafusion() + .refresh_table(&TableReference::from(table_name), None) + .await?; + notifier + .ok_or_else(|| anyhow::anyhow!("No refresh notifier returned for {table_name}"))? + .notified() + .await; + Ok(()) +} + +/// Run a SQL query through the runtime's `DataFusion` context. +async fn execute_sql(rt: &Arc, sql: &str) -> Result, anyhow::Error> { + let mut result = rt.datafusion().query_builder(sql).build().run().await?; + let mut batches = Vec::new(); + while let Some(batch) = futures::StreamExt::next(&mut result.data).await { + batches.push(batch?); + } + Ok(batches) +} + +/// Open a native `DuckDB` connection and return HNSW index info. +/// Must be called after the runtime is fully shut down and dropped. +async fn query_native_duckdb_indexes( + db_path: &str, +) -> Result, anyhow::Error> { + let pool = + DuckDbConnectionPool::new_file(db_path, &AccessMode::ReadWrite).expect("valid DuckDB path"); + let conn_dyn = pool.connect().await.expect("valid connection"); + let conn = conn_dyn.as_sync().expect("sync connection"); + + let batches: Vec = conn + .query_arrow( + "SELECT index_name, table_name FROM duckdb_indexes() WHERE index_name LIKE '__spice_vss_%'", + &[], + None, + ) + .expect("index query executes") + .try_collect::>() + .await + .expect("collects results"); + + let mut results = Vec::new(); + for batch in &batches { + let index_names = batch + .column(0) + .as_any() + .downcast_ref::() + .context("index_name column")?; + let table_names = batch + .column(1) + .as_any() + .downcast_ref::() + .context("table_name column")?; + for i in 0..batch.num_rows() { + results.push(( + index_names.value(i).to_string(), + table_names.value(i).to_string(), + )); + } + } + Ok(results) +} + +/// Verifies HNSW index exists after initial load and survives a full (overwrite) refresh. +/// After shutdown, queries the `DuckDB` file directly to confirm the index is on the correct +/// internal data table. +#[tokio::test] +async fn test_hnsw_index_created_after_full_refresh() -> Result<(), anyhow::Error> { + let _test_lock = HNSW_TEST_MUTEX.lock().await; + let _tracing = init_tracing(Some( + "integration_models=debug,runtime=debug,search=debug,info", + )); + + let db_path = "./test_hnsw_refresh.db"; + let ds_name = "hnsw_test_ds"; + + cleanup_db_path(db_path); + + test_request_context() + .scope(async { + let app = AppBuilder::new("hnsw_index_refresh_test") + .with_embedding(model2vec_embedding()) + .with_dataset(hnsw_dataset(ds_name, db_path, RefreshMode::Full)) + .build(); + + let rt = start_runtime(app).await; + + // 1. Verify vector search works after initial load + let batches = execute_sql( + &rt, + &format!( + "SELECT id, _score FROM vector_search({ds_name}, 'second') ORDER BY _score DESC LIMIT 4" + ), + ) + .await?; + let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); + anyhow::ensure!( + total_rows == 4, + "Expected 4 rows from initial vector_search, got {total_rows}" + ); + tracing::info!("Initial vector search returned {total_rows} rows"); + + // 2. Trigger a manual full refresh (overwrite) — this destroys and recreates the + // underlying DuckDB table. The HNSW index must be recreated afterward. + refresh_table(&rt, ds_name).await?; + + // 3. Verify vector search STILL works after refresh + let batches = execute_sql( + &rt, + &format!( + "SELECT id, _score FROM vector_search({ds_name}, 'second') ORDER BY _score DESC LIMIT 4" + ), + ) + .await?; + let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); + anyhow::ensure!( + total_rows == 4, + "Expected 4 rows from post-refresh vector_search, got {total_rows}" + ); + tracing::info!("Post-refresh vector search returned {total_rows} rows"); + + // 4. Shutdown runtime and verify index via native DuckDB connection + rt.shutdown().await; + drop(rt); + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + + let indexes = query_native_duckdb_indexes(db_path).await?; + tracing::info!("Native DuckDB indexes: {indexes:?}"); + + anyhow::ensure!( + !indexes.is_empty(), + "Expected at least one __spice_vss_ HNSW index in DuckDB file after refresh" + ); + + // The index should be on an internal data table (__data__), + // not on the view name directly + for (index_name, table_name) in &indexes { + anyhow::ensure!( + index_name.contains("question_embedding"), + "Index name {index_name} should reference question_embedding column" + ); + tracing::info!( + "Verified HNSW index {index_name} on table {table_name}" + ); + } + + cleanup_db_path(db_path); + Ok(()) + }) + .await +} + +/// Verifies HNSW index is created after initial append refresh. +/// Uses a local JSONL file (with a `created_at` timestamp for `time_column`). +#[tokio::test] +async fn test_hnsw_index_created_after_append_refresh() -> Result<(), anyhow::Error> { + let _test_lock = HNSW_TEST_MUTEX.lock().await; + let _tracing = init_tracing(Some( + "integration_models=debug,runtime=debug,search=debug,info", + )); + + let db_path = "./test_hnsw_append_refresh.db"; + let ds_name = "hnsw_append_ds"; + + cleanup_db_path(db_path); + + let test_data = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests/models/test_data/mega-science-sample.jsonl"); + let source = format!("file://{}", test_data.display()); + + test_request_context() + .scope(async { + let mut dataset = hnsw_dataset(ds_name, db_path, RefreshMode::Append); + dataset.from = source; + dataset.time_column = Some("created_at".to_string()); + dataset.time_format = Some(spicepod::component::dataset::TimeFormat::ISO8601); + dataset.params = None; // Remove client_timeout, not supported for file connector + + let app = AppBuilder::new("hnsw_append_refresh_test") + .with_embedding(model2vec_embedding()) + .with_dataset(dataset) + .build(); + + let rt = start_runtime(app).await; + + // Verify vector search works after initial append load + let batches = execute_sql( + &rt, + &format!( + "SELECT id, _score FROM vector_search({ds_name}, 'second') ORDER BY _score DESC LIMIT 4" + ), + ) + .await?; + let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); + anyhow::ensure!( + total_rows == 4, + "Expected 4 rows from vector_search after append refresh, got {total_rows}" + ); + tracing::info!("Append refresh vector search returned {total_rows} rows"); + + // Shutdown runtime and verify index via native DuckDB connection + rt.shutdown().await; + drop(rt); + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + + let indexes = query_native_duckdb_indexes(db_path).await?; + tracing::info!("Native DuckDB indexes after append refresh: {indexes:?}"); + + anyhow::ensure!( + !indexes.is_empty(), + "Expected at least one __spice_vss_ HNSW index after append refresh" + ); + + for (index_name, table_name) in &indexes { + anyhow::ensure!( + index_name.contains("question_embedding"), + "Index name {index_name} should reference question_embedding column" + ); + tracing::info!( + "Verified HNSW index {index_name} on table {table_name}" + ); + } + + cleanup_db_path(db_path); + Ok(()) + }) + .await +} + +/// Verifies HNSW index survives multiple consecutive full refreshes. +#[tokio::test] +async fn test_hnsw_index_survives_multiple_refreshes() -> Result<(), anyhow::Error> { + let _test_lock = HNSW_TEST_MUTEX.lock().await; + let _tracing = init_tracing(Some( + "integration_models=debug,runtime=debug,search=debug,info", + )); + + let db_path = "./test_hnsw_multi_refresh.db"; + let ds_name = "hnsw_multi_ds"; + + cleanup_db_path(db_path); + + test_request_context() + .scope(async { + let app = AppBuilder::new("hnsw_multi_refresh_test") + .with_embedding(model2vec_embedding()) + .with_dataset(hnsw_dataset(ds_name, db_path, RefreshMode::Full)) + .build(); + + let rt = start_runtime(app).await; + + // Initial vector search + let batches = execute_sql( + &rt, + &format!("SELECT id FROM vector_search({ds_name}, 'second') LIMIT 4"), + ) + .await?; + let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); + anyhow::ensure!(total_rows == 4, "Initial search failed: {total_rows} rows"); + + // Refresh 3 times + for i in 1..=3 { + refresh_table(&rt, ds_name).await?; + let batches = execute_sql( + &rt, + &format!("SELECT id FROM vector_search({ds_name}, 'second') LIMIT 4"), + ) + .await?; + let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum(); + anyhow::ensure!( + total_rows == 4, + "Vector search failed after refresh #{i}: {total_rows} rows" + ); + tracing::info!("Refresh #{i}: vector search OK ({total_rows} rows)"); + } + + // Shutdown runtime and verify native DuckDB indexes + rt.shutdown().await; + drop(rt); + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + + let indexes = query_native_duckdb_indexes(db_path).await?; + anyhow::ensure!(!indexes.is_empty(), "Expected HNSW index after 3 refreshes"); + + // There should be exactly one HNSW index (on the latest internal table). + // Old internal tables are dropped when the view is swapped. + anyhow::ensure!( + indexes.len() == 1, + "Expected exactly 1 HNSW index after multiple refreshes, found {}", + indexes.len() + ); + + tracing::info!( + "Verified single HNSW index after 3 refreshes: {:?}", + indexes[0] + ); + + cleanup_db_path(db_path); + + Ok(()) + }) + .await +} diff --git a/crates/runtime/tests/models/mod.rs b/crates/runtime/tests/models/mod.rs index cec2f60779..e3201211e8 100644 --- a/crates/runtime/tests/models/mod.rs +++ b/crates/runtime/tests/models/mod.rs @@ -39,6 +39,8 @@ mod ai_udf; mod bedrock; mod embedding; pub(crate) mod hf; +#[cfg(feature = "duckdb")] +mod hnsw_index; mod local; mod models_http_endpoint; pub(crate) mod openai; diff --git a/crates/runtime/tests/models/test_data/mega-science-sample.jsonl b/crates/runtime/tests/models/test_data/mega-science-sample.jsonl new file mode 100644 index 0000000000..82555c6a69 --- /dev/null +++ b/crates/runtime/tests/models/test_data/mega-science-sample.jsonl @@ -0,0 +1,5 @@ +{"id": 1, "question": "$7-5=$ since $+5=7$.", "answer": "$7-5=2$ since $2+5=7$.", "subject": "math", "reference_answer": "2", "source": "textbook_reasoning", "created_at": "2025-01-01T00:00:00+00:00"} +{"id": 2, "question": "(3):** \nProve that if \\( H \\) is an \\( n \\times n \\) Hadamard matrix, then \\( \\det(H) = \\pm n^{n/2} \\).", "answer": "(3):** \nFrom the previous result, \\( \\det(H)^2 = n^n \\), so:\n\\[\n\\det(H) = \\pm n^{n/2}.\n\\]", "subject": "math", "reference_answer": "$\\det(H) = \\pm n^{n/2}$", "source": "textbook_reasoning", "created_at": "2025-01-01T01:00:00+00:00"} +{"id": 3, "question": "(Problem 12):** \nIf \\(p = \\sqrt{d}\\), then express \\(d\\) in terms of \\(p\\).", "answer": "(Problem 12):** \nIf \\(p = \\sqrt{d}\\), then squaring both sides gives \\(d = p^2\\). \nThus, \\(d = \\boxed{p^2}\\).", "subject": "math", "reference_answer": "p^2", "source": "textbook_reasoning", "created_at": "2025-01-01T02:00:00+00:00"} +{"id": 4, "question": "17. Evaluate the function \\( F(x) = 4x \\) at \\( x = 3 \\), \\( x = -4 \\), and \\( x = 0 \\).", "answer": "17. \n- \\( F(3) = 4(3) = \\boxed{12} \\)\n- \\( F(-4) = 4(-4) = \\boxed{-16} \\)\n- \\( F(0) = 4(0) = \\boxed{0} \\)", "subject": "math", "reference_answer": "0", "source": "textbook_reasoning", "created_at": "2025-01-01T03:00:00+00:00"} +{"id": 5, "question": "46.** $63 - 1.03$", "answer": "To compute $63 - 1.03$, subtract the two numbers directly:\n\\[\n63.00 - 1.03 = 61.97\n\\]\nThus, the result is $\\boxed{61.97}$.", "subject": "math", "reference_answer": "61.97", "source": "textbook_reasoning", "created_at": "2025-01-01T04:00:00+00:00"} diff --git a/crates/search/src/index/duckdb.rs b/crates/search/src/index/duckdb.rs index 6d60a92228..2b962e7127 100644 --- a/crates/search/src/index/duckdb.rs +++ b/crates/search/src/index/duckdb.rs @@ -242,13 +242,19 @@ impl DuckDBVectorIndex { conn.execute("SET hnsw_enable_experimental_persistence = true", []) .map_err(to_execution_error)?; let index_name = DuckDBHnswOptions::index_name_for(table_name, &embedding_column); - conn.execute( - &self - .hnsw - .create_index_sql(table_name, &embedding_column, &index_name), - [], - ) - .map_err(to_execution_error)?; + let create_sql = self + .hnsw + .create_index_sql(table_name, &embedding_column, &index_name); + conn.execute(&create_sql, []).map_err(to_execution_error)?; + + tracing::debug!( + table = %table_name, + index = %index_name, + column = %embedding_column, + sql = %create_sql, + "HNSW index created successfully" + ); + Ok(()) } @@ -371,6 +377,10 @@ impl Index for DuckDBVectorIndex { /// init time; DuckDB VSS maintains it automatically on subsequent inserts. async fn on_write_complete(&self) -> DataFusionResult<()> { let Some(ctx) = &self.query_context else { + tracing::debug!( + column = %self.embedded_column, + "on_write_complete skipped: no query context for HNSW index" + ); return Ok(()); }; let index = self.clone();