From 4e679c0bbe9b93ff540c0de5e3d2f74d3802b9de Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 20 Apr 2026 21:10:13 -0700 Subject: [PATCH 1/4] feat: Add Azure Cosmos DB (NoSQL) data connector (RC) (#10392) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add Azure Cosmos DB (NoSQL) data connector (alpha) Adds a first-pass, read-only Azure Cosmos DB NoSQL / Core SQL API data connector built on the azure_data_cosmos 0.30 SDK. - New data_components::cosmosdb module with CosmosDBClient, CosmosDBTableProvider, and CosmosDBExec. - Cross-partition scan via Cosmos SQL (defaults to SELECT * FROM c). - Arrow schema inferred from a sample of documents; system fields (_rid/_self/_etag/_attachments/_ts) are stripped. - Key-based auth (connection string or account endpoint + key). - New 'cosmosdb' Cargo feature wired through data_components, runtime, and spiced (included in the default spiced distribution). - Makefile lint target + SPICED_DATA_FEATURES entry. - Docs stub and initial RC criteria table row. Alpha quality: no filter/projection push-down, no writes, no change feed, no Microsoft Entra ID auth yet. * refactor: Clean up code formatting and improve readability in Cosmos DB modules * fix: Simplify endpoint assignment in CosmosDBClient implementation * fix(cosmosdb): address review comments for Cosmos DB connector - Drop `.required()` from `database` param; falls back to path segment. - Validate `schema_infer_max_records`; warn-and-default on 0/non-integer. - Use prefixed param names (`cosmosdb_*`) in the auth error message. - Split path parsing into `parse_database_and_container` and reject empty db/container segments with a descriptive message. - Fix `infer_schema` doc comment to match actual empty-sample behavior. - Add unit tests for `parse_database_and_container`, `strip_system_fields`, and `infer_schema`. - Docs: use unprefixed `query` / `schema_infer_max_records` (runtime params). * fix(cosmosdb): address follow-up review comments - parse_database_and_container error now references the user-facing 'cosmosdb_database' parameter name. - execute() wraps Cosmos query errors with Error::QueryFailed so failures include the database/container context (matches fetch_samples). - Add unit tests for decode_batch covering multi-doc decoding, projection, empty input, and missing-field-fills-null behavior. * feat(cosmosdb): bring connector to RC quality Closes the remaining gaps against docs/criteria/connectors/rc.md for the Azure Cosmos DB (NoSQL) connector: - Connection resilience: per-account concurrency semaphore, exponential/fibonacci backoff, Retry-After / x-ms-retry-after-ms handling, permanent-error (401/403/404) detection that latches the connector disabled. New CosmosResilienceConfig in crates/data_components/src/cosmosdb/resilience.rs with 16 unit tests. - inflight_operations metric gauge via MetricsProvider, keyed off the account endpoint so datasets sharing an account aggregate cleanly. - New runtime parameters: max_concurrent_requests, http_max_retries, backoff_method, disable_on_permanent_error. - unsupported_type_action plumbing — all-null inferred columns (DataType::Null) are warn-dropped by default, with user-configurable Error/Ignore/String alternatives. - Integration-test scaffold at crates/runtime/tests/cosmosdb/ with an offline smoke test plus two #[ignore]'d live tests that read COSMOSDB_CONNECTION_STRING (or ENDPOINT+KEY) from the environment. - Cookbook recipe at examples/cosmosdb-connector/ (README + spicepod.yaml + queries.sql). - Docs: status flip to RC, JSON -> Arrow type mapping table, RC exceptions call-out, resilience parameter reference, integration-test run instructions. - Criteria tables: add Cosmos DB rows to alpha.md / beta.md (quality + feature matrix); flip rc.md row to DRI @lukekim. Test coverage: - 32 unit tests in data_components::cosmosdb (resilience, schema, provider incl. 1024-column max-width test, unsupported-type actions). - 12 unit tests in runtime::dataconnector::cosmosdb (shared_semaphore, shared_disabled_flag, path parsing). - 1 offline smoke + 2 ignored live tests in crates/runtime/tests/cosmosdb. * refactor(cosmosdb): simplify post-review nits - Redact CosmosDBCredential Debug impl (was deriving Debug on an enum holding raw account keys / connection strings; would surface secrets in any tracing::debug! or panic). - Drop CosmosDBMetrics wrapper struct; store the Arc inflight counter directly on CosmosDB and on CosmosDBMetricsProvider. Removes one level of Arc indirection. - Use azure_core::http::headers::X_MS_RETRY_AFTER_MS constant instead of hand-rolling HeaderName::from_static for the Cosmos-specific retry-ms header. - Swap Ordering::SeqCst -> Acquire/Release on the disabled latch; a single AtomicBool only needs acquire/release pairing, and the change is free on x86 but saves a fence on ARM. - Avoid allocating String endpoint on the streaming scan hot path; pass client.endpoint() (&str) directly to handle_stream_error. - Add eviction-note docstring to COSMOS_CONCURRENCY_LIMITS and COSMOS_DISABLED_FLAGS, matching the git-connector statics. - Trim narrating / RC-pointing doc comments from PARAMETERS, CosmosDBTableProviderConfig, and apply_unsupported_type_action. * perf(cosmosdb): decode JSON docs directly via arrow-json decoder Replace decode_batch's NDJSON round-trip (serialize each Value to bytes, push newline, re-parse through ReaderBuilder::build) with the serde-aware path: ReaderBuilder::build_decoder() + decoder.serialize(&docs) + decoder.flush(). Avoids the per-batch serialize -> parse overhead and matches the pattern used in crates/data_components/src/s3_vectors/ (list_provider.rs, query_provider.rs). Addresses Copilot review feedback on PR #10392. * docs(cosmosdb): clarify retry scope (schema vs streaming) http_max_retries / backoff_method apply to the schema-inference pass at dataset registration; mid-stream pager errors propagate directly because a FeedPager cannot be rewound once rows have been emitted. The permanent-error latch still fires on both paths. Addresses Copilot review feedback on PR #10392. * style(cosmosdb): apply cargo fmt * docs(cosmosdb): align doc comments with RC quality - Update crate-level mod.rs doc to state RC quality and reference docs - Update provider.rs doc comment to RC quality with retry-scope limitation - Clarify inflight_operations metric is per-dataset (shared concurrency budget is enforced via endpoint-keyed COSMOS_CONCURRENCY_LIMITS map) - Update http_max_retries and backoff_method descriptions to state they apply to the schema-inference sampling pass only * docs(cosmosdb): clarify inflight_operations gauge is per-dataset scope * refactor(cosmosdb): replace CosmosDBClient wrapper with build_container_client Each CosmosDBTableProvider is pinned to a single (database, container) pair, so building a ContainerClient once at connector setup and reusing it — rather than holding a CosmosDBClient wrapper and re-deriving the ContainerClient on every scan — simplifies the API. - Delete the CosmosDBClient struct and its re-export - Add build_container_client(credential, database, container) free function that returns (ContainerClient, endpoint) - Store ContainerClient + endpoint directly on CosmosDBTableProvider and CosmosDBExec - Update fetch_samples and execute to use the pre-built ContainerClient (no per-call .container_client() indirection) * docs(cosmosdb): clarify example comment for max_concurrent_requests=8 The comment said 'tighten' the per-account concurrency budget, but the value (8) raises it above the default (4). Update the comment to say 'raise' and explicitly note the default, so the direction matches the demonstrated value. * Lint * Lint * Lint * fix(cosmosdb): normalize endpoint + fix Fibonacci comment indexing - Normalize the account endpoint (trim trailing '/', lowercase) before returning it as the resilience key, so the per-account concurrency budget is shared across datasets configured with benign URL formatting differences (trailing slash / casing). - Fix the Fibonacci backoff test comment: factors follow F(attempt+2) with F(1)=F(2)=1, not F(attempt+1). The asserted sequence 1, 2, 3, 5, 8 already matched the implementation; only the comment was off by one. * Lint * Lint * Lint * fix(cosmosdb): remove unused schema_override; fix schema-pinning docs Spicepod's dataset.columns is semantic-only and has no type field, so the docs/example claims about pinning Cosmos schema via 'columns:' were misleading — schema pinning is not actually supported by this connector today. This commit aligns the code and docs with reality: - Remove the unused schema_override field + with_schema_override from CosmosDBTableProviderConfig; simplify try_new to always infer schema - Rewrite the 'pinned schema' example to a resilience-tuning example (widens schema_infer_max_records to stabilize inference instead) - Update docs/dev/cosmosdb.md type map and 'What's supported' to remove columns:-based pinning claims and point to schema_infer_max_records - Clarify inflight_operations metric description: it counts operations holding a concurrency permit (incremented once per operation, held across retry backoff sleeps) rather than strictly in-flight HTTP requests * Lint * Lint --------- Co-authored-by: Viktor Yershov --- .gitignore | 1 + Cargo.lock | 27 +- Makefile | 3 +- bin/spiced/Cargo.toml | 2 + crates/data_components/Cargo.toml | 6 + crates/data_components/src/cosmosdb/client.rs | 134 ++++ crates/data_components/src/cosmosdb/mod.rs | 117 +++ .../data_components/src/cosmosdb/provider.rs | 722 ++++++++++++++++++ .../src/cosmosdb/resilience.rs | 620 +++++++++++++++ crates/data_components/src/cosmosdb/schema.rs | 185 +++++ crates/data_components/src/lib.rs | 2 + crates/runtime/Cargo.toml | 1 + crates/runtime/src/dataconnector/cosmosdb.rs | 591 ++++++++++++++ crates/runtime/src/dataconnector/mod.rs | 2 + crates/runtime/tests/cosmosdb/mod.rs | 226 ++++++ crates/runtime/tests/integration.rs | 2 + docs/criteria/connectors/alpha.md | 1 + docs/criteria/connectors/beta.md | 2 + docs/criteria/connectors/rc.md | 2 + docs/dev/cosmosdb.md | 162 ++++ examples/cosmosdb-connector/README.md | 142 ++++ examples/cosmosdb-connector/queries.sql | 35 + examples/cosmosdb-connector/spicepod.yaml | 38 + 23 files changed, 3018 insertions(+), 5 deletions(-) create mode 100644 crates/data_components/src/cosmosdb/client.rs create mode 100644 crates/data_components/src/cosmosdb/mod.rs create mode 100644 crates/data_components/src/cosmosdb/provider.rs create mode 100644 crates/data_components/src/cosmosdb/resilience.rs create mode 100644 crates/data_components/src/cosmosdb/schema.rs create mode 100644 crates/runtime/src/dataconnector/cosmosdb.rs create mode 100644 crates/runtime/tests/cosmosdb/mod.rs create mode 100644 docs/dev/cosmosdb.md create mode 100644 examples/cosmosdb-connector/README.md create mode 100644 examples/cosmosdb-connector/queries.sql create mode 100644 examples/cosmosdb-connector/spicepod.yaml diff --git a/.gitignore b/.gitignore index fe066a3f87..95e5e19496 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ Makefile.local *.parquet spicepod.yaml !test/tpc-bench/tpch-spicepod/spicepod.yaml +!examples/cosmosdb-connector/spicepod.yaml target/ diff --git a/Cargo.lock b/Cargo.lock index 883c22fb6a..f0f115af48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2195,10 +2195,12 @@ dependencies = [ "azure_core_macros", "bytes", "futures", + "hmac 0.12.1", "pin-project", "rustc_version", "serde", "serde_json", + "sha2 0.10.9", "tracing", "typespec", "typespec_client_core", @@ -2216,6 +2218,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "azure_data_cosmos" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196ab882f9a566826713e52ab6bd3a744c4942699f20970565194a81a268ce93" +dependencies = [ + "async-lock", + "async-trait", + "azure_core 0.31.0", + "futures", + "serde", + "serde_json", + "tracing", + "url", +] + [[package]] name = "azure_storage" version = "0.21.0" @@ -5255,6 +5273,8 @@ dependencies = [ "aws-sdk-dynamodbstreams", "aws-smithy-async", "aws-smithy-types", + "azure_core 0.31.0", + "azure_data_cosmos", "base64 0.22.1", "bb8", "bb8-oracle", @@ -7508,7 +7528,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", "rustix 1.1.4", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -8239,7 +8259,7 @@ dependencies = [ "libc", "log", "rustversion", - "windows-link 0.1.3", + "windows-link 0.2.1", "windows-result 0.4.1", ] @@ -16510,7 +16530,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -18517,7 +18537,6 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.52.0", "windows-sys 0.59.0", ] diff --git a/Makefile b/Makefile index 628872c132..2c1c893eca 100644 --- a/Makefile +++ b/Makefile @@ -207,6 +207,7 @@ check-rust-features: cargo check $(CARGO_PROFILE) --no-default-features --features delta_lake cargo check $(CARGO_PROFILE) --no-default-features --features dremio cargo check $(CARGO_PROFILE) --no-default-features --features clickhouse + cargo check $(CARGO_PROFILE) --no-default-features --features cosmosdb cargo check $(CARGO_PROFILE) --no-default-features --features debezium cargo check $(CARGO_PROFILE) --no-default-features --features runtime/openapi cargo check $(CARGO_PROFILE) --no-default-features --features dynamodb @@ -253,7 +254,7 @@ display-deps: # Default install includes models. Use -data suffix variants to build without models. # Data-only features (default features minus models) # Note: postgres-accel enables the PostgreSQL data accelerator (separate from postgres connector) -SPICED_DATA_FEATURES := duckdb,postgres,postgres-accel,sqlite,mysql,flightsql,delta_lake,databricks,dremio,clickhouse,sharepoint,snapshots,snowflake,spark,ftp,sftp,debezium,kafka,anonymous_telemetry,mssql,dynamodb,imap,alloc-snmalloc,oracle,runtime/s3_vectors,mongodb,iceberg-write,turso,smb,pingora,scylladb +SPICED_DATA_FEATURES := duckdb,postgres,postgres-accel,sqlite,mysql,flightsql,delta_lake,databricks,dremio,clickhouse,cosmosdb,sharepoint,snapshots,snowflake,spark,ftp,sftp,debezium,kafka,anonymous_telemetry,mssql,dynamodb,imap,alloc-snmalloc,oracle,runtime/s3_vectors,mongodb,iceberg-write,turso,smb,pingora,scylladb .PHONY: install install: build diff --git a/bin/spiced/Cargo.toml b/bin/spiced/Cargo.toml index b0cb038072..bf73171469 100644 --- a/bin/spiced/Cargo.toml +++ b/bin/spiced/Cargo.toml @@ -74,6 +74,7 @@ alloc-system = [] anonymous_telemetry = ["telemetry/anonymous_telemetry"] aws-secrets-manager = ["runtime/aws-secrets-manager"] clickhouse = ["connector-clickhouse", "runtime/clickhouse"] +cosmosdb = ["runtime/cosmosdb"] cuda = ["runtime/cuda"] databricks = ["connector-databricks", "runtime/databricks"] debezium = ["runtime/debezium"] @@ -88,6 +89,7 @@ default = [ "databricks", "dremio", "clickhouse", + "cosmosdb", "sharepoint", "snapshots", "snowflake", diff --git a/crates/data_components/Cargo.toml b/crates/data_components/Cargo.toml index d20021ef61..d7ba0cc89e 100644 --- a/crates/data_components/Cargo.toml +++ b/crates/data_components/Cargo.toml @@ -25,6 +25,11 @@ aws-sdk-credential-bridge = { path = "../aws-sdk-credential-bridge" } aws-sdk-dynamodb = { workspace = true, optional = true } aws-sdk-dynamodbstreams = { workspace = true, optional = true } aws-smithy-async = { workspace = true, optional = true } +azure_core = { version = "0.31.0", optional = true } +azure_data_cosmos = { version = "0.30.0", default-features = false, features = [ + "hmac_rust", + "key_auth", +], optional = true } base64.workspace = true bb8 = { workspace = true, optional = true } bb8-oracle = { version = "0.3", features = ["chrono"], optional = true } @@ -121,6 +126,7 @@ rdkafka = { workspace = true, features = ["cmake-build"], optional = true } bench = [] # Feature for benchmarking that exposes internal functions adbc = ["datafusion-table-providers/adbc-federation"] clickhouse = ["dep:clickhouse-rs", "datafusion-table-providers/federation"] +cosmosdb = ["dep:azure_core", "dep:azure_data_cosmos"] databricks = [ "delta_lake", "spark_connect", diff --git a/crates/data_components/src/cosmosdb/client.rs b/crates/data_components/src/cosmosdb/client.rs new file mode 100644 index 0000000000..7709a7cd3a --- /dev/null +++ b/crates/data_components/src/cosmosdb/client.rs @@ -0,0 +1,134 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Build a [`ContainerClient`] for a specific `(database, container)` from a +//! user-supplied credential. Each [`CosmosDBTableProvider`] is pinned to one +//! container, so we construct the `ContainerClient` once at connector setup +//! and reuse it for schema inference and every subsequent scan. +//! +//! [`CosmosDBTableProvider`]: super::provider::CosmosDBTableProvider + +use std::sync::Arc; + +use azure_core::credentials::Secret; +use azure_data_cosmos::{ConnectionString, CosmosClient, clients::ContainerClient}; +use snafu::ResultExt; + +use super::{BuildClientSnafu, Error, InvalidConnectionStringSnafu}; + +/// Credential used to build a Cosmos client. +/// +/// Carries account keys / full connection strings; the manual `Debug` below +/// redacts both so tracing / panic dumps never surface them. +#[derive(Clone)] +pub enum CosmosDBCredential { + /// An `AccountEndpoint=https://...;AccountKey=...;` connection string. + ConnectionString(String), + /// Explicit account endpoint URL plus primary/secondary key. + Key { endpoint: String, key: String }, +} + +impl std::fmt::Debug for CosmosDBCredential { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectionString(_) => f + .debug_tuple("ConnectionString") + .field(&"") + .finish(), + Self::Key { endpoint, .. } => f + .debug_struct("Key") + .field("endpoint", endpoint) + .field("key", &"") + .finish(), + } + } +} + +/// Build a [`ContainerClient`] for the given `(database, container)` pair, +/// returning the account endpoint alongside it (needed for resilience keying +/// and error messages). +/// +/// # Errors +/// Returns an error if the credential is malformed or the underlying Azure +/// SDK client cannot be constructed. +pub fn build_container_client( + credential: CosmosDBCredential, + database: &str, + container: &str, +) -> Result<(ContainerClient, Arc), Error> { + let (client, endpoint) = match credential { + CosmosDBCredential::ConnectionString(conn_str) => { + let parsed: ConnectionString = conn_str + .parse() + .map_err(boxed_err) + .context(InvalidConnectionStringSnafu)?; + let endpoint = parsed.account_endpoint; + + let client = CosmosClient::with_connection_string(Secret::from(conn_str), None) + .map_err(boxed_err) + .context(BuildClientSnafu { + endpoint: endpoint.clone(), + })?; + + (client, endpoint) + } + CosmosDBCredential::Key { endpoint, key } => { + let client = CosmosClient::with_key(&endpoint, Secret::from(key), None) + .map_err(boxed_err) + .context(BuildClientSnafu { + endpoint: endpoint.clone(), + })?; + + (client, endpoint) + } + }; + + let container_client = client.database_client(database).container_client(container); + + Ok((container_client, Arc::from(normalize_endpoint(&endpoint)))) +} + +/// Normalize a Cosmos DB account endpoint so benign URL-formatting differences +/// (trailing slash, casing) don't split the shared per-account concurrency +/// budget across datasets that target the same account. +fn normalize_endpoint(endpoint: &str) -> String { + endpoint.trim().trim_end_matches('/').to_ascii_lowercase() +} + +fn boxed_err(e: E) -> Box +where + E: std::error::Error + Send + Sync + 'static, +{ + Box::new(e) +} + +#[cfg(test)] +mod tests { + use super::normalize_endpoint; + + #[test] + fn normalize_endpoint_collapses_benign_variants() { + let canonical = "https://myaccount.documents.azure.com:443"; + for variant in [ + "https://myaccount.documents.azure.com:443", + "https://myaccount.documents.azure.com:443/", + "https://MYACCOUNT.documents.azure.com:443/", + " https://myaccount.documents.azure.com:443/ ", + ] { + assert_eq!(normalize_endpoint(variant), canonical, "input: {variant:?}"); + } + } +} diff --git a/crates/data_components/src/cosmosdb/mod.rs b/crates/data_components/src/cosmosdb/mod.rs new file mode 100644 index 0000000000..1ebb684db6 --- /dev/null +++ b/crates/data_components/src/cosmosdb/mod.rs @@ -0,0 +1,117 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Azure Cosmos DB (`NoSQL` / Core SQL API) data connector components. +//! +//! Exposes a [`provider::CosmosDBTableProvider`] built on top of the +//! `azure_data_cosmos` crate. Documents are fetched via a Cosmos SQL query +//! (`SELECT * FROM c` by default) and projected into Arrow `RecordBatch`es. +//! +//! The current connector release targets *RC* quality: read-only, +//! cross-partition scan, schema inference from a sample of documents, and no +//! filter push-down yet. See `docs/criteria/connectors/rc.md` for the full +//! Cosmos DB row and `docs/dev/cosmosdb.md` for the type map and limitations. + +pub mod client; +pub mod provider; +pub mod resilience; +pub mod schema; + +use snafu::Snafu; + +pub use client::{CosmosDBCredential, build_container_client}; +pub use provider::CosmosDBTableProvider; +pub use resilience::{ + BackoffMethod, CosmosResilienceConfig, DEFAULT_MAX_CONCURRENT_REQUESTS, DEFAULT_MAX_RETRIES, + ResilienceError, +}; + +pub type Result = std::result::Result; + +/// Default SQL query used when no custom `query` is provided. Selects every +/// property from the root alias `c`, which is the canonical cross-partition +/// scan in Cosmos DB `NoSQL`. +pub const DEFAULT_QUERY: &str = "SELECT * FROM c"; + +/// Default sample size used for schema inference when no explicit value is +/// provided. Kept intentionally small to minimize Request Unit (RU) usage on +/// initial dataset registration. +pub const DEFAULT_SCHEMA_INFER_MAX_RECORDS: usize = 100; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub))] +pub enum Error { + #[snafu(display( + "Failed to build the Azure Cosmos DB client for account {endpoint}: {source}" + ))] + BuildClient { + endpoint: String, + source: Box, + }, + + #[snafu(display( + "Invalid Azure Cosmos DB connection string. Ensure the connection string was copied directly from the Azure portal: {source}" + ))] + InvalidConnectionString { + source: Box, + }, + + #[snafu(display( + "Azure Cosmos DB requires either 'connection_string' or both 'account_endpoint' and 'account_key' to be set." + ))] + MissingCredentials, + + #[snafu(display( + "Failed to query Azure Cosmos DB container '{container}' in database '{database}': {source}" + ))] + QueryFailed { + database: String, + container: String, + source: Box, + }, + + #[snafu(display( + "Azure Cosmos DB container '{container}' in database '{database}' returned no documents to infer schema from. \ + Ensure the container is populated, or pin a schema explicitly via the dataset `columns` configuration." + ))] + EmptyContainer { database: String, container: String }, + + #[snafu(display("Failed to infer Arrow schema from Cosmos DB documents: {source}"))] + SchemaInference { source: arrow::error::ArrowError }, + + #[snafu(display("Failed to decode Cosmos DB JSON document into Arrow: {source}"))] + JsonDecode { source: arrow::error::ArrowError }, + + #[snafu(display( + "Invalid dataset path '{path}'. Azure Cosmos DB dataset paths must be of the form 'database.container' or 'database/container'." + ))] + InvalidDatasetPath { path: String }, + + #[snafu(display( + "The Azure Cosmos DB connector at '{endpoint}' is disabled after a permanent error (401/403/404). Fix the credentials or grants, then restart Spice." + ))] + ConnectorDisabled { endpoint: String }, + + #[snafu(display( + "Column '{column}' in Azure Cosmos DB dataset '{database}.{container}' has an unsupported Arrow data type ({data_type}). Set the dataset's `unsupported_type_action` parameter to `warn`, `ignore`, or `string` to proceed." + ))] + UnsupportedColumn { + database: String, + container: String, + column: String, + data_type: String, + }, +} diff --git a/crates/data_components/src/cosmosdb/provider.rs b/crates/data_components/src/cosmosdb/provider.rs new file mode 100644 index 0000000000..a101386d2e --- /dev/null +++ b/crates/data_components/src/cosmosdb/provider.rs @@ -0,0 +1,722 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Arrow [`TableProvider`] implementation for an Azure Cosmos DB `NoSQL` +//! container. +//! +//! The provider executes a single Cosmos SQL query (defaulting to +//! `SELECT * FROM c`) against the configured container, infers an Arrow +//! schema from a sample of documents on first access, and streams the full +//! result set into record batches via `arrow::json::ReaderBuilder`. +//! +//! This is an RC-quality implementation with the following current +//! limitations: +//! * Read-only (no INSERT / UPDATE / DELETE). +//! * Cross-partition scan only — no filter or projection push-down. +//! * Schema inferred from a sample; pin the schema via the dataset +//! `columns:` spicepod property when stability is required. +//! * Retries/backoff apply to the schema-inference pass only; mid-stream +//! pager errors during scan execution propagate directly. +//! * Cosmos DB Rust SDK 0.30 has limited cross-partition capabilities; see +//! the module-level documentation. + +use std::any::Any; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::json::ReaderBuilder; +use async_trait::async_trait; +use datafusion::catalog::Session; +use datafusion::common::{Result as DataFusionResult, project_schema}; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::DataFusionError; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchReceiverStream; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + execution_plan::{Boundedness, EmissionType}, +}; +use datafusion::prelude::Expr; +use datafusion_table_providers::UnsupportedTypeAction; +use futures::StreamExt; +use serde_json::Value; +use snafu::ResultExt; + +use azure_data_cosmos::clients::ContainerClient; + +use super::resilience::{CosmosResilienceConfig, ResilienceError, run_with_resilience}; +use super::schema::{infer_schema, strip_system_fields}; +use super::{DEFAULT_SCHEMA_INFER_MAX_RECORDS, EmptyContainerSnafu, Error, JsonDecodeSnafu}; + +/// Number of documents emitted per `RecordBatch` when streaming results. +const STREAM_BATCH_SIZE: usize = 1024; + +/// Configuration for a single Cosmos DB dataset. +#[derive(Debug, Clone)] +pub struct CosmosDBTableProviderConfig { + pub database: String, + pub container: String, + /// Cosmos SQL query to execute. Defaults to `SELECT * FROM c`. + pub query: String, + /// Number of documents sampled when inferring the schema. + pub schema_infer_max_records: usize, + /// How to handle columns whose type Cosmos DB cannot represent (e.g. + /// all-null samples that Arrow's JSON inference returns as + /// [`DataType::Null`]). Defaults to [`UnsupportedTypeAction::Warn`]. + pub unsupported_type_action: UnsupportedTypeAction, + pub resilience: CosmosResilienceConfig, +} + +impl CosmosDBTableProviderConfig { + #[must_use] + pub fn new( + database: impl Into, + container: impl Into, + query: impl Into, + ) -> Self { + Self { + database: database.into(), + container: container.into(), + query: query.into(), + schema_infer_max_records: DEFAULT_SCHEMA_INFER_MAX_RECORDS, + unsupported_type_action: UnsupportedTypeAction::Warn, + resilience: CosmosResilienceConfig::default(), + } + } + + #[must_use] + pub fn with_schema_infer_max_records(mut self, n: usize) -> Self { + self.schema_infer_max_records = n; + self + } + + #[must_use] + pub fn with_resilience(mut self, resilience: CosmosResilienceConfig) -> Self { + self.resilience = resilience; + self + } + + #[must_use] + pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self { + self.unsupported_type_action = action; + self + } +} + +/// Arrow [`TableProvider`] backed by an Azure Cosmos DB container. +#[derive(Clone)] +pub struct CosmosDBTableProvider { + container_client: ContainerClient, + endpoint: Arc, + config: Arc, + schema: SchemaRef, +} + +impl std::fmt::Debug for CosmosDBTableProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CosmosDBTableProvider") + .field("endpoint", &self.endpoint) + .field("config", &self.config) + .field("schema", &self.schema) + .finish_non_exhaustive() + } +} + +impl CosmosDBTableProvider { + /// Build a new table provider, inferring the schema by sampling a batch + /// of documents from the container. + /// + /// `container_client` is pre-built for the `(database, container)` pair + /// carried on `config`; `endpoint` is the Cosmos account endpoint used for + /// resilience keying and error messages. + /// + /// # Errors + /// Returns an error if the sample query fails or the container is empty. + pub async fn try_new( + container_client: ContainerClient, + endpoint: Arc, + config: CosmosDBTableProviderConfig, + ) -> Result { + let samples = fetch_samples( + &container_client, + &endpoint, + &config.database, + &config.container, + &config.query, + config.schema_infer_max_records, + &config.resilience, + ) + .await?; + + if samples.is_empty() { + return EmptyContainerSnafu { + database: config.database.clone(), + container: config.container.clone(), + } + .fail(); + } + + let inferred = infer_schema(&samples)?; + let schema = apply_unsupported_type_action( + &inferred, + config.unsupported_type_action, + &config.database, + &config.container, + )?; + + Ok(Self { + container_client, + endpoint, + config: Arc::new(config), + schema, + }) + } + + #[must_use] + pub fn schema_ref(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// Apply the configured [`UnsupportedTypeAction`] to an inferred schema. +/// +/// For Cosmos DB, the only type Arrow's JSON inference can produce that +/// downstream query engines may refuse is [`DataType::Null`] — it appears when +/// every sampled document has `null` for a field. +fn apply_unsupported_type_action( + inferred: &SchemaRef, + action: UnsupportedTypeAction, + database: &str, + container: &str, +) -> Result { + if !schema_has_unsupported_columns(inferred) { + return Ok(Arc::clone(inferred)); + } + + let mut kept: Vec> = Vec::with_capacity(inferred.fields().len()); + for field in inferred.fields() { + if is_unsupported_cosmos_field(field) { + match action { + UnsupportedTypeAction::Error => { + return Err(Error::UnsupportedColumn { + database: database.to_string(), + container: container.to_string(), + column: field.name().clone(), + data_type: format!("{:?}", field.data_type()), + }); + } + UnsupportedTypeAction::Warn => { + tracing::warn!( + database = %database, + container = %container, + column = %field.name(), + data_type = %format!("{:?}", field.data_type()), + "Dropping column '{}' from Cosmos DB dataset {database}.{container}: Arrow inferred an unsupported data type ({:?}). All sampled documents were null for this field — populate the field or pin a schema via `columns:` to override.", + field.name(), + field.data_type() + ); + } + UnsupportedTypeAction::Ignore => { + // Silently drop the column. + } + UnsupportedTypeAction::String => { + kept.push(Arc::new(Field::new( + field.name(), + DataType::Utf8, + field.is_nullable(), + ))); + } + } + } else { + kept.push(Arc::::clone(field)); + } + } + + Ok(Arc::new(Schema::new(kept))) +} + +fn is_unsupported_cosmos_field(field: &Arc) -> bool { + matches!(field.data_type(), DataType::Null) +} + +fn schema_has_unsupported_columns(schema: &SchemaRef) -> bool { + schema.fields().iter().any(is_unsupported_cosmos_field) +} + +/// Sample up to `limit` documents from the container for schema inference. +/// +/// Wrapped by [`run_with_resilience`] so the whole sampling operation is +/// retried (with fresh pager construction) on transient errors, bounded by +/// the configured retry budget. +async fn fetch_samples( + container_client: &ContainerClient, + endpoint: &str, + database: &str, + container: &str, + query: &str, + limit: usize, + resilience: &CosmosResilienceConfig, +) -> Result, Error> { + run_with_resilience(resilience, endpoint, || async { + let mut pager = container_client.query_items::(query, (), None)?; + let mut samples = Vec::with_capacity(limit.min(1024)); + while samples.len() < limit { + match pager.next().await { + Some(Ok(doc)) => samples.push(strip_system_fields(doc)), + Some(Err(e)) => return Err(e), + None => break, + } + } + Ok(samples) + }) + .await + .map_err(|e| match e { + ResilienceError::Disabled => Error::ConnectorDisabled { + endpoint: endpoint.to_string(), + }, + ResilienceError::Request(source) => Error::QueryFailed { + database: database.to_string(), + container: container.to_string(), + source: Box::new(source), + }, + }) +} + +#[async_trait] +impl TableProvider for CosmosDBTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> DataFusionResult> { + let projected_schema = project_schema(&self.schema, projection)?; + + Ok(Arc::new(CosmosDBExec::new( + self.container_client.clone(), + Arc::clone(&self.endpoint), + Arc::clone(&self.config), + Arc::clone(&self.schema), + projected_schema, + projection.cloned(), + ))) + } +} + +/// [`ExecutionPlan`] that streams documents from a Cosmos DB container and +/// converts them into Arrow record batches. +struct CosmosDBExec { + container_client: ContainerClient, + endpoint: Arc, + config: Arc, + /// Full (un-projected) schema used when decoding JSON. + full_schema: SchemaRef, + /// Schema presented to `DataFusion` after projection. + projected_schema: SchemaRef, + projection: Option>, + properties: PlanProperties, +} + +impl CosmosDBExec { + fn new( + container_client: ContainerClient, + endpoint: Arc, + config: Arc, + full_schema: SchemaRef, + projected_schema: SchemaRef, + projection: Option>, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&projected_schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + Self { + container_client, + endpoint, + config, + full_schema, + projected_schema, + projection, + properties, + } + } +} + +impl std::fmt::Debug for CosmosDBExec { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("CosmosDBExec") + .field("database", &self.config.database) + .field("container", &self.config.container) + .field("query", &self.config.query) + .finish_non_exhaustive() + } +} + +impl DisplayAs for CosmosDBExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "CosmosDBExec: database={}, container={}, query={}", + self.config.database, self.config.container, self.config.query + ) + } +} + +impl ExecutionPlan for CosmosDBExec { + fn name(&self) -> &'static str { + "CosmosDBExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.projected_schema) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DataFusionResult> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DataFusionResult { + let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&self.projected_schema), 2); + let tx = builder.tx(); + + let container_client = self.container_client.clone(); + let endpoint = Arc::clone(&self.endpoint); + let config = Arc::clone(&self.config); + let full_schema = Arc::clone(&self.full_schema); + let projection = self.projection.clone(); + + builder.spawn(async move { + if config.resilience.disabled.load(Ordering::Acquire) { + return Err(to_df_error(Error::ConnectorDisabled { + endpoint: endpoint.to_string(), + })); + } + + // Permit + inflight guard are held as `_`-bindings so they release + // automatically when the async block returns — including on + // cancellation or receiver-drop mid-stream. + let _permit = match &config.resilience.semaphore { + Some(s) => Some( + Arc::::clone(s) + .acquire_owned() + .await + .map_err(|_| { + to_df_error(Error::ConnectorDisabled { + endpoint: endpoint.to_string(), + }) + })?, + ), + None => None, + }; + let _inflight = crate::cosmosdb::resilience::InflightGuard::enter( + Arc::::clone(&config.resilience.inflight), + ); + + let handle_stream_error = |resilience: &CosmosResilienceConfig, + endpoint: &str, + err: azure_core::Error| + -> DataFusionError { + if crate::cosmosdb::resilience::is_permanent_error(&err) + && resilience.disable_on_permanent_error + { + resilience.disabled.store(true, Ordering::Release); + tracing::error!( + endpoint = %endpoint, + "Permanent error from Azure Cosmos DB; disabling connector. {err}" + ); + } + to_df_error(Error::QueryFailed { + database: config.database.clone(), + container: config.container.clone(), + source: Box::new(err), + }) + }; + + let mut pager = container_client + .query_items::(config.query.as_str(), (), None) + .map_err(|e| handle_stream_error(&config.resilience, &endpoint, e))?; + + let mut buffer: Vec = Vec::with_capacity(STREAM_BATCH_SIZE); + + while let Some(item) = pager.next().await { + let doc = + item.map_err(|e| handle_stream_error(&config.resilience, &endpoint, e))?; + + buffer.push(strip_system_fields(doc)); + + if buffer.len() >= STREAM_BATCH_SIZE { + let batch = decode_batch(&buffer, &full_schema, projection.as_deref()) + .map_err(to_df_error)?; + buffer.clear(); + if tx.send(Ok(batch)).await.is_err() { + // Receiver dropped; stop scanning. + return Ok(()); + } + } + } + + if !buffer.is_empty() { + let batch = decode_batch(&buffer, &full_schema, projection.as_deref()) + .map_err(to_df_error)?; + let _ = tx.send(Ok(batch)).await; + } + + Ok::<_, DataFusionError>(()) + }); + + Ok(builder.build()) + } +} + +fn decode_batch( + docs: &[Value], + full_schema: &SchemaRef, + projection: Option<&[usize]>, +) -> Result { + // Hand the Value slice directly to arrow-json's serde-aware decoder, + // avoiding the NDJSON serialize -> parse round-trip. + let mut decoder = ReaderBuilder::new(Arc::clone(full_schema)) + .build_decoder() + .context(JsonDecodeSnafu)?; + + if !docs.is_empty() { + decoder.serialize(docs).context(JsonDecodeSnafu)?; + } + + let full_batch = decoder + .flush() + .context(JsonDecodeSnafu)? + .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(full_schema))); + + if let Some(indices) = projection { + full_batch.project(indices).context(JsonDecodeSnafu) + } else { + Ok(full_batch) + } +} + +fn to_df_error(e: Error) -> DataFusionError { + DataFusionError::External(Box::new(e) as Box) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use serde_json::json; + + fn sample_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, true), + Field::new("count", DataType::Int64, true), + ])) + } + + #[test] + fn decode_batch_decodes_multiple_documents() { + let schema = sample_schema(); + let docs = vec![ + json!({"id": "a", "count": 1}), + json!({"id": "b", "count": 2}), + json!({"id": "c", "count": 3}), + ]; + let batch = decode_batch(&docs, &schema, None).expect("decode_batch failed"); + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 2); + + let id_col = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("column 0 should be StringArray"); + assert_eq!(id_col.value(0), "a"); + assert_eq!(id_col.value(2), "c"); + + let count_col = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("column 1 should be Int64Array"); + assert_eq!(count_col.value(0), 1); + assert_eq!(count_col.value(2), 3); + } + + #[test] + fn decode_batch_applies_projection() { + let schema = sample_schema(); + let docs = vec![json!({"id": "a", "count": 1})]; + // Project only the second column (`count`). + let batch = + decode_batch(&docs, &schema, Some(&[1])).expect("decode_batch with projection failed"); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.schema().field(0).name(), "count"); + } + + #[test] + fn decode_batch_handles_empty_input() { + let schema = sample_schema(); + let batch = decode_batch(&[], &schema, None).expect("decode_batch on empty input failed"); + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.num_columns(), 2); + } + + #[test] + fn decode_batch_fills_missing_fields_with_null() { + // Cosmos documents are schemaless — some docs may omit fields the + // inferred schema includes. Those cells must surface as nulls. + let schema = sample_schema(); + let docs = vec![json!({"id": "a"}), json!({"id": "b", "count": 2})]; + let batch = decode_batch(&docs, &schema, None).expect("decode_batch failed"); + let count_col = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("column 1 should be Int64Array"); + assert!(count_col.is_null(0)); + assert_eq!(count_col.value(1), 2); + } + + /// Beta criterion: the connector must handle datasets whose column count + /// matches the source limit. Cosmos DB has no formal column cap, but + /// production tenants routinely store 1024+ top-level fields. Build a + /// synthetic schema + document of that size and verify end-to-end + /// JSON-to-Arrow decoding does not OOM or regress. + #[test] + fn decode_batch_handles_wide_schema() { + const COLS: usize = 1024; + let fields: Vec = (0..COLS) + .map(|i| Field::new(format!("col_{i}"), DataType::Int64, true)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + let mut obj = serde_json::Map::with_capacity(COLS); + for i in 0..COLS { + obj.insert(format!("col_{i}"), json!(i64::try_from(i).unwrap_or(0))); + } + let docs = vec![Value::Object(obj)]; + + let batch = decode_batch(&docs, &schema, None).expect("decode_batch on wide schema failed"); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), COLS); + let mid = batch + .column(COLS / 2) + .as_any() + .downcast_ref::() + .expect("middle column should be Int64Array"); + assert_eq!(mid.value(0), i64::try_from(COLS / 2).unwrap_or(0)); + } + + fn schema_with_null_column() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, true), + Field::new("always_null", DataType::Null, true), + ])) + } + + #[test] + fn unsupported_type_action_warn_drops_null_columns() { + let schema = schema_with_null_column(); + let projected = + apply_unsupported_type_action(&schema, UnsupportedTypeAction::Warn, "db", "container") + .expect("Warn action should succeed"); + assert_eq!(projected.fields().len(), 1); + assert_eq!(projected.field(0).name(), "id"); + } + + #[test] + fn unsupported_type_action_ignore_drops_silently() { + let schema = schema_with_null_column(); + let projected = apply_unsupported_type_action( + &schema, + UnsupportedTypeAction::Ignore, + "db", + "container", + ) + .expect("Ignore action should succeed"); + assert_eq!(projected.fields().len(), 1); + } + + #[test] + fn unsupported_type_action_string_coerces_to_utf8() { + let schema = schema_with_null_column(); + let projected = apply_unsupported_type_action( + &schema, + UnsupportedTypeAction::String, + "db", + "container", + ) + .expect("String action should succeed"); + assert_eq!(projected.fields().len(), 2); + assert_eq!(projected.field(1).data_type(), &DataType::Utf8); + } + + #[test] + fn unsupported_type_action_error_surfaces_to_caller() { + let schema = schema_with_null_column(); + let err = + apply_unsupported_type_action(&schema, UnsupportedTypeAction::Error, "db", "container") + .expect_err("Error action should fail on unsupported column"); + assert!(matches!(err, Error::UnsupportedColumn { .. })); + } + + #[test] + fn unsupported_type_action_is_noop_on_clean_schema() { + let schema = sample_schema(); + let projected = + apply_unsupported_type_action(&schema, UnsupportedTypeAction::Error, "db", "container") + .expect("Error action on clean schema should succeed"); + assert!(Arc::ptr_eq(&schema, &projected)); + } +} diff --git a/crates/data_components/src/cosmosdb/resilience.rs b/crates/data_components/src/cosmosdb/resilience.rs new file mode 100644 index 0000000000..5dc735b44e --- /dev/null +++ b/crates/data_components/src/cosmosdb/resilience.rs @@ -0,0 +1,620 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Connection-resilience primitives for the Azure Cosmos DB connector. +//! +//! The Cosmos DB SDK explicitly disables `typespec`'s retry pipeline +//! (`azure_data_cosmos::clients::cosmos_client` sets `RetryOptions::none()`), +//! so the connector owns retry, concurrency limiting, and permanent-error +//! detection itself. This matches the pattern used by the Git connector in +//! `crates/data_components/src/git.rs` and satisfies the RC +//! "Connection Resilience" gate in +//! `docs/criteria/connectors/rc.md`. + +use std::sync::{ + Arc, + atomic::{AtomicBool, AtomicU64, Ordering}, +}; +use std::time::Duration; + +use azure_core::error::ErrorKind; +use azure_core::http::headers::{HeaderName, Headers, X_MS_RETRY_AFTER_MS}; +use tokio::sync::Semaphore; + +/// Default upper bound on in-flight Cosmos DB requests per account endpoint. +pub const DEFAULT_MAX_CONCURRENT_REQUESTS: usize = 4; + +/// Default number of retries before a transient error is surfaced. +pub const DEFAULT_MAX_RETRIES: u32 = 3; + +const RETRY_INITIAL_BACKOFF: Duration = Duration::from_millis(500); +const RETRY_MAX_BACKOFF: Duration = Duration::from_secs(30); + +/// Standard `Retry-After` header. typespec's header registry keeps it as a +/// standard header name but does not expose a `pub const`, so construct it +/// locally. Cosmos uses the integer-seconds form in practice. +const RETRY_AFTER_HEADER: HeaderName = HeaderName::from_static("retry-after"); + +/// Backoff strategy for retries on transient Cosmos DB errors. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackoffMethod { + Exponential, + Fibonacci, +} + +impl BackoffMethod { + /// Parse a user-supplied string. Accepted values are + /// `"exponential"` and `"fibonacci"` (case-insensitive). + /// + /// # Errors + /// Returns a human-readable message describing the invalid value. + pub fn parse(value: &str) -> Result { + match value.to_ascii_lowercase().as_str() { + "exponential" => Ok(Self::Exponential), + "fibonacci" => Ok(Self::Fibonacci), + other => Err(format!( + "invalid backoff_method '{other}'. Expected 'exponential' or 'fibonacci'." + )), + } + } +} + +/// RAII guard that increments an in-flight counter on construction and +/// decrements it on drop. Cancellation-safe: if the surrounding future is +/// dropped before completion, the counter still returns to its prior value. +pub struct InflightGuard { + counter: Arc, +} + +impl InflightGuard { + pub fn enter(counter: Arc) -> Self { + counter.fetch_add(1, Ordering::Relaxed); + Self { counter } + } +} + +impl Drop for InflightGuard { + fn drop(&mut self) { + self.counter.fetch_sub(1, Ordering::Relaxed); + } +} + +/// Configuration used to tune retry, concurrency, and permanent-error +/// behavior. Produced by the runtime factory from user-facing parameters. +#[derive(Debug, Clone)] +pub struct CosmosResilienceConfig { + pub max_retries: u32, + pub backoff: BackoffMethod, + /// Bounds the number of concurrent Cosmos DB requests. `None` disables + /// concurrency limiting. + pub semaphore: Option>, + /// When true, a 401/403/404 response latches the connector as disabled. + pub disable_on_permanent_error: bool, + /// Counter driving the `inflight_operations` metric gauge. + pub inflight: Arc, + /// Shared latch inspected before every request; once set to `true`, all + /// subsequent operations on the same account endpoint short-circuit. + pub disabled: Arc, +} + +impl Default for CosmosResilienceConfig { + fn default() -> Self { + Self { + max_retries: DEFAULT_MAX_RETRIES, + backoff: BackoffMethod::Exponential, + semaphore: None, + disable_on_permanent_error: true, + inflight: Arc::new(AtomicU64::new(0)), + disabled: Arc::new(AtomicBool::new(false)), + } + } +} + +/// Result returned by [`run_with_resilience`]. +#[derive(Debug)] +pub enum ResilienceError { + /// The connector is latched in a disabled state from a prior permanent + /// error. Callers should map this to a domain error surface. + Disabled, + /// The underlying SDK surfaced an error that either is non-retryable or + /// exhausted the retry budget. + Request(azure_core::Error), +} + +impl std::fmt::Display for ResilienceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Disabled => f.write_str("Azure Cosmos DB connector has been disabled"), + Self::Request(e) => write!(f, "{e}"), + } + } +} + +impl std::error::Error for ResilienceError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Disabled => None, + Self::Request(e) => Some(e), + } + } +} + +/// Classify an SDK error as permanent (authn/authz/not-found) vs. transient. +#[must_use] +pub fn is_permanent_error(err: &azure_core::Error) -> bool { + matches!(err.http_status().map(u16::from), Some(401 | 403 | 404)) +} + +/// Extract a `Retry-After` delay from an error's raw HTTP response headers, +/// if any. Honors both the standard `Retry-After` (seconds) and the +/// Cosmos-specific `x-ms-retry-after-ms` header. +#[must_use] +pub fn retry_after_from_error(err: &azure_core::Error) -> Option { + if let ErrorKind::HttpResponse { + raw_response: Some(response), + .. + } = err.kind() + { + return retry_after_from_headers(response.headers()); + } + None +} + +fn retry_after_from_headers(headers: &Headers) -> Option { + if let Some(ms_str) = headers.get_optional_str(&X_MS_RETRY_AFTER_MS) + && let Ok(ms) = ms_str.parse::() + { + return Some(Duration::from_millis(ms)); + } + if let Some(secs_str) = headers.get_optional_str(&RETRY_AFTER_HEADER) + && let Ok(secs) = secs_str.parse::() + { + return Some(Duration::from_secs(secs)); + } + None +} + +/// Compute the backoff delay for the given attempt under the configured +/// method, capped at [`RETRY_MAX_BACKOFF`]. +#[must_use] +pub fn backoff_delay(method: BackoffMethod, attempt: u32) -> Duration { + let factor_u64: u64 = match method { + BackoffMethod::Exponential => 2u64.saturating_pow(attempt), + BackoffMethod::Fibonacci => { + let (mut a, mut b) = (1u64, 1u64); + for _ in 0..attempt { + let next = a.saturating_add(b); + a = b; + b = next; + } + b + } + }; + let factor = u32::try_from(factor_u64).unwrap_or(u32::MAX); + RETRY_INITIAL_BACKOFF + .saturating_mul(factor) + .min(RETRY_MAX_BACKOFF) +} + +/// Execute `operation` with concurrency limiting, retry on transient errors, +/// permanent-error detection, and in-flight tracking. +/// +/// `operation` is invoked once per attempt to produce a fresh future; it must +/// therefore be idempotent (or, more precisely, safe to re-issue from scratch). +/// The semaphore permit and [`InflightGuard`] are held for the lifetime of all +/// attempts. +/// +/// # Errors +/// Returns [`ResilienceError::Disabled`] if the shared disabled flag is set, +/// or [`ResilienceError::Request`] once retries are exhausted or on a +/// permanent error. +pub async fn run_with_resilience( + config: &CosmosResilienceConfig, + endpoint: &str, + operation: F, +) -> Result +where + F: Fn() -> Fut, + Fut: std::future::Future>, +{ + if config.disabled.load(Ordering::Acquire) { + return Err(ResilienceError::Disabled); + } + + let _permit = match &config.semaphore { + Some(s) => Some( + Arc::::clone(s) + .acquire_owned() + .await + .map_err(|_| ResilienceError::Disabled)?, + ), + None => None, + }; + + let _inflight = InflightGuard::enter(Arc::::clone(&config.inflight)); + + let mut attempt: u32 = 0; + loop { + match operation().await { + Ok(v) => return Ok(v), + Err(err) => { + let is_perm = is_permanent_error(&err); + + if is_perm && config.disable_on_permanent_error { + config.disabled.store(true, Ordering::Release); + tracing::error!( + endpoint = %endpoint, + "Permanent error from Azure Cosmos DB; disabling connector. {err}" + ); + return Err(ResilienceError::Request(err)); + } + + if is_perm || attempt >= config.max_retries { + return Err(ResilienceError::Request(err)); + } + + let backoff = backoff_delay(config.backoff, attempt); + let retry_after = retry_after_from_error(&err); + let delay = retry_after.map_or(backoff, |ra| ra.max(backoff)); + + tracing::warn!( + endpoint = %endpoint, + attempt = attempt + 1, + max_retries = config.max_retries, + delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX), + "Transient error from Azure Cosmos DB, retrying. {err}" + ); + tokio::time::sleep(delay).await; + attempt += 1; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::AtomicUsize; + + #[test] + fn backoff_method_parse_accepts_canonical_values() { + assert_eq!( + BackoffMethod::parse("exponential").expect("should parse 'exponential'"), + BackoffMethod::Exponential + ); + assert_eq!( + BackoffMethod::parse("Fibonacci").expect("should parse 'Fibonacci'"), + BackoffMethod::Fibonacci + ); + } + + #[test] + fn backoff_method_parse_rejects_unknown_values() { + let err = BackoffMethod::parse("linear").expect_err("'linear' should be rejected"); + assert!(err.contains("invalid backoff_method")); + } + + #[test] + fn backoff_delay_exponential_doubles_and_caps() { + assert_eq!( + backoff_delay(BackoffMethod::Exponential, 0), + RETRY_INITIAL_BACKOFF + ); + assert_eq!( + backoff_delay(BackoffMethod::Exponential, 1), + RETRY_INITIAL_BACKOFF * 2 + ); + assert_eq!( + backoff_delay(BackoffMethod::Exponential, 2), + RETRY_INITIAL_BACKOFF * 4 + ); + // Large attempt saturates at the cap. + assert_eq!( + backoff_delay(BackoffMethod::Exponential, 100), + RETRY_MAX_BACKOFF + ); + } + + #[test] + fn backoff_delay_fibonacci_grows_as_expected() { + // Factors follow F(attempt+2) with the conventional Fibonacci indexing + // F(1)=F(2)=1 → attempts 0, 1, 2, 3, 4 map to 1, 2, 3, 5, 8, ... + assert_eq!( + backoff_delay(BackoffMethod::Fibonacci, 0), + RETRY_INITIAL_BACKOFF + ); + assert_eq!( + backoff_delay(BackoffMethod::Fibonacci, 1), + RETRY_INITIAL_BACKOFF * 2 + ); + assert_eq!( + backoff_delay(BackoffMethod::Fibonacci, 2), + RETRY_INITIAL_BACKOFF * 3 + ); + assert_eq!( + backoff_delay(BackoffMethod::Fibonacci, 3), + RETRY_INITIAL_BACKOFF * 5 + ); + } + + #[test] + fn is_permanent_error_flags_auth_and_not_found() { + let auth_err = azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::Unauthorized, + error_code: None, + raw_response: None, + }, + std::io::Error::other("401"), + ); + assert!(is_permanent_error(&auth_err)); + + let forbidden = azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::Forbidden, + error_code: None, + raw_response: None, + }, + std::io::Error::other("403"), + ); + assert!(is_permanent_error(&forbidden)); + + let not_found = azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::NotFound, + error_code: None, + raw_response: None, + }, + std::io::Error::other("404"), + ); + assert!(is_permanent_error(¬_found)); + } + + #[test] + fn is_permanent_error_skips_transient_statuses() { + let throttled = azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::TooManyRequests, + error_code: None, + raw_response: None, + }, + std::io::Error::other("429"), + ); + assert!(!is_permanent_error(&throttled)); + + let server_error = azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::InternalServerError, + error_code: None, + raw_response: None, + }, + std::io::Error::other("500"), + ); + assert!(!is_permanent_error(&server_error)); + + let io_err = azure_core::Error::new(ErrorKind::Io, std::io::Error::other("io")); + assert!(!is_permanent_error(&io_err)); + } + + fn make_error_with_headers( + status: azure_core::http::StatusCode, + headers: Vec<(HeaderName, String)>, + ) -> azure_core::Error { + use azure_core::http::response::RawResponse; + + let mut hs = Headers::new(); + for (name, value) in headers { + hs.insert(name, value); + } + let raw = RawResponse::from_bytes(status, hs, Vec::::new()); + azure_core::Error::new( + ErrorKind::HttpResponse { + status, + error_code: None, + raw_response: Some(Box::new(raw)), + }, + std::io::Error::other("sdk"), + ) + } + + #[test] + fn retry_after_prefers_millisecond_header() { + let err = make_error_with_headers( + azure_core::http::StatusCode::TooManyRequests, + vec![(X_MS_RETRY_AFTER_MS, "250".into())], + ); + assert_eq!( + retry_after_from_error(&err), + Some(Duration::from_millis(250)) + ); + } + + #[test] + fn retry_after_falls_back_to_seconds_header() { + let err = make_error_with_headers( + azure_core::http::StatusCode::ServiceUnavailable, + vec![(RETRY_AFTER_HEADER, "3".into())], + ); + assert_eq!(retry_after_from_error(&err), Some(Duration::from_secs(3))); + } + + #[test] + fn retry_after_is_none_when_header_missing() { + let err = make_error_with_headers(azure_core::http::StatusCode::TooManyRequests, vec![]); + assert_eq!(retry_after_from_error(&err), None); + } + + #[test] + fn inflight_guard_increments_and_decrements() { + let counter = Arc::new(AtomicU64::new(0)); + { + let _guard = InflightGuard::enter(Arc::::clone(&counter)); + assert_eq!(counter.load(Ordering::Relaxed), 1); + } + assert_eq!(counter.load(Ordering::Relaxed), 0); + } + + #[test] + fn inflight_guard_tracks_nested_entries() { + let counter = Arc::new(AtomicU64::new(0)); + let g1 = InflightGuard::enter(Arc::::clone(&counter)); + let g2 = InflightGuard::enter(Arc::::clone(&counter)); + assert_eq!(counter.load(Ordering::Relaxed), 2); + drop(g2); + assert_eq!(counter.load(Ordering::Relaxed), 1); + drop(g1); + assert_eq!(counter.load(Ordering::Relaxed), 0); + } + + #[tokio::test] + async fn run_with_resilience_short_circuits_when_disabled() { + let config = CosmosResilienceConfig::default(); + config.disabled.store(true, Ordering::Release); + let attempts = Arc::new(AtomicUsize::new(0)); + let attempts_clone = Arc::::clone(&attempts); + let result: Result<(), _> = run_with_resilience(&config, "https://x", || { + let a = Arc::::clone(&attempts_clone); + async move { + a.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + }) + .await; + assert!(matches!(result, Err(ResilienceError::Disabled))); + assert_eq!(attempts.load(Ordering::Relaxed), 0); + } + + #[tokio::test] + async fn run_with_resilience_retries_transient_then_succeeds() { + let config = CosmosResilienceConfig { + max_retries: 3, + backoff: BackoffMethod::Exponential, + ..CosmosResilienceConfig::default() + }; + let attempts = Arc::new(AtomicUsize::new(0)); + let attempts_clone = Arc::::clone(&attempts); + let result: Result = tokio::time::timeout( + Duration::from_secs(30), + run_with_resilience(&config, "https://x", || { + let a = Arc::::clone(&attempts_clone); + async move { + let n = a.fetch_add(1, Ordering::Relaxed); + if n < 2 { + Err(azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::TooManyRequests, + error_code: None, + raw_response: None, + }, + std::io::Error::other("429"), + )) + } else { + Ok(42) + } + } + }), + ) + .await + .expect("future did not time out"); + assert_eq!(result.expect("operation should succeed after retries"), 42); + assert_eq!(attempts.load(Ordering::Relaxed), 3); + } + + #[tokio::test] + async fn run_with_resilience_surfaces_after_max_retries() { + let config = CosmosResilienceConfig { + max_retries: 2, + backoff: BackoffMethod::Exponential, + ..CosmosResilienceConfig::default() + }; + let attempts = Arc::new(AtomicUsize::new(0)); + let attempts_clone = Arc::::clone(&attempts); + let result: Result = tokio::time::timeout( + Duration::from_secs(30), + run_with_resilience(&config, "https://x", || { + let a = Arc::::clone(&attempts_clone); + async move { + a.fetch_add(1, Ordering::Relaxed); + Err(azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::InternalServerError, + error_code: None, + raw_response: None, + }, + std::io::Error::other("500"), + )) + } + }), + ) + .await + .expect("future did not time out"); + assert!(matches!(result, Err(ResilienceError::Request(_)))); + // max_retries=2 means: initial attempt + 2 retries = 3 total calls. + assert_eq!(attempts.load(Ordering::Relaxed), 3); + } + + #[tokio::test] + async fn run_with_resilience_latches_disabled_on_permanent_error() { + let config = CosmosResilienceConfig { + max_retries: 3, + disable_on_permanent_error: true, + ..CosmosResilienceConfig::default() + }; + let attempts = Arc::new(AtomicUsize::new(0)); + let attempts_clone = Arc::::clone(&attempts); + let result: Result = run_with_resilience(&config, "https://x", || { + let a = Arc::::clone(&attempts_clone); + async move { + a.fetch_add(1, Ordering::Relaxed); + Err(azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::Forbidden, + error_code: None, + raw_response: None, + }, + std::io::Error::other("403"), + )) + } + }) + .await; + assert!(matches!(result, Err(ResilienceError::Request(_)))); + // Permanent errors short-circuit without retrying. + assert_eq!(attempts.load(Ordering::Relaxed), 1); + assert!(config.disabled.load(Ordering::Acquire)); + } + + #[tokio::test] + async fn run_with_resilience_does_not_latch_when_disable_off() { + let config = CosmosResilienceConfig { + max_retries: 3, + disable_on_permanent_error: false, + ..CosmosResilienceConfig::default() + }; + let result: Result = run_with_resilience(&config, "https://x", || async { + Err(azure_core::Error::new( + ErrorKind::HttpResponse { + status: azure_core::http::StatusCode::Unauthorized, + error_code: None, + raw_response: None, + }, + std::io::Error::other("401"), + )) + }) + .await; + assert!(matches!(result, Err(ResilienceError::Request(_)))); + assert!(!config.disabled.load(Ordering::Acquire)); + } +} diff --git a/crates/data_components/src/cosmosdb/schema.rs b/crates/data_components/src/cosmosdb/schema.rs new file mode 100644 index 0000000000..dbf560ed6e --- /dev/null +++ b/crates/data_components/src/cosmosdb/schema.rs @@ -0,0 +1,185 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Schema inference for Cosmos DB documents. +//! +//! Cosmos DB `NoSQL` does not expose a static schema. The first `N` documents +//! returned from the configured query are sampled and handed to Arrow's +//! [`infer_json_schema_from_iterator`] to derive a best-effort Arrow schema. +//! +//! Cosmos always stamps every document with `_rid`, `_self`, `_etag`, +//! `_attachments`, and `_ts` system fields. These are stripped from the +//! sample set before inference to avoid polluting downstream tables with +//! metadata columns the user almost never wants to query. + +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::json::reader::infer_json_schema_from_iterator; +use serde_json::Value; +use snafu::ResultExt; +use std::sync::Arc; + +use super::{Error, SchemaInferenceSnafu}; + +/// System fields stamped on every Cosmos document. Stripped prior to schema +/// inference so they never become user-visible columns. +const COSMOS_SYSTEM_FIELDS: &[&str] = &["_rid", "_self", "_etag", "_attachments", "_ts"]; + +/// Strip Cosmos DB-internal system fields from a top-level JSON object. Any +/// non-object value is returned unchanged. +#[must_use] +pub fn strip_system_fields(value: Value) -> Value { + match value { + Value::Object(mut map) => { + for field in COSMOS_SYSTEM_FIELDS { + map.remove(*field); + } + Value::Object(map) + } + other => other, + } +} + +/// Infer an Arrow schema from a slice of sampled Cosmos documents. Callers +/// are expected to have already run the values through [`strip_system_fields`]. +/// +/// Returns an empty [`Schema`] when `samples` is empty — callers should map +/// that condition to a user-facing `EmptyContainer` error rather than letting +/// it propagate as a successful inference. +/// +/// # Errors +/// Returns an error if Arrow's JSON schema inference fails. +pub fn infer_schema(samples: &[Value]) -> Result { + if samples.is_empty() { + return Ok(Arc::new(Schema::empty())); + } + + let schema = infer_json_schema_from_iterator( + samples + .iter() + .map(Result::<_, arrow::error::ArrowError>::Ok), + ) + .context(SchemaInferenceSnafu)?; + + Ok(Arc::new(schema)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::DataType; + use serde_json::json; + + #[test] + fn strip_system_fields_removes_all_known_system_fields() { + let doc = json!({ + "_rid": "rid", + "_self": "self", + "_etag": "etag", + "_attachments": "attachments", + "_ts": 1234, + "id": "doc1", + "payload": "keep me", + }); + let stripped = strip_system_fields(doc); + let obj = stripped + .as_object() + .expect("result should be a JSON object"); + assert!(!obj.contains_key("_rid")); + assert!(!obj.contains_key("_self")); + assert!(!obj.contains_key("_etag")); + assert!(!obj.contains_key("_attachments")); + assert!(!obj.contains_key("_ts")); + assert_eq!(obj.get("id").and_then(Value::as_str), Some("doc1")); + assert_eq!(obj.get("payload").and_then(Value::as_str), Some("keep me")); + } + + #[test] + fn strip_system_fields_ignores_non_object_values() { + assert_eq!(strip_system_fields(json!(null)), json!(null)); + assert_eq!(strip_system_fields(json!("string")), json!("string")); + assert_eq!(strip_system_fields(json!(42)), json!(42)); + assert_eq!(strip_system_fields(json!([1, 2, 3])), json!([1, 2, 3])); + } + + #[test] + fn strip_system_fields_does_not_strip_nested_occurrences() { + // System field stripping only applies at the top level — nested + // objects (user-controlled payloads) are preserved unchanged. + let doc = json!({ + "_rid": "top_rid", + "nested": {"_rid": "nested_rid", "value": 1}, + }); + let stripped = strip_system_fields(doc); + let obj = stripped + .as_object() + .expect("result should be a JSON object"); + assert!(!obj.contains_key("_rid")); + let nested = obj + .get("nested") + .and_then(Value::as_object) + .expect("nested field should be an object"); + assert_eq!( + nested.get("_rid").and_then(Value::as_str), + Some("nested_rid") + ); + } + + #[test] + fn infer_schema_returns_empty_schema_for_empty_sample() { + let schema = infer_schema(&[]).expect("infer_schema on empty input should succeed"); + assert_eq!(schema.fields().len(), 0); + } + + #[test] + fn infer_schema_produces_fields_from_sample_documents() { + let samples = vec![ + json!({"id": "1", "count": 10}), + json!({"id": "2", "count": 20}), + ]; + let schema = infer_schema(&samples).expect("infer_schema should succeed"); + let id_field = schema + .field_with_name("id") + .expect("schema should have 'id' field"); + let count_field = schema + .field_with_name("count") + .expect("schema should have 'count' field"); + assert_eq!(id_field.data_type(), &DataType::Utf8); + assert!(matches!( + count_field.data_type(), + DataType::Int64 | DataType::Float64 + )); + } + + #[test] + fn infer_schema_unions_mixed_documents() { + // Documents with different field sets should merge into a single + // schema that contains the union of fields. + let samples = vec![ + json!({"id": "1", "only_in_first": "x"}), + json!({"id": "2", "only_in_second": 42}), + ]; + let schema = infer_schema(&samples).expect("infer_schema should succeed"); + schema + .field_with_name("id") + .expect("schema should have 'id' field"); + schema + .field_with_name("only_in_first") + .expect("schema should have 'only_in_first' field"); + schema + .field_with_name("only_in_second") + .expect("schema should have 'only_in_second' field"); + } +} diff --git a/crates/data_components/src/lib.rs b/crates/data_components/src/lib.rs index e52129a962..8b28b23f7a 100644 --- a/crates/data_components/src/lib.rs +++ b/crates/data_components/src/lib.rs @@ -23,6 +23,8 @@ use datafusion::{catalog::CatalogProvider, datasource::TableProvider, sql::Table pub mod arrow; #[cfg(feature = "clickhouse")] pub mod clickhouse; +#[cfg(feature = "cosmosdb")] +pub mod cosmosdb; #[cfg(feature = "databricks")] pub mod databricks; #[cfg(feature = "debezium")] diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index 88933d9dcb..a55d092e87 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -265,6 +265,7 @@ adbc = ["dep:adbc_core", "dep:adbc_driver_manager", "dep:blake3", "dep:sha2", "d aws-secrets-manager = ["dep:aws-sdk-sts", "runtime-secrets/aws-secrets-manager"] bedrock = [] clickhouse = ["db_connection_pool/clickhouse", "data_components/clickhouse"] +cosmosdb = ["data_components/cosmosdb"] cuda = ["llms/cuda"] databricks = ["data_components/databricks"] debezium = ["kafka", "data_components/debezium"] diff --git a/crates/runtime/src/dataconnector/cosmosdb.rs b/crates/runtime/src/dataconnector/cosmosdb.rs new file mode 100644 index 0000000000..c446d0be23 --- /dev/null +++ b/crates/runtime/src/dataconnector/cosmosdb.rs @@ -0,0 +1,591 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Azure Cosmos DB (`NoSQL` / Core SQL API) data connector. +//! +//! Read-only scan with schema inferred from a sample of documents, backed by +//! RC-level connection resilience (concurrency limiting, retry with backoff, +//! permanent-error detection) and an `inflight_operations` metric gauge. + +use std::any::Any; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, LazyLock, Mutex}; + +use async_trait::async_trait; +use data_components::cosmosdb::{ + BackoffMethod, CosmosDBCredential, CosmosDBTableProvider, CosmosResilienceConfig, + DEFAULT_MAX_CONCURRENT_REQUESTS, DEFAULT_MAX_RETRIES, DEFAULT_QUERY, + DEFAULT_SCHEMA_INFER_MAX_RECORDS, build_container_client, + provider::CosmosDBTableProviderConfig, +}; +use datafusion::datasource::TableProvider; +use datafusion_table_providers::UnsupportedTypeAction as DFUnsupportedTypeAction; +use opentelemetry::KeyValue; +use tokio::sync::Semaphore; + +use super::{ + ConnectorComponent, ConnectorParams, DataConnector, DataConnectorError, DataConnectorFactory, + ParameterSpec, Parameters, +}; +use crate::component::ComponentType; +use crate::component::dataset::Dataset; +use crate::component::metrics::{MetricSpec, MetricType, MetricsProvider, ObserveMetricCallback}; + +const CONNECTOR_NAME: &str = "cosmosdb"; + +/// Semaphore paired with the numeric limit it was constructed with, so +/// mismatches across datasets targeting the same Cosmos account can be +/// detected and surfaced as a warning. +type SemaphoreEntry = (Arc, usize); + +/// Per-account-endpoint concurrency semaphores. Datasets that hit the same +/// Cosmos account share a single concurrency budget, matching the per-account +/// rate-limit model of Cosmos DB. +/// +/// Entries are never evicted during the runtime's lifetime: each slot holds an +/// `Arc` + `usize` (~40 bytes on 64-bit platforms), and typical +/// deployments configure a bounded set of accounts. Workloads that +/// dynamically materialize many distinct Cosmos accounts should treat this as +/// a known upper bound on memory use. +static COSMOS_CONCURRENCY_LIMITS: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Per-account-endpoint disabled-state flags. A permanent error (401/403/404) +/// observed by one dataset latches the connector for every dataset pointing +/// at the same account. Same memory footprint and eviction trade-off as +/// `COSMOS_CONCURRENCY_LIMITS` above. +static COSMOS_DISABLED_FLAGS: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +fn shared_semaphore(endpoint: &str, max_concurrent: usize) -> Arc { + let mut guard = COSMOS_CONCURRENCY_LIMITS + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if let Some((semaphore, existing_max)) = guard.get(endpoint) { + if *existing_max != max_concurrent { + tracing::warn!( + endpoint = %endpoint, + existing_max, + requested_max = max_concurrent, + "Multiple datasets target the same Cosmos DB account with different max_concurrent_requests values. Keeping the first-seen limit ({existing_max})." + ); + } + Arc::::clone(semaphore) + } else { + let semaphore = Arc::new(Semaphore::new(max_concurrent)); + guard.insert( + endpoint.to_string(), + (Arc::::clone(&semaphore), max_concurrent), + ); + semaphore + } +} + +fn shared_disabled_flag(endpoint: &str) -> Arc { + let mut guard = COSMOS_DISABLED_FLAGS + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + Arc::::clone( + guard + .entry(endpoint.to_string()) + .or_insert_with(|| Arc::new(AtomicBool::new(false))), + ) +} + +const COSMOSDB_METRICS: &[MetricSpec] = + &[ + MetricSpec::new("inflight_operations", MetricType::ObservableGaugeU64) + .description("Azure Cosmos DB operations currently holding a concurrency permit — incremented once per operation and held across retry backoff sleeps (not a pure in-flight-HTTP counter)") + .auto_register(), + ]; + +#[derive(Debug, Clone)] +struct CosmosDBMetricsProvider { + inflight_operations: Arc, +} + +impl MetricsProvider for CosmosDBMetricsProvider { + fn component_type(&self) -> ComponentType { + ComponentType::Dataset + } + + fn component_name(&self) -> &'static str { + CONNECTOR_NAME + } + + fn available_metrics(&self) -> &'static [MetricSpec] { + COSMOSDB_METRICS + } + + fn callback_to_observe_metric( + &self, + metric: &MetricSpec, + attributes: Vec, + ) -> Option { + match metric.name { + "inflight_operations" => { + let counter = Arc::::clone(&self.inflight_operations); + Some(ObserveMetricCallback::U64(Box::new(move |observer| { + observer.observe(counter.load(Ordering::Relaxed), &attributes); + }))) + } + _ => None, + } + } +} + +#[derive(Debug)] +pub struct CosmosDB { + params: Parameters, + /// Drives the `inflight_operations` metric gauge. Instantiated per + /// connector (one per dataset), so the exported value reflects in-flight + /// operations for that dataset rather than a shared per-account budget — + /// the shared concurrency budget itself is enforced via the endpoint-keyed + /// `COSMOS_CONCURRENCY_LIMITS` map, not this counter. + inflight_operations: Arc, + unsupported_type_action: Option, +} + +#[derive(Default, Debug, Copy, Clone)] +pub struct CosmosDBFactory {} + +impl CosmosDBFactory { + #[must_use] + pub fn new() -> Self { + Self {} + } + + #[must_use] + pub fn new_arc() -> Arc { + Arc::new(Self {}) as Arc + } +} + +const PARAMETERS: &[ParameterSpec] = &[ + ParameterSpec::component("account_endpoint") + .description("The Azure Cosmos DB account endpoint URL, e.g. 'https://my-account.documents.azure.com:443/'.") + .secret(), + ParameterSpec::component("account_key") + .description("The Azure Cosmos DB account primary or secondary key.") + .secret(), + ParameterSpec::component("connection_string") + .description("An Azure Cosmos DB connection string (AccountEndpoint=...;AccountKey=...). Takes precedence over account_endpoint/account_key if set.") + .secret(), + ParameterSpec::component("database") + .description("The Cosmos DB database name. Defaults to the first segment of the dataset `from:` path ('database.container')."), + ParameterSpec::runtime("query") + .description("Cosmos SQL query used to scan the container. Defaults to 'SELECT * FROM c'.") + .default(DEFAULT_QUERY), + ParameterSpec::runtime("schema_infer_max_records") + .description("Number of documents sampled during schema inference. Larger samples produce a more precise schema at the cost of additional RU consumption on dataset registration.") + .default("100"), + + ParameterSpec::runtime("max_concurrent_requests") + .description("Maximum number of concurrent Azure Cosmos DB requests per account endpoint, shared across all datasets pointing at the same account.") + .default("4"), + ParameterSpec::runtime("http_max_retries") + .description("Maximum number of retries for transient errors (429, 5xx, network) during the schema-inference sampling pass at dataset registration. Retries use the configured backoff strategy and honor Retry-After headers. Mid-stream pager errors during scan execution are not retried.") + .default("3"), + ParameterSpec::runtime("backoff_method") + .description("Backoff strategy between schema-inference sampling retries on transient errors. 'exponential' doubles the delay each attempt; 'fibonacci' follows the Fibonacci sequence.") + .one_of(&["exponential", "fibonacci"]) + .default("exponential"), + ParameterSpec::runtime("disable_on_permanent_error") + .description("When true, a permanent error (401/403/404) from Azure Cosmos DB latches the connector into a disabled state and short-circuits subsequent requests until Spice is restarted.") + .default("true") + .is_boolean(), +]; + +impl DataConnectorFactory for CosmosDBFactory { + fn as_any(&self) -> &dyn Any { + self + } + + fn create( + &self, + params: ConnectorParams, + ) -> Pin + Send>> { + let unsupported_type_action = params.unsupported_type_action; + Box::pin(async move { + let conn = CosmosDB { + params: params.parameters, + inflight_operations: Arc::new(AtomicU64::new(0)), + unsupported_type_action, + }; + Ok(Arc::new(conn) as Arc) + }) + } + + fn prefix(&self) -> &'static str { + CONNECTOR_NAME + } + + fn parameters(&self) -> &'static [ParameterSpec] { + PARAMETERS + } + + fn supports_unsupported_type_action(&self) -> bool { + true + } +} + +impl CosmosDB { + fn build_credential( + &self, + dataset: &Dataset, + ) -> Result { + if let Some(conn_str) = self.params.get("connection_string").expose().ok() { + return Ok(CosmosDBCredential::ConnectionString(conn_str.to_string())); + } + + let endpoint = self.params.get("account_endpoint").expose().ok(); + let key = self.params.get("account_key").expose().ok(); + + match (endpoint, key) { + (Some(endpoint), Some(key)) => Ok(CosmosDBCredential::Key { + endpoint: endpoint.to_string(), + key: key.to_string(), + }), + _ => Err(DataConnectorError::InvalidConfigurationNoSource { + dataconnector: CONNECTOR_NAME.to_string(), + connector_component: ConnectorComponent::from(dataset), + message: "Azure Cosmos DB requires either 'cosmosdb_connection_string' or both 'cosmosdb_account_endpoint' and 'cosmosdb_account_key'.".to_string(), + }), + } + } + + /// Materialize a resilience config from validated parameters. Per-endpoint + /// semaphore and disabled flag are shared across datasets that target the + /// same Cosmos account. + fn build_resilience(&self, endpoint: &str) -> CosmosResilienceConfig { + let max_concurrent_requests = self + .params + .get("max_concurrent_requests") + .expose() + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_MAX_CONCURRENT_REQUESTS) + .max(1); + + let max_retries = self + .params + .get("http_max_retries") + .expose() + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_MAX_RETRIES); + + let backoff_value = self + .params + .get("backoff_method") + .expose() + .ok() + .unwrap_or("exponential"); + let backoff = BackoffMethod::parse(backoff_value).unwrap_or_else(|message| { + tracing::warn!("{message}; falling back to 'exponential'."); + BackoffMethod::Exponential + }); + + let disable_on_permanent_error = self + .params + .get("disable_on_permanent_error") + .expose() + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(true); + + let semaphore = shared_semaphore(endpoint, max_concurrent_requests); + let disabled = shared_disabled_flag(endpoint); + + CosmosResilienceConfig { + max_retries, + backoff, + semaphore: Some(semaphore), + disable_on_permanent_error, + inflight: Arc::::clone(&self.inflight_operations), + disabled, + } + } +} + +/// Pure parsing helper for [`resolve_database_and_container`]. Split out so +/// it can be exercised in unit tests without constructing a full [`Dataset`]. +fn parse_database_and_container( + path: &str, + database_param: Option<&str>, +) -> Result<(String, String), String> { + // Accept either `database.container` or `database/container`, or just the + // container when `database` is explicitly set. + let (db_from_path, container) = if let Some((db, container)) = path.split_once('.') { + (Some(db.to_string()), container.to_string()) + } else if let Some((db, container)) = path.split_once('/') { + (Some(db.to_string()), container.to_string()) + } else { + (None, path.to_string()) + }; + + let database = match (database_param, db_from_path) { + (Some(d), _) => d.to_string(), + (None, Some(d)) => d, + (None, None) => { + return Err(format!( + "Could not determine Cosmos DB database from dataset path '{path}'. Expected 'database.container' or set the 'cosmosdb_database' parameter." + )); + } + }; + + if database.is_empty() { + return Err(format!( + "Could not determine Cosmos DB database from dataset path '{path}'. Expected 'database.container' or set the 'cosmosdb_database' parameter." + )); + } + + if container.is_empty() { + return Err(format!( + "Could not determine Cosmos DB container from dataset path '{path}'." + )); + } + + Ok((database, container)) +} + +/// Parse `database.container` / `database/container` from the dataset path. +/// If the configured `database` parameter is set, it overrides the database +/// segment and the path is treated as just the container name. +fn resolve_database_and_container( + dataset: &Dataset, + database_param: Option<&str>, +) -> Result<(String, String), DataConnectorError> { + parse_database_and_container(dataset.path(), database_param).map_err(|message| { + DataConnectorError::InvalidConfigurationNoSource { + dataconnector: CONNECTOR_NAME.to_string(), + connector_component: ConnectorComponent::from(dataset), + message, + } + }) +} + +#[async_trait] +impl DataConnector for CosmosDB { + fn as_any(&self) -> &dyn Any { + self + } + + async fn read_provider( + &self, + dataset: &Dataset, + ) -> Result, DataConnectorError> { + let credential = self.build_credential(dataset)?; + + let database_param = self.params.get("database").expose().ok(); + let (database, container) = resolve_database_and_container(dataset, database_param)?; + + let (container_client, endpoint) = + build_container_client(credential, &database, &container).map_err(|e| { + DataConnectorError::UnableToGetReadProvider { + dataconnector: CONNECTOR_NAME.to_string(), + connector_component: ConnectorComponent::from(dataset), + source: Box::new(e), + } + })?; + + let query = self + .params + .get("query") + .expose() + .ok() + .unwrap_or(DEFAULT_QUERY) + .to_string(); + + let schema_infer_max_records = match self + .params + .get("schema_infer_max_records") + .expose() + .ok() + { + Some(value) => match value.parse::() { + Ok(0) => { + tracing::warn!( + "Ignoring invalid schema_infer_max_records value '0' for dataset {}; using default value {}.", + dataset.name, + DEFAULT_SCHEMA_INFER_MAX_RECORDS + ); + DEFAULT_SCHEMA_INFER_MAX_RECORDS + } + Ok(v) => v, + Err(_) => { + tracing::warn!( + "Ignoring invalid schema_infer_max_records value '{}' for dataset {}; expected a positive integer, using default value {}.", + value, + dataset.name, + DEFAULT_SCHEMA_INFER_MAX_RECORDS + ); + DEFAULT_SCHEMA_INFER_MAX_RECORDS + } + }, + None => DEFAULT_SCHEMA_INFER_MAX_RECORDS, + }; + + let resilience = self.build_resilience(&endpoint); + + let mut config = CosmosDBTableProviderConfig::new(database, container, query) + .with_schema_infer_max_records(schema_infer_max_records) + .with_resilience(resilience); + + if let Some(action) = self.unsupported_type_action { + config = config.with_unsupported_type_action(action); + } + + let provider = CosmosDBTableProvider::try_new(container_client, endpoint, config) + .await + .map_err(|e| DataConnectorError::UnableToGetReadProvider { + dataconnector: CONNECTOR_NAME.to_string(), + connector_component: ConnectorComponent::from(dataset), + source: Box::new(e), + })?; + + Ok(Arc::new(provider)) + } + + fn metrics_provider(&self) -> Option> { + Some(Arc::new(CosmosDBMetricsProvider { + inflight_operations: Arc::::clone(&self.inflight_operations), + })) + } +} + +register_data_connector!("cosmosdb", CosmosDBFactory); + +#[cfg(test)] +mod tests { + use super::{parse_database_and_container, shared_disabled_flag, shared_semaphore}; + use std::sync::atomic::Ordering; + + #[test] + fn parses_dot_delimited_path() { + let (db, container) = parse_database_and_container("mydb.mycontainer", None) + .expect("dot-delimited path should parse"); + assert_eq!(db, "mydb"); + assert_eq!(container, "mycontainer"); + } + + #[test] + fn parses_slash_delimited_path() { + let (db, container) = parse_database_and_container("mydb/mycontainer", None) + .expect("slash-delimited path should parse"); + assert_eq!(db, "mydb"); + assert_eq!(container, "mycontainer"); + } + + #[test] + fn uses_database_param_when_path_is_container_only() { + let (db, container) = parse_database_and_container("mycontainer", Some("explicit_db")) + .expect("container-only path with explicit db should parse"); + assert_eq!(db, "explicit_db"); + assert_eq!(container, "mycontainer"); + } + + #[test] + fn database_param_overrides_path_segment() { + let (db, container) = + parse_database_and_container("path_db.mycontainer", Some("override_db")) + .expect("db param should override path segment"); + assert_eq!(db, "override_db"); + assert_eq!(container, "mycontainer"); + } + + #[test] + fn errors_when_no_database_can_be_determined() { + let err = parse_database_and_container("just_container", None) + .expect_err("missing db should be an error"); + assert!(err.contains("Could not determine Cosmos DB database")); + } + + #[test] + fn errors_on_empty_container_segment() { + let err = parse_database_and_container("mydb.", None) + .expect_err("empty container segment should be an error"); + assert!(err.contains("Could not determine Cosmos DB container")); + + let err = parse_database_and_container("mydb/", None) + .expect_err("empty container segment should be an error"); + assert!(err.contains("Could not determine Cosmos DB container")); + } + + #[test] + fn errors_on_empty_database_segment() { + let err = parse_database_and_container(".mycontainer", None) + .expect_err("empty database segment should be an error"); + assert!(err.contains("Could not determine Cosmos DB database")); + + let err = parse_database_and_container("/mycontainer", None) + .expect_err("empty database segment should be an error"); + assert!(err.contains("Could not determine Cosmos DB database")); + } + + #[test] + fn dot_takes_precedence_over_slash() { + // Documents current behavior: the first `.` wins even when a `/` is + // also present. Cosmos DB names do not legally contain `.`, so this + // mainly matters for malformed input. + let (db, container) = + parse_database_and_container("a/b.c", None).expect("dot takes precedence over slash"); + assert_eq!(db, "a/b"); + assert_eq!(container, "c"); + } + + #[test] + fn multiple_dots_split_at_first() { + let (db, container) = + parse_database_and_container("a.b.c", None).expect("multiple dots split at first"); + assert_eq!(db, "a"); + assert_eq!(container, "b.c"); + } + + #[test] + fn shared_semaphore_returns_same_instance_for_same_endpoint() { + // Use a unique endpoint per test to avoid cross-test interference + // through the process-wide `COSMOS_CONCURRENCY_LIMITS` map. + let endpoint = "https://shared-semaphore-same-endpoint.documents.azure.com:443/"; + let sem_a = shared_semaphore(endpoint, 4); + let sem_b = shared_semaphore(endpoint, 4); + assert!(std::sync::Arc::ptr_eq(&sem_a, &sem_b)); + } + + #[test] + fn shared_semaphore_keeps_first_seen_limit_on_mismatch() { + let endpoint = "https://shared-semaphore-mismatch.documents.azure.com:443/"; + let sem_a = shared_semaphore(endpoint, 4); + // A conflicting request should be resolved in favor of the first-seen + // limit rather than silently bumping or panicking. + let sem_b = shared_semaphore(endpoint, 16); + assert!(std::sync::Arc::ptr_eq(&sem_a, &sem_b)); + assert_eq!(sem_a.available_permits(), 4); + } + + #[test] + fn shared_disabled_flag_shares_state_across_lookups() { + let endpoint = "https://shared-disabled-flag.documents.azure.com:443/"; + let flag_a = shared_disabled_flag(endpoint); + let flag_b = shared_disabled_flag(endpoint); + assert!(std::sync::Arc::ptr_eq(&flag_a, &flag_b)); + flag_a.store(true, Ordering::SeqCst); + assert!(flag_b.load(Ordering::SeqCst)); + } +} diff --git a/crates/runtime/src/dataconnector/mod.rs b/crates/runtime/src/dataconnector/mod.rs index cd044faf5e..5e4fce7712 100644 --- a/crates/runtime/src/dataconnector/mod.rs +++ b/crates/runtime/src/dataconnector/mod.rs @@ -150,6 +150,8 @@ macro_rules! register_data_connector { pub mod abfs; #[cfg(feature = "adbc")] pub mod adbc; +#[cfg(feature = "cosmosdb")] +pub mod cosmosdb; #[cfg(feature = "debezium")] pub mod debezium; #[cfg(feature = "dynamodb")] diff --git a/crates/runtime/tests/cosmosdb/mod.rs b/crates/runtime/tests/cosmosdb/mod.rs new file mode 100644 index 0000000000..cb4a3d2b16 --- /dev/null +++ b/crates/runtime/tests/cosmosdb/mod.rs @@ -0,0 +1,226 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Azure Cosmos DB (`NoSQL`) connector integration tests. +//! +//! Tests that touch a live Cosmos account are marked `#[ignore]` and run with +//! `cargo test --features cosmosdb -- --ignored cosmosdb_live`. They read +//! credentials from the environment: +//! +//! * `COSMOSDB_CONNECTION_STRING` — full Azure connection string (preferred), OR +//! * `COSMOSDB_ACCOUNT_ENDPOINT` + `COSMOSDB_ACCOUNT_KEY` — discrete pieces. +//! * `COSMOSDB_INTEGRATION_DATABASE` (default `spice-integration`) +//! * `COSMOSDB_INTEGRATION_CONTAINER` (default `documents`) +//! +//! Tests that exercise only connector registration / parameter plumbing are +//! not ignored and run in CI. +//! +//! The Azure Cosmos emulator (`mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator`) +//! is intentionally NOT used here: its 3+ GB image and 3–5 minute cold-start +//! exceeds the budgets of the shared runner and `docker/mod.rs` +//! `CONTAINER_SEMAPHORE`. A future on-demand CI job can add it behind a +//! `cosmosdb-emulator` feature flag. + +#![allow(dead_code, clippy::allow_attributes)] + +use std::collections::HashMap; +use std::env; +use std::sync::Arc; + +use app::AppBuilder; +use runtime::Runtime; +use spicepod::{component::dataset::Dataset, param::Params}; + +use crate::{configure_test_datafusion, init_tracing, utils::test_request_context}; + +const DEFAULT_DATABASE: &str = "spice-integration"; +const DEFAULT_CONTAINER: &str = "documents"; + +/// Credentials + destination for live Cosmos integration tests. `None` if the +/// required env vars are unset, which lets `#[ignore]`-gated tests skip +/// cleanly when run without real Cosmos. +struct LiveConfig { + params: HashMap, + database: String, + container: String, +} + +fn live_config_from_env() -> Option { + let database = + env::var("COSMOSDB_INTEGRATION_DATABASE").unwrap_or_else(|_| DEFAULT_DATABASE.to_string()); + let container = env::var("COSMOSDB_INTEGRATION_CONTAINER") + .unwrap_or_else(|_| DEFAULT_CONTAINER.to_string()); + + let mut params: HashMap = HashMap::new(); + if env::var("COSMOSDB_CONNECTION_STRING").is_ok() { + params.insert( + "cosmosdb_connection_string".to_string(), + "${ env:COSMOSDB_CONNECTION_STRING }".to_string(), + ); + } else if env::var("COSMOSDB_ACCOUNT_ENDPOINT").is_ok() + && env::var("COSMOSDB_ACCOUNT_KEY").is_ok() + { + params.insert( + "cosmosdb_account_endpoint".to_string(), + "${ env:COSMOSDB_ACCOUNT_ENDPOINT }".to_string(), + ); + params.insert( + "cosmosdb_account_key".to_string(), + "${ env:COSMOSDB_ACCOUNT_KEY }".to_string(), + ); + } else { + return None; + } + + Some(LiveConfig { + params, + database, + container, + }) +} + +fn make_live_dataset(name: &str, config: &LiveConfig) -> Dataset { + let from = format!("cosmosdb:{}.{}", config.database, config.container); + let mut dataset = Dataset::new(from, name.to_string()); + dataset.params = Some(Params::from_string_map(config.params.clone())); + dataset +} + +/// Smoke test: the Cosmos DB connector must be reachable via the runtime's +/// factory registry and accept its parameter spec. Offline — no HTTP call — +/// so it can run in CI without credentials. +#[tokio::test] +async fn cosmosdb_connector_factory_is_registered() -> Result<(), anyhow::Error> { + let _tracing = init_tracing(Some("integration=info,info")); + + test_request_context() + .scope(async { + // Building the Runtime forces the `linkme` distributed slice to + // evaluate, which registers every compiled-in connector including + // `cosmosdb` via `register_data_connector!`. If the cosmosdb crate + // fails to link or the factory panics during registration, this + // test surfaces it without needing live credentials. + configure_test_datafusion(); + let _rt = Runtime::builder() + .with_app(AppBuilder::new("cosmosdb_smoke").build()) + .build() + .await; + Ok::<_, anyhow::Error>(()) + }) + .await +} + +/// Live test: SELECT against a real Cosmos account. Requires the env vars +/// documented at the top of this module. Skipped by default. +#[tokio::test] +#[ignore = "requires live Cosmos credentials (COSMOSDB_CONNECTION_STRING or COSMOSDB_ACCOUNT_ENDPOINT+KEY)"] +async fn cosmosdb_live_select_returns_rows() -> Result<(), anyhow::Error> { + let _tracing = init_tracing(Some("integration=debug,info")); + + let Some(config) = live_config_from_env() else { + panic!( + "cosmosdb_live_select_returns_rows: set COSMOSDB_CONNECTION_STRING (or \ + COSMOSDB_ACCOUNT_ENDPOINT + COSMOSDB_ACCOUNT_KEY) to run this test." + ); + }; + + test_request_context() + .scope(async { + let dataset = make_live_dataset("cosmos_live", &config); + let app = AppBuilder::new("cosmosdb_live") + .with_dataset(dataset) + .build(); + + configure_test_datafusion(); + let rt = Runtime::builder().with_app(app).build().await; + let cloned_rt = Arc::new(rt.clone()); + + tokio::select! { + () = tokio::time::sleep(std::time::Duration::from_secs(60)) => { + return Err(anyhow::anyhow!("Timed out waiting for Cosmos DB dataset to load")); + } + () = cloned_rt.load_components() => {} + } + + // Issue a SELECT — the dataset must be queryable end-to-end. We + // don't snapshot the rows because the container contents are + // operator-controlled; verifying the query completes is enough. + let df = rt + .datafusion() + .ctx + .sql("SELECT COUNT(*) as n FROM cosmos_live") + .await?; + let _batches = df.collect().await?; + + Ok::<_, anyhow::Error>(()) + }) + .await +} + +/// Live test: the resilience layer must let a SELECT succeed even when the +/// underlying account is lightly-loaded. Running it repeatedly exercises the +/// shared per-endpoint concurrency budget. +#[tokio::test] +#[ignore = "requires live Cosmos credentials"] +async fn cosmosdb_live_repeated_queries_share_budget() -> Result<(), anyhow::Error> { + let _tracing = init_tracing(Some("integration=info,info")); + + let Some(config) = live_config_from_env() else { + panic!("set Cosmos DB credentials to run this test"); + }; + + test_request_context() + .scope(async { + let dataset = make_live_dataset("cosmos_live_rep", &config); + let app = AppBuilder::new("cosmosdb_live_rep") + .with_dataset(dataset) + .build(); + + configure_test_datafusion(); + let rt = Runtime::builder().with_app(app).build().await; + let cloned_rt = Arc::new(rt.clone()); + + tokio::select! { + () = tokio::time::sleep(std::time::Duration::from_secs(60)) => { + return Err(anyhow::anyhow!("Timed out waiting for Cosmos DB dataset to load")); + } + () = cloned_rt.load_components() => {} + } + + // Three concurrent scans — the per-account semaphore limits + // in-flight operations to `max_concurrent_requests` (default 4), + // so this should always complete without error. + let mut handles = Vec::new(); + for _ in 0..3 { + let rt_clone = rt.clone(); + handles.push(tokio::spawn(async move { + let df = rt_clone + .datafusion() + .ctx + .sql("SELECT 1 FROM cosmos_live_rep LIMIT 1") + .await?; + let _ = df.collect().await?; + Ok::<_, anyhow::Error>(()) + })); + } + for handle in handles { + handle.await??; + } + + Ok::<_, anyhow::Error>(()) + }) + .await +} diff --git a/crates/runtime/tests/integration.rs b/crates/runtime/tests/integration.rs index cde1c4d836..d26c9d92ec 100644 --- a/crates/runtime/tests/integration.rs +++ b/crates/runtime/tests/integration.rs @@ -37,6 +37,8 @@ mod cayenne_catalog_ddl; mod clickbench; mod cluster; mod cors; +#[cfg(feature = "cosmosdb")] +mod cosmosdb; #[cfg(all(feature = "delta_lake", feature = "databricks"))] mod databricks_delta; #[cfg(all(feature = "delta_lake", feature = "databricks"))] diff --git a/docs/criteria/connectors/alpha.md b/docs/criteria/connectors/alpha.md index 3cdf9529a9..6fe3da28d9 100644 --- a/docs/criteria/connectors/alpha.md +++ b/docs/criteria/connectors/alpha.md @@ -11,6 +11,7 @@ All criteria must be met for the connector to be considered Alpha. As Alpha sign | ADBC | ➖ | | | Azure BlobFS | ➖ | | | Clickhouse | ➖ | | +| Cosmos DB (NoSQL) | ✅ | @lukekim | | Databricks (mode: delta_lake) | ✅ | @Sevenannn | | Databricks (mode: spark_connect) | ✅ | @Sevenannn | | Databricks (mode: sql_warehouse) | ➖ | | diff --git a/docs/criteria/connectors/beta.md b/docs/criteria/connectors/beta.md index 68d5c6d299..ce99830822 100644 --- a/docs/criteria/connectors/beta.md +++ b/docs/criteria/connectors/beta.md @@ -11,6 +11,7 @@ All criteria must be met for the connector to be considered Beta, with exception | ADBC | ➖ | | | Azure BlobFS | ➖ | | | Clickhouse | ➖ | | +| Cosmos DB (NoSQL) | ✅ | @lukekim | | Databricks (mode: delta_lake) | ✅ | @Sevenannn | | Databricks (mode: spark_connect) | ✅ | @Sevenannn | | Databricks (mode: sql_warehouse) | ➖ | | @@ -62,6 +63,7 @@ This table defines the required features and/or tests for each connector: | ADBC | ➖ | ➖ | ➖ | ➖ | ➖ | ☑️ | | Azure BlobFS | ✅ (1) | ➖ | ➖ | ➖ | ✅ | ☑️ | | Clickhouse | ✅ (100) | ➖ | ☑️ | ➖ | ✅ | ➖ | +| Cosmos DB (NoSQL) | ➖ | ➖ | ➖ | ➖ | ➖ | ☑️ | | Databricks (mode: delta_lake) | ✅ (1) | ➖ | ☑️ | ➖ | ✅ | ➖ | | Databricks (mode: spark_connect) | ✅ (100) | ➖ | ☑️ | ➖ | ➖ | ➖ | | Databricks (mode: sql_warehouse) | ➖ | ➖ | ☑️ | ➖ | ✅ | ➖ | diff --git a/docs/criteria/connectors/rc.md b/docs/criteria/connectors/rc.md index fd1162bc89..2fbcedfbdd 100644 --- a/docs/criteria/connectors/rc.md +++ b/docs/criteria/connectors/rc.md @@ -11,6 +11,7 @@ All criteria must be met for the connector to be considered [RC](../definitions. | ADBC | ➖ | | | Azure BlobFS | ➖ | | | Clickhouse | ➖ | | +| Cosmos DB (NoSQL) | ✅ | @lukekim | | Databricks (mode: delta_lake) | ✅ | @Sevenannn | | Databricks (mode: spark_connect) | ➖ | | | Databricks (mode: sql_warehouse) | ➖ | | @@ -62,6 +63,7 @@ This table defines the required features and/or tests for each connector: | ADBC | ➖ | ➖ | ☑️ | ➖ | ➖ | ☑️ | | Azure BlobFS | ✅ (1) | ✅ (1) | ☑️ | ➖ | ✅ | ☑️ | | Clickhouse | ✅ (100) | ✅ (100) | ✅ | ✅ | ✅ | ✅ | +| Cosmos DB (NoSQL) | ➖ | ➖ | ➖ | ➖ | ➖ | ☑️ | | Databricks (mode: delta_lake) | ✅ (1) | ✅ (1) | ☑️ | ✅ | ✅ | ✅ | | Databricks (mode: spark_connect) | ✅ (100) | ✅ (100) | ✅ | ✅ | ✅ | ✅ | | Databricks (mode: sql_warehouse) | ➖ | ➖ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/dev/cosmosdb.md b/docs/dev/cosmosdb.md new file mode 100644 index 0000000000..77047b1e8e --- /dev/null +++ b/docs/dev/cosmosdb.md @@ -0,0 +1,162 @@ +# Azure Cosmos DB (NoSQL / Core SQL) Data Connector + +Status: **RC** — read-only scan with RC-level connection resilience. + +## Configuration + +```yaml +datasets: + - from: cosmosdb:mydb.mycontainer + name: my_table + params: + # Option A — connection string (takes precedence) + cosmosdb_connection_string: ${secrets:cosmosdb_conn} + + # Option B — explicit endpoint + key + cosmosdb_account_endpoint: https://my-account.documents.azure.com:443/ + cosmosdb_account_key: ${secrets:cosmosdb_key} + + # Optional: override database (otherwise taken from `from:` path) + cosmosdb_database: mydb + # Optional: custom Cosmos SQL query (defaults to `SELECT * FROM c`) + query: SELECT * FROM c + # Optional: sample size for schema inference (default 100) + schema_infer_max_records: "100" + + # Optional resilience tuning (defaults shown) + max_concurrent_requests: "4" + http_max_retries: "3" + backoff_method: exponential # or "fibonacci" + disable_on_permanent_error: "true" +``` + +The dataset path accepts `database.container`, `database/container`, or just +`container` when `cosmosdb_database` is set explicitly. + +## Authentication + +Key-based authentication only, via either a full Cosmos DB connection string +or an explicit `AccountEndpoint` + `AccountKey` pair. Microsoft Entra ID / +managed identity support is tracked as a post-RC enhancement. + +## What's supported + +- Read-only (`SELECT`) scans via Cosmos SQL. +- Cross-partition query by default. +- Arrow schema inferred from a sample of documents (system fields + `_rid`, `_self`, `_etag`, `_attachments`, `_ts` are stripped). Schema + pinning is not currently supported — widen `schema_infer_max_records` + instead to stabilize inference when optional fields are sparse. +- Standard Spice acceleration (DuckDB / SQLite / Arrow in-memory) on top of + the connector. +- Connection resilience: per-account concurrency semaphore, bounded retries + with configurable backoff, `Retry-After` / `x-ms-retry-after-ms` handling, + permanent-error (401/403/404) detection that latches the connector disabled. +- `inflight_operations` metric gauge, exported via the runtime metrics + endpoint for dashboards. This gauge is scoped per dataset connector + instance, not as an account-wide aggregate across all datasets using the + same Cosmos account endpoint. +- `unsupported_type_action` plumbing — all-null sampled fields (inferred as + `DataType::Null`) are warn-and-dropped by default. + +## JSON → Arrow type mapping + +Cosmos stores documents as JSON. The connector samples up to +`schema_infer_max_records` documents and hands them to Arrow's JSON inference: + +| Cosmos / JSON value | Arrow data type | Notes | +| --------------------------- | --------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `"abc"` | `Utf8` | | +| integer (`42`, `-7`, ...) | `Int64` | JSON numbers without fractional part infer as `Int64`; widens to `Float64` if any sampled doc contains a decimal. | +| floating (`3.14`, `1.0e9`) | `Float64` | | +| `true` / `false` | `Boolean` | | +| object `{ ... }` | `Struct` | Nested objects are preserved as structs. | +| array `[ ... ]` | `List` | The element type is inferred from the first non-null item; heterogeneous arrays may surface as `Utf8` or require a wider sample to disambiguate. | +| all-null in sample | `Null` | Warn-dropped by default (`unsupported_type_action=warn`). Set `unsupported_type_action=string` to coerce to `Utf8`, or widen the sample so real values appear. | +| System fields (`_rid`, ...) | stripped | Never appear in the dataset schema. | + +Cosmos does not emit `Date`, `Time`, `Timestamp`, `Decimal`, or `Binary` +natively — they round-trip as strings and should be handled with `CAST` at +query time. + +## RC exceptions + +Per `docs/criteria/connectors/rc.md` row 66, the following tracks are +intentionally out-of-scope for this connector's RC: + +| Criterion | Status | Reason | +| ----------------- | ------ | ------------------------------------------------------------------------------------------- | +| TPC-H / TPC-DS | ➖ | Cosmos DB's SQL surface does not cover TPC workloads; exempt per the per-connector matrix. | +| Federation | ➖ | Cosmos SQL does not support joins across containers; no filter or projection push-down yet. | +| Data Correctness | ➖ | No TPC harness, so no correctness diff against a native CLI. | +| Streaming | ➖ | No change-feed support yet; `RefreshMode::Changes` is not wired. | +| Schema Inference | ☑️ | Inferred from a sample of documents — Cosmos DB has no native schema. | + +## What's not yet supported (post-RC tracking) + +- Filter / projection / limit push-down into Cosmos DB. +- Write (`INSERT` / `UPDATE` / `DELETE`). +- Change feed streaming (`RefreshMode::Changes`). +- Microsoft Entra ID / managed identity authentication. +- Fine-grained partition-key routing. + +## Resilience parameters + +These parameters satisfy the "Connection Resilience" section of +`docs/criteria/connectors/rc.md`. + +| Parameter | Default | Description | +| ---------------------------- | ------------- | ------------------------------------------------------------------------------------------------------- | +| `max_concurrent_requests` | `4` | Upper bound on in-flight requests per account endpoint. Shared across datasets targeting the same account. | +| `http_max_retries` | `3` | Maximum retries for transient errors (429, 5xx, network). | +| `backoff_method` | `exponential` | Backoff strategy: `exponential` (500ms × 2ⁿ, capped 30s) or `fibonacci` (500ms × Fₙ, capped 30s). | +| `disable_on_permanent_error` | `true` | Latch the connector disabled on 401/403/404 to avoid a thundering herd of failed requests. | + +Retries honor both the standard `Retry-After` header and the Cosmos-specific +`x-ms-retry-after-ms` header. The effective delay is `max(retry_after, backoff)`. + +**Retry scope:** `http_max_retries` / `backoff_method` apply to the schema +inference pass that runs at dataset registration. Errors surfaced *during* a +streaming scan propagate immediately to the caller — a `FeedPager` cannot be +safely rewound once rows have been emitted, so mid-stream retry would risk +duplicating output. Spice's dataset refresh layer handles retry at the query +boundary. The permanent-error latch (`disable_on_permanent_error`) still +applies on both paths, so a 401/403/404 from any request disables the +connector account-wide. + +The `inflight_operations` metric is automatically registered and reports the +current number of Cosmos requests holding a concurrency permit. + +## Integration tests + +Unit-level coverage lives in `crates/data_components/src/cosmosdb/` (32 tests +at time of RC) and `crates/runtime/src/dataconnector/cosmosdb.rs`. + +End-to-end tests against a live Cosmos account live at +`crates/runtime/tests/cosmosdb/`. Live tests are `#[ignore]`'d by default; +set `COSMOSDB_CONNECTION_STRING` (or `COSMOSDB_ACCOUNT_ENDPOINT` + `COSMOSDB_ACCOUNT_KEY`), +optionally `COSMOSDB_INTEGRATION_DATABASE` / `COSMOSDB_INTEGRATION_CONTAINER`, +then run: + +```bash +cargo test --features cosmosdb -p runtime --test integration -- --ignored cosmosdb_live +``` + +The Azure Cosmos emulator image (`mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator`) +is not used in CI — its 3+ GB size and 3–5 minute cold-start exceeds the +shared runner's budgets. A future `cosmosdb-emulator` feature flag can add +on-demand emulator tests. + +## Feature flag + +Built into the default `spiced` distribution; also available as the +`cosmosdb` Cargo feature for custom builds: + +```bash +SPICED_CUSTOM_FEATURES="cosmosdb" make build-runtime +``` + +## Cookbook recipe + +See [`examples/cosmosdb-connector/`](../../examples/cosmosdb-connector/) for +a copy-pasteable Spicepod that connects to Cosmos DB. diff --git a/examples/cosmosdb-connector/README.md b/examples/cosmosdb-connector/README.md new file mode 100644 index 0000000000..fe6510c9de --- /dev/null +++ b/examples/cosmosdb-connector/README.md @@ -0,0 +1,142 @@ +# Azure Cosmos DB Connector Example + +This example demonstrates the Azure Cosmos DB (NoSQL) data connector, which +lets you query Cosmos DB containers as Spice datasets. + +## Prerequisites + +- Spice CLI installed (`spice` command available). +- An Azure Cosmos DB account (NoSQL / Core SQL API). The [free tier][free] + is sufficient for trying this example. +- A container with some documents to query — see "Seeding sample data" below + if you need to populate one. + +[free]: https://learn.microsoft.com/azure/cosmos-db/try-free + +## Running the example + +1. Export your Cosmos DB connection string (copy it from the Azure portal, + "Keys" blade under your account): + + ```bash + export COSMOSDB_CONNECTION_STRING="AccountEndpoint=https://.documents.azure.com:443/;AccountKey=;" + ``` + +2. Navigate to this directory: + + ```bash + cd examples/cosmosdb-connector + ``` + +3. Edit `spicepod.yaml` to point the `from:` at your `database.container`. + +4. Start Spice: + + ```bash + spice run + ``` + +5. In another terminal, connect to the SQL REPL and try a query: + + ```bash + spice sql + ``` + + ```sql + SELECT COUNT(*) FROM products; + SELECT * FROM products LIMIT 5; + ``` + +## Datasets + +`spicepod.yaml` declares three datasets that cover the common use cases: + +### `products` + +Full scan over a single container using the default query +(`SELECT * FROM c`). Schema is inferred from the first 100 documents. + +### `active_orders` + +Custom Cosmos SQL query with a WHERE clause. Useful when the container is +large and you only want a subset surfaced as a dataset. + +### `products_pinned` + +Dataset with an explicit Arrow schema via `columns:`. Use this pattern when +the sample-based inference disagrees with your production schema (e.g. an +optional field that is null in the first N documents but not in general). + +## Parameters used in this example + +### Authentication + +- `cosmosdb_connection_string`: Full Azure connection string (takes + precedence over explicit endpoint + key). Recommended for local dev. +- `cosmosdb_account_endpoint` + `cosmosdb_account_key`: Discrete pieces, + useful when the endpoint and key are stored separately (e.g. Key Vault). + +### Data shape + +- `cosmosdb_database`: Overrides the database name parsed from the `from:` + path. Leave unset when the path is `database.container`. +- `query`: Custom Cosmos SQL query. Defaults to `SELECT * FROM c`. +- `schema_infer_max_records`: Number of documents sampled during schema + inference. Default `100`. Larger samples produce a more precise schema at + the cost of additional RU consumption on dataset registration. + +### Resilience tuning + +- `max_concurrent_requests`: Per-account concurrency budget. Default `4`. +- `http_max_retries`: Retries for transient errors. Default `3`. +- `backoff_method`: `exponential` (default) or `fibonacci`. +- `disable_on_permanent_error`: Default `true`. Latches the connector + disabled on 401/403/404 to prevent a thundering herd of failed requests. + Set to `"false"` during development if you'd rather see every failure. + +## Seeding sample data + +If you need a container to query, the snippet below creates the `products` +container used in `spicepod.yaml` and populates it with a few rows. Requires +the Azure CLI and the `az cosmosdb sql` extension. + +```bash +# Replace with your account + resource group +ACCOUNT=your-cosmos-account +RG=your-resource-group +DB=store +CONTAINER=products + +az cosmosdb sql database create --account-name "$ACCOUNT" --resource-group "$RG" --name "$DB" +az cosmosdb sql container create --account-name "$ACCOUNT" --resource-group "$RG" \ + --database-name "$DB" --name "$CONTAINER" --partition-key-path /id + +# Insert a few documents via the Data Explorer, the REST API, or the SDK of your choice. +``` + +## Troubleshooting + +### Authentication failed (401 / 403) + +The connector latches disabled after a 401/403/404. Fix the credentials or +grants in `spicepod.yaml`, then restart `spice run`. Set +`disable_on_permanent_error: "false"` only if you want the connector to +keep retrying every failure. + +### `EmptyContainer` error on dataset load + +Schema is inferred from the sample. If your query returns zero documents at +load time, the connector cannot produce a schema. Either populate the +container, widen the `query`, or pin a schema via `columns:`. + +### High RU consumption on load + +Each dataset registration samples up to `schema_infer_max_records` +documents. Lower that value — or pin a schema — if you want to avoid the +upfront RU cost. + +## Learn more + +- [Cosmos DB Connector Reference](../../docs/dev/cosmosdb.md) +- [RC Release Criteria](../../docs/criteria/connectors/rc.md) +- [Spice Documentation](https://docs.spice.ai) diff --git a/examples/cosmosdb-connector/queries.sql b/examples/cosmosdb-connector/queries.sql new file mode 100644 index 0000000000..48c76ec904 --- /dev/null +++ b/examples/cosmosdb-connector/queries.sql @@ -0,0 +1,35 @@ +-- Azure Cosmos DB connector: example queries for the spicepod in this directory. +-- Replace column names to match your container's schema. + +-- Count rows across the full container. +SELECT COUNT(*) AS total FROM products; + +-- Preview the first few documents. +SELECT * FROM products LIMIT 10; + +-- Project a single column across all rows. +SELECT id, name, price FROM products ORDER BY price DESC LIMIT 10; + +-- Simple aggregation. +SELECT + category, + COUNT(*) AS count, + AVG(price) AS avg_price, + MAX(price) AS max_price +FROM products +GROUP BY category +ORDER BY count DESC; + +-- Custom-query dataset. +SELECT COUNT(*) AS active_orders FROM active_orders; + +-- Join across two Cosmos-backed datasets. Spice federates the join in the +-- local DataFusion engine — Cosmos DB itself does not support joins across +-- containers. +SELECT + o.id AS order_id, + p.name AS product_name, + p.price AS unit_price +FROM active_orders o +JOIN products p ON o.product_id = p.id +LIMIT 50; diff --git a/examples/cosmosdb-connector/spicepod.yaml b/examples/cosmosdb-connector/spicepod.yaml new file mode 100644 index 0000000000..4a3331160d --- /dev/null +++ b/examples/cosmosdb-connector/spicepod.yaml @@ -0,0 +1,38 @@ +version: v2 +kind: Spicepod +name: cosmosdb-connector-example + +datasets: + # Full-scan dataset using the default query (`SELECT * FROM c`). + # Schema is inferred from the first 100 documents. + - from: cosmosdb:store.products + name: products + description: All products, schema inferred from the first 100 documents. + params: + cosmosdb_connection_string: ${ env:COSMOSDB_CONNECTION_STRING } + + # Custom Cosmos SQL query — useful when the container is large and only a + # subset should be surfaced as a dataset. + - from: cosmosdb:store.orders + name: active_orders + description: Orders whose status is 'active'. + params: + cosmosdb_connection_string: ${ env:COSMOSDB_CONNECTION_STRING } + query: "SELECT * FROM c WHERE c.status = 'active'" + # Sample more documents for a more stable schema inference. + schema_infer_max_records: "500" + + # Resilience-tuned dataset — raises the per-account concurrency budget + # (default 4) and widens the retry envelope for datasets expected to run + # under heavy load. + - from: cosmosdb:store.products + name: products_heavy_load + description: Products, tuned for heavy-load query patterns. + params: + cosmosdb_connection_string: ${ env:COSMOSDB_CONNECTION_STRING } + max_concurrent_requests: "8" + http_max_retries: "5" + backoff_method: fibonacci + # Widen the sample to stabilize inference when optional fields are null + # in the first 100 documents but not in general. + schema_infer_max_records: "500" From 924691c85f41022efe26577c167d55a535e6dc0a Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 20 Apr 2026 21:45:55 -0700 Subject: [PATCH 2/4] feat(datafusion): flatten_json_properties + json_tree UDTFs (#10406) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(datafusion): add flatten_json_properties and json_tree UDTFs (M1) M1 skeleton of the `flatten_json_properties` table function from #10399 — recursively walks a JSON-Schema-shaped document's `properties` tree and emits one row per field with path, parent_path, name, description, type, required, format, enum_values, and metadata columns. Also adds `json_tree`, a schema-agnostic recursive JSON walker modeled on DuckDB/SQLite's function of the same name (cols: key, value, type, atom, id, parent, fullkey, path) so users have a generic alternative when their input isn't JSON-Schema-shaped. Both are experimental and gated behind `flatten-json-properties` and `json-tree` Cargo features (off by default). M1 accepts only literal JSON string arguments; per-row LATERAL invocation with a column reference lands in M2 alongside `$ref` / `allOf` / `oneOf` / `anyOf` resolution, `items.properties`, `additionalProperties` maps, the options struct, cycle detection, and metrics. Refs #10399 * feat(datafusion): complete M2-M4 for flatten_json_properties + json_tree M1 shipped a `properties`-only skeleton behind a feature flag. This commit lands the rest of the milestones for both functions. M2 — Full shape coverage: - `items.properties` — arrays of objects; leaves emit at `array.field`. - `additionalProperties` — typed maps; `type = "map"` and children at `map.child`. - `allOf` / `oneOf` / `anyOf` — fields merged across branches with first- declaration dedupe; `required` is union across branches. - Local `$ref` resolution (JSON Pointer syntax, including `~0` / `~1` escapes) with an active-ref set for cycle detection — cycles yield a `kind=cycle` metric, no stack overflow. - External `$ref` URIs — surfaced as `type = "ref"` rows with the URI captured in `metadata`. Never dereferenced (no network / file IO). - Options surface (named args on both UDTF and planning path): `max_depth`, `max_rows`, `max_bytes`, `dialect`, `include_internal`, `path_style` (`dot` or `json-pointer`). - OpenTelemetry counters: `flatten_json_properties_invocations_total`, `_rows_emitted_total`, `_errors_total{kind}` where kind ∈ {parse, depth_exceeded, row_cap_hit, cycle, input_too_large}. Same set for `json_tree` (with applicable kinds). - Scalar UDF companion registered under the same name, returning `List>` — gives per-row / LATERAL semantics via `UNNEST(flatten_json_properties(s.body))`. `json_tree` brought to parity: max_depth / max_bytes options, scalar UDF variant, cycle-independent depth cap, metrics. M3 — UX + perf: - Cookbook recipe at `examples/flatten-json-properties/` with a worked spicepod.yaml (dataset → view via UNNEST → DuckDB acceleration → column-level embeddings → vector_search) plus a 3-document sample. - Bench harness at `crates/runtime/benches/flatten_json_properties.rs` with Criterion groups for flat-schema fan-out, nested depth, and a 1k-schema catalog simulation. M4 — Release decision: - Feature flags dropped. Both UDTFs + UDFs register unconditionally on every build. Default behavior change vs M1: `include_internal` is now `false` (spec default), so container rows (`object` / `array` / `map`) are suppressed unless the caller opts in. 32 unit tests covering the full shape matrix, ref resolution, cycle termination, option parsing, limit tripping, path-style variants, scalar UDF per-row dispatch with NULLs, and UDTF plan integration. Refs #10399 * refactor(datafusion/udtf): simplify walker per /simplify review - Replace hand-rolled `resolve_local_ref` with `serde_json::Value::pointer`. - Delete `collect_effective_owned` and the `Cow<'static>` lifetime-laundering dance; everything walked lives under the walker's `&'a Value` root, so `&'a Value` suffices. Removes two identical recursion paths and the deep target clone on every `$ref` resolution. - Drop the dead `depth` parameter from `collect_effective`. - Hoist `property_fields` / `tree_fields` into static `LazyLock` handles so the schema isn't reallocated on every call. - Extract `build_tree_arrays` in `json_tree` so `rows_to_batch` and the scalar-UDF struct-array builder share one implementation. - Borrow-not-clone for `HashSet<&str>` required / seen_names in the walker. - Strip WHAT-style comments and task-references from the bench. * fix(datafusion/udtf): address PR review feedback - Update copyright headers to 2024-2026 across the new UDTF files. - Tighten scalar UDF signatures (`flatten_json_properties` / `json_tree`) to accept Utf8 / LargeUtf8 / Utf8View; normalize via `cast` so non-Utf8 string columns no longer panic in `as_string_array`. - Cap combinator / `$ref` expansion in `collect_effective` by threading a ref-depth counter through recursion; prevents pathological chains from bypassing `max_depth` / exhausting the stack. - Clarify `dialect` option semantics in docs: currently only tags invocation metrics; OpenAPI-specific walker behavior is future scope. - `compute_type` no longer treats non-object `properties` / non-object-or- array `items` as `object` / `array`. - Collapse duplicate-row emission in `handle_field`: recurse once on the original `spec` so `walk_schema`'s `seen_names` de-duplicates fields across allOf/oneOf/anyOf / `$ref` branches. - Document single-node-only scan for both UDTFs (cluster mode requires a `UdtfArgs` proto variant + codec, tracked as follow-up). - Fix three branch-local clippy `collapsible_if` errors and annotate `emit_row`'s argument count. * fix(datafusion/udtf): address second round of PR review - `json_tree`: root row now emits `path = NULL` (field is nullable) to match DuckDB / SQLite `json_tree` semantics; children still carry the parent fullkey as `path`. - `json_tree`: array element rows now set `key = idx.to_string()` so consumers can distinguish array siblings (previously NULL). - `flatten_json_properties`: container fields with no walkable children (array of primitives, map of primitives, empty object) are now emitted as leaf rows in `include_internal = false` mode, so the field still appears in output. - Deny `flatten_json_properties` / `json_tree` scalar UDFs for federation pushdown; add them to the existing `deny_list_blocks_spice_builtins` test so regressions are caught. README double-pipe comment was a false positive (the file already uses single `|` with `\|` escapes inside cells). * fix(datafusion/udtf): address round-3 PR review - `json_tree`: add `max_rows` option (default 1,000,000) so bounded `max_bytes` input can't explode into unbounded row counts. Walker records `row_cap_hit` metric when hit and truncates cleanly. - `json_tree`: clarify module-level docs — named options are UDTF-form only; the scalar UDF takes just the JSON argument with default caps. - Both scalar UDFs now truncate deterministically at `i32::MAX` flattened rows (with a `row_cap_hit` metric) instead of returning a query-level `Execution` error on `List` offset overflow. Preserves the "never a query-level error" contract. Not addressed: re-raised comments on `DataSourceExec` / cluster-mode `UdtfExec` wrapping — documented as follow-up scope in the prior commit; wrapping requires a new `UdtfArgs` proto variant + codec. * style: cargo fmt line-wrap in flatten_json_properties scalar UDF * fix(datafusion/udtf): bracket-quote JSON-path keys with hyphens SQLite / DuckDB `json_tree` path shorthand only accepts identifier-style keys; anything else must be bracket-quoted so consumers can re-parse the `fullkey`. Previously a key like `has-hyphen` was rendered as `$.a-b`, which isn't a valid shorthand. Now forces bracket-quoting for keys with any non-identifier character, and extends the existing special-character test to cover hyphens. * fix(datafusion/udtf): switch scalar UDFs to LargeList> Copilot flagged that i32 ListArray offsets could silently truncate results when the flattened row count across a batch exceeds i32::MAX (only a metric signal was emitted). Silent incomplete results risk query correctness. Switching to LargeList (i64 offsets) makes overflow effectively impossible with no behavior change — UNNEST works transparently on both variants. Drops the `max_flattened_rows` truncation path entirely. * style(datafusion/udtf): fix pedantic clippy + fmt errors CI's `make lint-rust` uses `clippy::pedantic + clippy::allow_attributes + clippy::unwrap_used + clippy::expect_used`, which surfaced: - `#[allow(clippy::too_many_arguments)]` → `#[expect(...)]` with reason (lint 1.81+ requires explicit expect for cleared warnings). - `doc_markdown`: backtick-wrap `UInt`, `Bool`, `Utf8`, numeric defaults, `DuckDB`, `SQLite`, `DoS`, `DataFusion`, `OpenAPI` in module docs. - `single_match_else` + `match_like_matches_macro`: rewrite the `serde_json::from_str` match as `let Ok(root) = ... else { ... }`. - `.unwrap()` on `key.chars().next()` in `escape_object_key` → `is_some_and`. - `name.to_string()` on `&String` → `name.clone()`. - `all_rows.len() as i64` → `i64::try_from(...).unwrap_or(i64::MAX)` (walker caps bound the count well under i64::MAX; saturate instead of unwrap since the lint config bans `.unwrap()`/`.expect()`). * fix(datafusion/udtf): type-union ordering + fail-loud on offset overflow - `compute_type`: when `"type"` is an array (JSON-Schema nullable syntax, e.g. `["null", "string"]`), pick the first non-null entry so optional fields classify as their real type instead of `"null"`. Falls back to `"null"` only when it's the sole type. Extended test coverage. - Both scalar UDFs: `i64::try_from(row_count)` now returns a `DataFusionError::Execution` on overflow instead of saturating to `i64::MAX`. Saturation would silently misalign `LargeList` offsets; erroring surfaces the (unreachable-in-practice) condition loudly. * fix(datafusion/udtf): cross-walk cycle detection + batch row cap - `walk_schema` now persists `$ref` insertion in `visited_refs` for the duration of the tree-walk recursion, not just for a single `collect_effective` pass. Fixes a leak where schemas like `{$defs: {Node: {properties: {next: {$ref: #/$defs/Node}}}}, properties: {root: {$ref: #/$defs/Node}}}` could descend past the first resolution boundary. Tightened `local_ref_cycle_terminates` to assert stopping at `root.next`. - Both scalar UDFs now error on `DataFusionError::Execution` if the accumulated cross-batch row count exceeds `SCALAR_BATCH_MAX_ROWS` (10M). Per-document caps bound single rows, but a wide batch could previously reach `number_rows * max_rows` in memory before returning. * fix(udtf): pass projection to MemorySourceConfig in json_properties and json_tree Both UDTFs were ignoring the projection parameter in scan(), causing a schema mismatch error when selecting specific columns (e.g. SELECT path, name, type FROM flatten_json_properties(...)). Pass projection.cloned() to MemorySourceConfig::try_new() so DataFusion can push column pruning down into the scan. * fix: format MemorySourceConfig initialization for better readability * Tests + Lint * fix(tests): improve error handling and assertions in JSON property tests * fix(tests): update projection comments for clarity in JSON schema tests --------- Co-authored-by: Viktor Yershov --- crates/runtime/Cargo.toml | 4 + .../benches/flatten_json_properties.rs | 98 ++ crates/runtime/src/datafusion/mod.rs | 1 + crates/runtime/src/datafusion/udf.rs | 19 + .../src/datafusion/udtf/json_properties.rs | 1533 +++++++++++++++++ .../runtime/src/datafusion/udtf/json_tree.rs | 819 +++++++++ crates/runtime/src/datafusion/udtf/mod.rs | 18 + examples/flatten-json-properties/README.md | 172 ++ .../sample_schemas.json | 78 + 9 files changed, 2742 insertions(+) create mode 100644 crates/runtime/benches/flatten_json_properties.rs create mode 100644 crates/runtime/src/datafusion/udtf/json_properties.rs create mode 100644 crates/runtime/src/datafusion/udtf/json_tree.rs create mode 100644 crates/runtime/src/datafusion/udtf/mod.rs create mode 100644 examples/flatten-json-properties/README.md create mode 100644 examples/flatten-json-properties/sample_schemas.json diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index a55d092e87..5566bb3f5b 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -374,3 +374,7 @@ vortex-datafusion.workspace = true [[bench]] harness = false name = "prepared_statement" + +[[bench]] +harness = false +name = "flatten_json_properties" diff --git a/crates/runtime/benches/flatten_json_properties.rs b/crates/runtime/benches/flatten_json_properties.rs new file mode 100644 index 0000000000..1790ded828 --- /dev/null +++ b/crates/runtime/benches/flatten_json_properties.rs @@ -0,0 +1,98 @@ +#![allow(clippy::expect_used)] + +//! Benchmarks for `flatten_json_properties`. +//! +//! Exercises the walker in isolation (no `DataFusion` plumbing) so regressions +//! attributable to the walker itself surface without noise from query planning +//! or Arrow I/O. `bench_catalog_simulation` approximates the typical +//! materialization shape — 1k schemas × 50 fields per schema. + +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use runtime::datafusion::udtf::json_properties::{FlattenOptions, flatten_with_options}; + +fn synthetic_schema(num_fields: usize) -> String { + // One flat object with `num_fields` primitive properties. Representative of + // a wide data-product schema where most fields are leaves. + let mut props = String::from("{"); + for i in 0..num_fields { + if i > 0 { + props.push(','); + } + props.push_str(&format!( + r#""field_{i}":{{"type":"string","description":"Field {i}","format":"text"}}"# + )); + } + props.push('}'); + format!(r#"{{"properties":{props}}}"#) +} + +fn nested_schema(depth: usize) -> String { + // Deeply nested single-chain schema. Exercises the recursion path. + let mut inner = String::from(r#"{"type":"string"}"#); + for _ in 0..depth { + inner = format!(r#"{{"type":"object","properties":{{"n":{inner}}}}}"#); + } + format!(r#"{{"properties":{{"root":{inner}}}}}"#) +} + +fn bench_flat_schemas(c: &mut Criterion) { + let opts = FlattenOptions { + include_internal: true, + ..FlattenOptions::default() + }; + let mut group = c.benchmark_group("flatten_json_properties/flat"); + for fields in [16usize, 128, 512] { + let doc = synthetic_schema(fields); + group.throughput(Throughput::Elements(fields as u64)); + group.bench_with_input(BenchmarkId::new("fields", fields), &doc, |b, doc| { + b.iter(|| { + let rows = flatten_with_options(black_box(doc), &opts); + black_box(rows); + }); + }); + } + group.finish(); +} + +fn bench_nested_schemas(c: &mut Criterion) { + let opts = FlattenOptions { + include_internal: true, + max_depth: 32, + ..FlattenOptions::default() + }; + let mut group = c.benchmark_group("flatten_json_properties/nested"); + for depth in [4usize, 8, 16] { + let doc = nested_schema(depth); + group.throughput(Throughput::Elements(depth as u64)); + group.bench_with_input(BenchmarkId::new("depth", depth), &doc, |b, doc| { + b.iter(|| { + let rows = flatten_with_options(black_box(doc), &opts); + black_box(rows); + }); + }); + } + group.finish(); +} + +fn bench_catalog_simulation(c: &mut Criterion) { + let opts = FlattenOptions::default(); + let doc = synthetic_schema(50); + c.bench_function("flatten_json_properties/catalog_1k_schemas", |b| { + b.iter(|| { + for _ in 0..1000 { + let rows = flatten_with_options(black_box(&doc), &opts); + black_box(rows); + } + }); + }); +} + +criterion_group!( + benches, + bench_flat_schemas, + bench_nested_schemas, + bench_catalog_simulation +); +criterion_main!(benches); diff --git a/crates/runtime/src/datafusion/mod.rs b/crates/runtime/src/datafusion/mod.rs index 6cc30d52cc..3f0880c5ae 100644 --- a/crates/runtime/src/datafusion/mod.rs +++ b/crates/runtime/src/datafusion/mod.rs @@ -131,6 +131,7 @@ pub mod secrets_context_extension; pub mod sort_columns; pub(crate) mod sql_validator; pub mod udf; +pub mod udtf; pub const SPICE_DEFAULT_CATALOG: &str = "spice"; pub const SPICE_RUNTIME_SCHEMA: &str = "runtime"; diff --git a/crates/runtime/src/datafusion/udf.rs b/crates/runtime/src/datafusion/udf.rs index 064fe4798b..fbd46930fd 100644 --- a/crates/runtime/src/datafusion/udf.rs +++ b/crates/runtime/src/datafusion/udf.rs @@ -17,6 +17,10 @@ limitations under the License. use std::collections::HashSet; use std::sync::{Arc, LazyLock}; +use crate::datafusion::udtf::json_properties::{ + FLATTEN_JSON_PROPERTIES_UDTF_NAME, FlattenJsonPropertiesScalar, FlattenJsonPropertiesTableFunc, +}; +use crate::datafusion::udtf::json_tree::{JSON_TREE_UDTF_NAME, JsonTreeScalar, JsonTreeTableFunc}; use crate::embeddings::udtf::{VECTOR_SEARCH_UDTF_NAME, VectorSearchTableFunc}; use crate::search::full_text::udtf::{TEXT_SEARCH_UDTF_NAME, TextSearchTableFunc}; use crate::search::rrf; @@ -80,6 +84,17 @@ pub async fn register_udfs(runtime: &crate::Runtime) { Arc::new(rrf::ReciprocalRankFusion::from_ctx(ctx)), ); + // `flatten_json_properties` / `json_tree` — JSON-Schema and generic JSON + // shredders. Registered as both UDTF (FROM-clause, literal input) and + // ScalarUDF returning `List>` (per-row / LATERAL via UNNEST). + ctx.register_udtf( + FLATTEN_JSON_PROPERTIES_UDTF_NAME, + Arc::new(FlattenJsonPropertiesTableFunc::new()), + ); + ctx.register_udf(FlattenJsonPropertiesScalar::new().into()); + ctx.register_udtf(JSON_TREE_UDTF_NAME, Arc::new(JsonTreeTableFunc::new())); + ctx.register_udf(JsonTreeScalar::new().into()); + #[cfg(feature = "models")] { ctx.register_udf(embed::Embed::new(runtime.embeds()).into()); @@ -101,6 +116,8 @@ static DENY_SPICE_SPECIFIC_FUNCTIONS: LazyLock = LazyLock::new( #[cfg(feature = "models")] AI_UDF_NAME, DIGEST_UDF_NAME, + FLATTEN_JSON_PROPERTIES_UDTF_NAME, + JSON_TREE_UDTF_NAME, ]; FunctionSupport::new( @@ -191,6 +208,8 @@ mod tests { spice_udf(Bucket::new()), spice_udf(Truncate::new()), Arc::new(INSTANCE.clone()), + spice_udf(FlattenJsonPropertiesScalar::new()), + spice_udf(JsonTreeScalar::new()), ]; for udf in spice_udfs { diff --git a/crates/runtime/src/datafusion/udtf/json_properties.rs b/crates/runtime/src/datafusion/udtf/json_properties.rs new file mode 100644 index 0000000000..da74063089 --- /dev/null +++ b/crates/runtime/src/datafusion/udtf/json_properties.rs @@ -0,0 +1,1533 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! `flatten_json_properties` — decompose a JSON-Schema-shaped document into one +//! row per field. See issue #10399 for the full specification. +//! +//! ```text +//! flatten_json_properties(input Utf8 [, options...]) -> TABLE( +//! path Utf8, +//! parent_path Utf8, +//! name Utf8, +//! description Utf8, +//! type Utf8, +//! required Boolean, +//! format Utf8, +//! enum_values List, +//! metadata Utf8 +//! ) +//! ``` +//! +//! Two entry points are registered: +//! +//! - **UDTF** (`register_udtf`) — accepts a literal JSON string and any number +//! of named options. Use in the `FROM` clause: +//! `SELECT * FROM flatten_json_properties('{...}')`. +//! - **Scalar UDF** (`register_udf`) — accepts a `Utf8` column and returns +//! `List>`. Use with `UNNEST` for per-row / LATERAL semantics: +//! `FROM schemas s, UNNEST(flatten_json_properties(s.body)) AS a`. +//! +//! The walker handles: +//! - `properties` recursion (object → nested objects). +//! - `items.properties` (arrays of objects; leaves appear at `array.field`). +//! - `additionalProperties` maps (the map field emits `type = "map"`, children +//! appear at `map.child`). +//! - `allOf`, `oneOf`, `anyOf` merge — fields from every branch are emitted; +//! duplicate names across branches are deduped. +//! - Local `$ref` pointers (`#/$defs/*`, `#/definitions/*`, `#/properties/*`) +//! with cycle detection. +//! - External `$ref` URIs — emitted as `type = "ref"`, never dereferenced (no IO). +//! +//! Options (passed as named arguments): +//! - `max_depth` (`UInt`, default `32`) — walk stops past this depth. +//! - `max_rows` (`UInt`, default `100_000`) — per-document row cap. +//! - `max_bytes` (`UInt`, default `8_388_608`) — input size limit. +//! - `dialect` (`Utf8`, `"json-schema"` | `"openapi"`, default `"json-schema"`) — +//! tags invocation metrics so operators can split `openapi` traffic from +//! `json-schema` traffic. The walker does not currently vary its behavior +//! based on dialect; `OpenAPI`-specific handling (e.g. `nullable: true`) is +//! future scope tracked with the rest of this UDTF. +//! - `include_internal` (`Bool`, default `false`) — include container rows. +//! - `path_style` (`Utf8`, `"dot"` | `"json-pointer"`, default `"dot"`). +//! +//! Telemetry: the walker emits OpenTelemetry counters +//! `flatten_json_properties_invocations_total`, +//! `flatten_json_properties_rows_emitted_total`, and +//! `flatten_json_properties_errors_total{kind}`. Malformed input or a hit +//! depth / row / size limit emits an error-kind metric and yields zero or a +//! truncated-but-valid batch — never a query-level error. + +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{ + Array, ArrayRef, BooleanBuilder, LargeListArray, ListBuilder, StringBuilder, StructArray, + as_string_array, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::compute::kernels::cast::cast; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::{Session, TableFunctionImpl, TableProvider}; +use datafusion::common::Result as DataFusionResult; +use datafusion::datasource::TableType; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; +use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_datasource::source::DataSourceExec; +use opentelemetry::KeyValue; +use opentelemetry::global; +use opentelemetry::metrics::{Counter, Meter}; +use serde_json::Value; + +pub const FLATTEN_JSON_PROPERTIES_UDTF_NAME: &str = "flatten_json_properties"; + +/// Default caps. Configurable per-call via named args. +const DEFAULT_MAX_DEPTH: usize = 32; +const DEFAULT_MAX_ROWS: usize = 100_000; +const DEFAULT_MAX_BYTES: usize = 8 * 1024 * 1024; + +/// Scalar UDF ceiling across a single evaluated batch. Per-document caps +/// already bound individual rows, but a wide input batch could still +/// accumulate `number_rows * max_rows` entries in memory. Error out loudly +/// past this watermark so operators see the condition rather than OOM. +const SCALAR_BATCH_MAX_ROWS: usize = 10_000_000; + +// -------- Metrics -------- + +static METER: LazyLock = LazyLock::new(|| global::meter("flatten_json_properties")); + +static INVOCATIONS: LazyLock> = LazyLock::new(|| { + METER + .u64_counter("flatten_json_properties_invocations_total") + .with_description("Invocations of flatten_json_properties, labelled by dialect.") + .build() +}); + +static ROWS_EMITTED: LazyLock> = LazyLock::new(|| { + METER + .u64_counter("flatten_json_properties_rows_emitted_total") + .with_description("Total rows emitted by flatten_json_properties.") + .build() +}); + +static ERRORS: LazyLock> = LazyLock::new(|| { + METER + .u64_counter("flatten_json_properties_errors_total") + .with_description( + "Errors inside flatten_json_properties, labelled by kind \ + (parse|depth_exceeded|row_cap_hit|cycle|input_too_large).", + ) + .build() +}); + +fn record_error(kind: &'static str) { + ERRORS.add(1, &[KeyValue::new("kind", kind)]); +} + +// -------- Output schema -------- + +static PROPERTY_FIELDS: LazyLock = LazyLock::new(|| { + let enum_item = Arc::new(Field::new("item", DataType::Utf8, true)); + Fields::from(vec![ + Field::new("path", DataType::Utf8, false), + Field::new("parent_path", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("description", DataType::Utf8, true), + Field::new("type", DataType::Utf8, false), + Field::new("required", DataType::Boolean, false), + Field::new("format", DataType::Utf8, true), + Field::new("enum_values", DataType::List(enum_item), true), + Field::new("metadata", DataType::Utf8, true), + ]) +}); + +static OUTPUT_SCHEMA: LazyLock = + LazyLock::new(|| Arc::new(Schema::new(PROPERTY_FIELDS.clone()))); + +/// Return type of the scalar UDF form. Uses `LargeList` (i64 offsets) +/// instead of `List` so a large batch can't overflow the offset range and +/// silently drop rows. `UNNEST` works on both variants, so the change is +/// transparent to downstream SQL. +static ROW_LIST_TYPE: LazyLock = LazyLock::new(|| { + DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Struct(PROPERTY_FIELDS.clone()), + true, + ))) +}); + +// -------- Row + Options -------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PropertyRow { + pub path: String, + pub parent_path: String, + pub name: String, + pub description: Option, + pub type_name: String, + pub required: bool, + pub format: Option, + pub enum_values: Option>, + pub metadata: Option, +} + +/// Dialect tag carried through options. Currently only affects the metric +/// label on `flatten_json_properties_invocations_total`; walker behavior does +/// not yet diverge. Retained so callers (and metrics) can distinguish traffic +/// when dialect-specific behavior lands later. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Dialect { + JsonSchema, + OpenApi, +} + +impl Dialect { + fn label(self) -> &'static str { + match self { + Self::JsonSchema => "json-schema", + Self::OpenApi => "openapi", + } + } + fn parse(s: &str) -> Option { + match s { + "json-schema" | "jsonschema" => Some(Self::JsonSchema), + "openapi" => Some(Self::OpenApi), + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PathStyle { + Dot, + JsonPointer, +} + +impl PathStyle { + fn parse(s: &str) -> Option { + match s { + "dot" => Some(Self::Dot), + "json-pointer" | "jsonpointer" => Some(Self::JsonPointer), + _ => None, + } + } +} + +#[derive(Debug, Clone)] +pub struct FlattenOptions { + pub max_depth: usize, + pub max_rows: usize, + pub max_bytes: usize, + pub dialect: Dialect, + pub include_internal: bool, + pub path_style: PathStyle, +} + +impl Default for FlattenOptions { + fn default() -> Self { + Self { + max_depth: DEFAULT_MAX_DEPTH, + max_rows: DEFAULT_MAX_ROWS, + max_bytes: DEFAULT_MAX_BYTES, + dialect: Dialect::JsonSchema, + include_internal: false, + path_style: PathStyle::Dot, + } + } +} + +// -------- Public entry points -------- + +/// Walk with default options. See [`flatten_with_options`] for configurable caps. +#[must_use] +pub fn flatten(input: &str) -> Vec { + flatten_with_options(input, &FlattenOptions::default()) +} + +/// Walk a JSON-Schema-shaped document and return one [`PropertyRow`] per +/// reachable field. Never errors: returns an empty / truncated `Vec` for +/// malformed input or caps being hit, emitting the corresponding metric. +#[must_use] +pub fn flatten_with_options(input: &str, opts: &FlattenOptions) -> Vec { + INVOCATIONS.add(1, &[KeyValue::new("dialect", opts.dialect.label())]); + + if input.len() > opts.max_bytes { + record_error("input_too_large"); + return Vec::new(); + } + + let Ok(root) = serde_json::from_str::(input) else { + record_error("parse"); + return Vec::new(); + }; + + let mut walker = Walker::new(&root, opts); + // Capture the root lifetime as a free variable so `walk_schema` sees it as + // `&'a Value` — letting ref resolution return `&'a Value` without cloning. + let root_ref: &Value = &root; + walker.walk_schema(root_ref, "", 0); + ROWS_EMITTED.add(walker.rows.len() as u64, &[]); + walker.rows +} + +// -------- Walker -------- + +struct Walker<'a> { + root: &'a Value, + opts: &'a FlattenOptions, + rows: Vec, + /// Active `$ref` pointers on the walk stack, for cycle detection. + visited_refs: HashSet, + depth_cap_hit: bool, + row_cap_hit: bool, +} + +impl<'a> Walker<'a> { + fn new(root: &'a Value, opts: &'a FlattenOptions) -> Self { + Self { + root, + opts, + rows: Vec::new(), + visited_refs: HashSet::new(), + depth_cap_hit: false, + row_cap_hit: false, + } + } + + fn walk_schema(&mut self, schema: &'a Value, parent_path: &str, depth: usize) { + if self.check_caps(depth) { + return; + } + let effective = self.effective_schemas(schema); + + // `collect_effective` handles cycles during a single resolution + // pass, but once control returns here we recurse into the resolved + // schema's own children — any `$ref` back to this node would look + // "fresh" to the next `collect_effective` call. Re-insert the ref + // (if there was one) so the whole walk-chain sees it. + let chain_ref: Option = schema + .get("$ref") + .and_then(Value::as_str) + .filter(|r| is_local_ref(r)) + .and_then(|r| self.visited_refs.insert(r.to_owned()).then(|| r.to_owned())); + + let required: HashSet<&str> = effective + .iter() + .flat_map(|s| { + s.get("required") + .and_then(Value::as_array) + .into_iter() + .flatten() + }) + .filter_map(Value::as_str) + .collect(); + + let mut seen_names: HashSet<&str> = HashSet::new(); + for eff in &effective { + if let Some(properties) = eff.get("properties").and_then(Value::as_object) { + for (name, spec) in properties { + if !seen_names.insert(name.as_str()) { + continue; + } + self.handle_field( + name, + spec, + parent_path, + required.contains(name.as_str()), + depth, + ); + if self.row_cap_hit { + if let Some(r) = chain_ref { + self.visited_refs.remove(&r); + } + return; + } + } + } + } + + if let Some(r) = chain_ref { + self.visited_refs.remove(&r); + } + } + + fn handle_field( + &mut self, + name: &str, + spec: &'a Value, + parent_path: &str, + required: bool, + depth: usize, + ) { + let path = make_path(parent_path, name, self.opts.path_style); + let effective_specs = self.effective_schemas(spec); + + let type_name = effective_specs + .iter() + .map(|s| compute_type(s)) + .find(|t| t != "unknown") + .unwrap_or_else(|| "unknown".to_owned()); + + let is_container = matches!(type_name.as_str(), "object" | "array" | "map"); + let emit_container_now = !is_container || self.opts.include_internal; + if emit_container_now { + self.emit_row( + spec, + &effective_specs, + &path, + parent_path, + name, + &type_name, + required, + ); + if self.row_cap_hit { + return; + } + } + + // Recurse once on the original `spec`; `walk_schema` re-expands + // effective branches internally with a single `seen_names` set, so + // overlapping `properties` across allOf/oneOf/anyOf / $ref branches + // are de-duplicated rather than emitted once per branch. + let rows_before = self.rows.len(); + match type_name.as_str() { + "object" => { + self.walk_schema(spec, &path, depth + 1); + } + "array" => { + // Items may itself be typed / composite; reuse walk_schema at + // the same path so leaves appear as `array.child` rather than + // `array[].child`. Look across effective branches so `items` + // declared under a combinator is still found, but use the + // first matching `items` as the single recursion point. + if let Some(items) = effective_specs + .iter() + .find_map(|s| s.get("items")) + .filter(|v| v.is_object()) + { + self.walk_schema(items, &path, depth + 1); + } + } + "map" => { + if let Some(ap) = effective_specs + .iter() + .find_map(|s| s.get("additionalProperties")) + .filter(|v| v.is_object()) + { + self.walk_schema(ap, &path, depth + 1); + } + } + _ => {} + } + + // Leaf-only mode would otherwise drop container fields whose children + // are primitives (array of strings, map of ints, empty object). If + // the recursion produced nothing and we haven't already emitted the + // container, surface it now so the field still appears in the output. + if is_container && !emit_container_now && self.rows.len() == rows_before { + self.emit_row( + spec, + &effective_specs, + &path, + parent_path, + name, + &type_name, + required, + ); + } + } + + #[expect( + clippy::too_many_arguments, + reason = "internal helper threads per-row metadata; splitting into a struct adds noise without clarity" + )] + fn emit_row( + &mut self, + raw_spec: &'a Value, + effective: &[&'a Value], + path: &str, + parent_path: &str, + name: &str, + type_name: &str, + required: bool, + ) { + // `effective` contains the raw_spec when no $ref was followed, and only + // the resolved target(s) when one was. The `or_else` arm preserves + // description / format / enum annotations declared alongside a $ref + // (JSON Schema 2020-12 lets them coexist; earlier drafts ignored them). + let description = first_str(effective, "description") + .or_else(|| raw_spec.get("description").and_then(Value::as_str)) + .map(ToOwned::to_owned); + + let format = first_str(effective, "format") + .or_else(|| raw_spec.get("format").and_then(Value::as_str)) + .map(ToOwned::to_owned); + + let enum_values = effective + .iter() + .find_map(|s| s.get("enum").and_then(Value::as_array)) + .or_else(|| raw_spec.get("enum").and_then(Value::as_array)) + .map(|arr| { + arr.iter() + .map(|v| match v { + Value::String(s) => s.clone(), + _ => v.to_string(), + }) + .collect::>() + }); + + self.rows.push(PropertyRow { + path: path.to_owned(), + parent_path: parent_path.to_owned(), + name: name.to_owned(), + description, + type_name: type_name.to_owned(), + required, + format, + enum_values, + metadata: Some(raw_spec.to_string()), + }); + + if self.rows.len() >= self.opts.max_rows { + self.row_cap_hit = true; + record_error("row_cap_hit"); + } + } + + /// Resolve `$ref`, `allOf`, `oneOf`, `anyOf` into the list of contributing + /// schemas. External and unresolvable refs pass through as-is so callers + /// can still read shape metadata from them. + fn effective_schemas(&mut self, schema: &'a Value) -> Vec<&'a Value> { + let mut out = Vec::new(); + self.collect_effective(schema, &mut out, 0); + if out.is_empty() { + out.push(schema); + } + out + } + + /// `ref_depth` tracks how deep we've recursed through `$ref` and + /// `allOf`/`oneOf`/`anyOf` expansion at a single schema node. Capped at + /// `opts.max_depth` so pathological combinator / ref chains can't blow the + /// stack or iterate unboundedly (`DoS`). + fn collect_effective(&mut self, schema: &'a Value, out: &mut Vec<&'a Value>, ref_depth: usize) { + if ref_depth > self.opts.max_depth { + if !self.depth_cap_hit { + self.depth_cap_hit = true; + record_error("depth_exceeded"); + } + out.push(schema); + return; + } + if let Some(ref_str) = schema.get("$ref").and_then(Value::as_str) { + if is_local_ref(ref_str) { + if self.visited_refs.contains(ref_str) { + record_error("cycle"); + return; + } + // Copy out `self.root: &'a Value` (references are Copy) so the + // returned `Option<&'a Value>` survives past `self`'s borrow. + let root: &'a Value = self.root; + if let Some(target) = root.pointer(ref_str.trim_start_matches('#')) { + self.visited_refs.insert(ref_str.to_owned()); + self.collect_effective(target, out, ref_depth + 1); + self.visited_refs.remove(ref_str); + return; + } + // Local ref that doesn't resolve — fall through and treat the + // schema itself as the contribution. + } + // External ref — `compute_type` will classify as `ref`; the URI + // remains in `metadata`. Never dereferenced. + out.push(schema); + return; + } + + out.push(schema); + for comb in ["allOf", "oneOf", "anyOf"] { + if let Some(arr) = schema.get(comb).and_then(Value::as_array) { + for entry in arr { + self.collect_effective(entry, out, ref_depth + 1); + } + } + } + } + + fn check_caps(&mut self, depth: usize) -> bool { + if depth > self.opts.max_depth { + if !self.depth_cap_hit { + self.depth_cap_hit = true; + record_error("depth_exceeded"); + } + return true; + } + if self.rows.len() >= self.opts.max_rows { + if !self.row_cap_hit { + self.row_cap_hit = true; + record_error("row_cap_hit"); + } + return true; + } + false + } +} + +// -------- Helpers -------- + +fn first_str<'a>(schemas: &[&'a Value], key: &str) -> Option<&'a str> { + schemas + .iter() + .find_map(|s| s.get(key).and_then(Value::as_str)) +} + +fn is_local_ref(ref_str: &str) -> bool { + ref_str.starts_with('#') +} + +/// Classify a schema node into one of the emitted type labels. +fn compute_type(spec: &Value) -> String { + // External $ref → "ref" + if let Some(ref_str) = spec.get("$ref").and_then(Value::as_str) + && !is_local_ref(ref_str) + { + return "ref".to_owned(); + } + // Explicit additionalProperties without own properties → map. + let has_ap = spec + .get("additionalProperties") + .is_some_and(Value::is_object); + // Require `properties` / `items` to be well-formed before treating the node + // as an object/array. A non-object `properties` or a non-object/array + // `items` shouldn't silently flip the type. + let has_props = spec.get("properties").is_some_and(Value::is_object); + if has_ap && !has_props { + return "map".to_owned(); + } + match spec.get("type") { + Some(Value::String(s)) => s.clone(), + Some(Value::Array(arr)) => { + // Type unions with `"null"` express optional/nullable in JSON + // Schema; the "real" type is the first non-null entry. Only fall + // back to `"null"` (or `"unknown"`) when no other type is present. + let strs: Vec<&str> = arr.iter().filter_map(Value::as_str).collect(); + strs.iter() + .find(|t| **t != "null") + .copied() + .or_else(|| strs.first().copied()) + .unwrap_or("unknown") + .to_owned() + } + _ => { + if has_props { + "object".to_owned() + } else if spec + .get("items") + .is_some_and(|v| v.is_object() || v.is_array()) + { + "array".to_owned() + } else if let Some(first_enum) = spec + .get("enum") + .and_then(Value::as_array) + .and_then(|a| a.first()) + { + match first_enum { + Value::String(_) => "string", + Value::Bool(_) => "boolean", + Value::Number(n) if n.is_i64() || n.is_u64() => "integer", + Value::Number(_) => "number", + Value::Null => "null", + _ => "unknown", + } + .to_owned() + } else { + "unknown".to_owned() + } + } + } +} + +fn make_path(parent: &str, name: &str, style: PathStyle) -> String { + match style { + PathStyle::Dot => { + if parent.is_empty() { + name.to_owned() + } else { + format!("{parent}.{name}") + } + } + PathStyle::JsonPointer => { + // Escape `/` and `~` per RFC 6901. + let escaped = name.replace('~', "~0").replace('/', "~1"); + if parent.is_empty() { + format!("/{escaped}") + } else { + format!("{parent}/{escaped}") + } + } + } +} + +// -------- UDTF -------- + +#[derive(Clone, Default)] +pub struct FlattenJsonPropertiesTableFunc; + +impl FlattenJsonPropertiesTableFunc { + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl Debug for FlattenJsonPropertiesTableFunc { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FlattenJsonPropertiesTableFunc").finish() + } +} + +impl TableFunctionImpl for FlattenJsonPropertiesTableFunc { + fn call(&self, exprs: &[Expr]) -> DataFusionResult> { + let parsed = parse_udtf_args(exprs)?; + let rows = parsed + .input + .as_deref() + .map(|s| flatten_with_options(s, &parsed.options)) + .unwrap_or_default(); + Ok(Arc::new(FlattenJsonPropertiesTable { + schema: Arc::clone(&OUTPUT_SCHEMA), + rows, + })) + } +} + +struct ParsedUdtfArgs { + input: Option, + options: FlattenOptions, +} + +fn parse_udtf_args(exprs: &[Expr]) -> DataFusionResult { + let mut positional = exprs.iter(); + let mut options = FlattenOptions::default(); + + let first = positional.next().ok_or_else(|| { + DataFusionError::Plan(format!( + "{FLATTEN_JSON_PROPERTIES_UDTF_NAME}() requires a JSON string argument." + )) + })?; + + let input = literal_string(first).map_err(|e| { + DataFusionError::NotImplemented(format!( + "{FLATTEN_JSON_PROPERTIES_UDTF_NAME}() currently supports a literal JSON string as the \ + first argument. For per-row / LATERAL invocation, use \ + `UNNEST({FLATTEN_JSON_PROPERTIES_UDTF_NAME}())`. Details: {e}" + )) + })?; + + for arg in positional { + let (name, value) = named_arg(arg).ok_or_else(|| { + DataFusionError::Plan(format!( + "Arguments after the JSON string must be named, e.g. `max_depth => 32`. Got: {arg:?}." + )) + })?; + apply_named_option(&name, value, &mut options)?; + } + + Ok(ParsedUdtfArgs { input, options }) +} + +/// Extract a Utf8/LargeUtf8 string literal. Returns `Ok(None)` for NULL. +fn literal_string(expr: &Expr) -> Result, String> { + match expr { + Expr::Literal(ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v), _) => Ok(v.clone()), + Expr::Literal(ScalarValue::Null, _) => Ok(None), + other => Err(format!("expected Utf8, got {other:?}")), + } +} + +/// Recognise a `name => value` named-argument expression. `DataFusion` surfaces +/// these as a literal tagged with `spice.parameter_name` metadata. +fn named_arg(expr: &Expr) -> Option<(String, &ScalarValue)> { + if let Expr::Literal(scalar, Some(meta)) = expr + && let Some(name) = meta.inner().get("spice.parameter_name") + { + return Some((name.clone(), scalar)); + } + None +} + +fn apply_named_option( + name: &str, + value: &ScalarValue, + opts: &mut FlattenOptions, +) -> DataFusionResult<()> { + match name { + "max_depth" => opts.max_depth = parse_usize(name, value)?, + "max_rows" => opts.max_rows = parse_usize(name, value)?, + "max_bytes" => opts.max_bytes = parse_usize(name, value)?, + "dialect" => { + let s = parse_utf8(name, value)?; + opts.dialect = Dialect::parse(&s).ok_or_else(|| { + DataFusionError::Plan(format!( + "Unknown dialect '{s}'. Expected 'json-schema' or 'openapi'." + )) + })?; + } + "include_internal" => opts.include_internal = parse_bool(name, value)?, + "path_style" => { + let s = parse_utf8(name, value)?; + opts.path_style = PathStyle::parse(&s).ok_or_else(|| { + DataFusionError::Plan(format!( + "Unknown path_style '{s}'. Expected 'dot' or 'json-pointer'." + )) + })?; + } + other => { + return Err(DataFusionError::Plan(format!( + "Unknown option '{other}'. Supported: max_depth, max_rows, max_bytes, dialect, \ + include_internal, path_style." + ))); + } + } + Ok(()) +} + +fn parse_usize(name: &str, v: &ScalarValue) -> DataFusionResult { + let n: i64 = match v { + ScalarValue::Int8(Some(n)) => i64::from(*n), + ScalarValue::Int16(Some(n)) => i64::from(*n), + ScalarValue::Int32(Some(n)) => i64::from(*n), + ScalarValue::Int64(Some(n)) => *n, + ScalarValue::UInt8(Some(n)) => i64::from(*n), + ScalarValue::UInt16(Some(n)) => i64::from(*n), + ScalarValue::UInt32(Some(n)) => i64::from(*n), + ScalarValue::UInt64(Some(n)) => i64::try_from(*n) + .map_err(|_| DataFusionError::Plan(format!("{name} must fit in i64, got {n}")))?, + other => { + return Err(DataFusionError::Plan(format!( + "{name} must be an integer, got {other:?}" + ))); + } + }; + usize::try_from(n) + .map_err(|_| DataFusionError::Plan(format!("{name} must be non-negative, got {n}"))) +} + +fn parse_bool(name: &str, v: &ScalarValue) -> DataFusionResult { + match v { + ScalarValue::Boolean(Some(b)) => Ok(*b), + other => Err(DataFusionError::Plan(format!( + "{name} must be a boolean, got {other:?}" + ))), + } +} + +fn parse_utf8(name: &str, v: &ScalarValue) -> DataFusionResult { + match v { + ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Ok(s.clone()), + other => Err(DataFusionError::Plan(format!( + "{name} must be a string, got {other:?}" + ))), + } +} + +#[derive(Debug)] +pub struct FlattenJsonPropertiesTable { + schema: SchemaRef, + rows: Vec, +} + +#[async_trait] +impl TableProvider for FlattenJsonPropertiesTable { + fn as_any(&self) -> &dyn Any { + self + } + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + fn table_type(&self) -> TableType { + TableType::Base + } + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> DataFusionResult> { + // Single-node only: a bare `DataSourceExec(MemorySourceConfig)` is + // rejected by `EnsureSupportedFileScan` in cluster mode. Distributed + // support requires a dedicated `UdtfArgs` proto variant + codec so + // remote executors can re-invoke the walker; that's follow-up scope. + let batch = rows_to_batch(&self.rows, Arc::clone(&self.schema))?; + let src = MemorySourceConfig::try_new( + &[vec![batch]], + Arc::clone(&self.schema), + projection.cloned(), + )?; + Ok(Arc::new(DataSourceExec::new(Arc::new(src)))) + } +} + +fn rows_to_batch(rows: &[PropertyRow], schema: SchemaRef) -> DataFusionResult { + let (arrays, _) = build_property_arrays(rows); + RecordBatch::try_new(schema, arrays).map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) +} + +fn build_property_arrays(rows: &[PropertyRow]) -> (Vec, usize) { + let mut path = StringBuilder::with_capacity(rows.len(), rows.len() * 16); + let mut parent_path = StringBuilder::with_capacity(rows.len(), rows.len() * 8); + let mut name = StringBuilder::with_capacity(rows.len(), rows.len() * 8); + let mut description = StringBuilder::with_capacity(rows.len(), rows.len() * 32); + let mut type_name = StringBuilder::with_capacity(rows.len(), rows.len() * 4); + let mut required = BooleanBuilder::with_capacity(rows.len()); + let mut format = StringBuilder::with_capacity(rows.len(), 0); + let mut metadata = StringBuilder::with_capacity(rows.len(), rows.len() * 64); + let mut enum_values = ListBuilder::new(StringBuilder::new()); + + for row in rows { + path.append_value(&row.path); + parent_path.append_value(&row.parent_path); + name.append_value(&row.name); + match &row.description { + Some(v) => description.append_value(v), + None => description.append_null(), + } + type_name.append_value(&row.type_name); + required.append_value(row.required); + match &row.format { + Some(v) => format.append_value(v), + None => format.append_null(), + } + match &row.metadata { + Some(v) => metadata.append_value(v), + None => metadata.append_null(), + } + match &row.enum_values { + Some(vs) => { + for v in vs { + enum_values.values().append_value(v); + } + enum_values.append(true); + } + None => enum_values.append(false), + } + } + + let arrays: Vec = vec![ + Arc::new(path.finish()), + Arc::new(parent_path.finish()), + Arc::new(name.finish()), + Arc::new(description.finish()), + Arc::new(type_name.finish()), + Arc::new(required.finish()), + Arc::new(format.finish()), + Arc::new(enum_values.finish()), + Arc::new(metadata.finish()), + ]; + (arrays, rows.len()) +} + +// -------- ScalarUDF variant -------- +// +// Exposes the same walker as a scalar that returns `List>` per row. +// Composes with `UNNEST` to give per-row / LATERAL semantics. + +#[derive(Debug, Clone)] +pub struct FlattenJsonPropertiesScalar { + signature: Signature, +} + +impl Default for FlattenJsonPropertiesScalar { + fn default() -> Self { + Self::new() + } +} + +impl FlattenJsonPropertiesScalar { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8View]), + ], + Volatility::Immutable, + ), + } + } +} + +impl PartialEq for FlattenJsonPropertiesScalar { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for FlattenJsonPropertiesScalar {} + +impl std::hash::Hash for FlattenJsonPropertiesScalar { + fn hash(&self, state: &mut H) { + self.name().hash(state); + } +} + +impl ScalarUDFImpl for FlattenJsonPropertiesScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + FLATTEN_JSON_PROPERTIES_UDTF_NAME + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(ROW_LIST_TYPE.clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let input_col = args + .args + .first() + .ok_or_else(|| { + DataFusionError::Plan(format!( + "{FLATTEN_JSON_PROPERTIES_UDTF_NAME}() requires a JSON string argument." + )) + })? + .clone(); + + // Named args are stripped of their metadata when the scalar form is + // invoked; users who want non-default options should use the UDTF form. + let opts = FlattenOptions::default(); + + let array = input_col.into_array(args.number_rows)?; + // Signature restricts input to Utf8/LargeUtf8/Utf8View; normalize to + // Utf8 so `as_string_array` below always succeeds. + let normalized = if matches!(array.data_type(), DataType::Utf8) { + array + } else { + cast(&array, &DataType::Utf8) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))? + }; + let strings = as_string_array(&normalized); + + // Collect all rows into a single flat vec, with offsets delineating + // which span belongs to which input row. NULL inputs produce empty + // (but non-NULL) list slots so the output row count matches input. + // + // Using `LargeListArray` (i64 offsets) so a large evaluated batch + // cannot overflow and silently drop rows. Per-document caps inside + // `flatten_with_options` still bound memory use. + let mut all_rows: Vec = Vec::new(); + let mut offsets: Vec = Vec::with_capacity(strings.len() + 1); + offsets.push(0); + + for idx in 0..strings.len() { + if !strings.is_null(idx) { + let rows = flatten_with_options(strings.value(idx), &opts); + all_rows.extend(rows); + if all_rows.len() > SCALAR_BATCH_MAX_ROWS { + record_error("batch_cap_hit"); + return Err(DataFusionError::Execution(format!( + "{FLATTEN_JSON_PROPERTIES_UDTF_NAME}(): batch produced more than {SCALAR_BATCH_MAX_ROWS} flattened rows; lower `max_rows` or split the input." + ))); + } + } + // Walker caps bound the row count well under `i64::MAX`, but if + // somehow they didn't, silently saturating would misalign list + // offsets. Fail loud instead so the condition is visible. + let len = i64::try_from(all_rows.len()).map_err(|_| { + DataFusionError::Execution(format!( + "{FLATTEN_JSON_PROPERTIES_UDTF_NAME}(): flattened row count exceeds LargeList i64 offset range." + )) + })?; + offsets.push(len); + } + + let (struct_arrays, _) = build_property_arrays(&all_rows); + let struct_array = StructArray::new(PROPERTY_FIELDS.clone(), struct_arrays, None); + let list_array = LargeListArray::new( + Arc::new(Field::new( + "item", + DataType::Struct(PROPERTY_FIELDS.clone()), + true, + )), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(struct_array), + None, + ); + Ok(ColumnarValue::Array(Arc::new(list_array))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn by_path(rows: &[PropertyRow]) -> std::collections::HashMap<&str, &PropertyRow> { + rows.iter().map(|r| (r.path.as_str(), r)).collect() + } + + fn with_internal() -> FlattenOptions { + FlattenOptions { + include_internal: true, + ..FlattenOptions::default() + } + } + + #[test] + fn leaves_only_by_default() { + let json = r#"{ + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"} + } + } + } + }"#; + let rows = flatten(json); + // "user" is a container; by default containers are not emitted. + let paths: Vec<_> = rows.iter().map(|r| r.path.as_str()).collect(); + assert_eq!(paths, vec!["user.name"]); + } + + #[test] + fn include_internal_emits_containers() { + let json = r#"{ + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"} + } + } + } + }"#; + let rows = flatten_with_options(json, &with_internal()); + let by = by_path(&rows); + assert_eq!(by["user"].type_name, "object"); + assert_eq!(by["user.name"].type_name, "string"); + } + + #[test] + fn flat_primitives_with_required() { + let json = r#"{ + "properties": { + "name": {"type": "string", "description": "User's full name"}, + "age": {"type": "integer"} + }, + "required": ["name"] + }"#; + let rows = flatten(json); + let by = by_path(&rows); + assert!(by["name"].required); + assert_eq!(by["name"].description.as_deref(), Some("User's full name")); + assert!(!by["age"].required); + } + + #[test] + fn items_properties_of_object_arrays() { + let json = r#"{ + "properties": { + "orders": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"} + }, + "required": ["id"] + } + } + } + }"#; + let rows = flatten(json); + let by = by_path(&rows); + assert_eq!(by["orders.id"].type_name, "integer"); + assert!(by["orders.id"].required); + assert_eq!(by["orders.name"].type_name, "string"); + // Array container itself is not emitted by default. + assert!(!by.contains_key("orders")); + } + + #[test] + fn additional_properties_map() { + let json = r#"{ + "properties": { + "labels": { + "type": "object", + "additionalProperties": { + "type": "object", + "properties": {"value": {"type": "string"}} + } + } + } + }"#; + let rows = flatten_with_options(json, &with_internal()); + let by = by_path(&rows); + assert_eq!(by["labels"].type_name, "map"); + // Child properties under additionalProperties are emitted at labels.value. + assert_eq!(by["labels.value"].type_name, "string"); + } + + #[test] + fn all_of_merges_fields() { + let json = r#"{ + "properties": { + "user": { + "allOf": [ + {"properties": {"name": {"type": "string"}}, + "required": ["name"]}, + {"properties": {"age": {"type": "integer"}}} + ] + } + } + }"#; + let rows = flatten(json); + let by = by_path(&rows); + assert!(by["user.name"].required); + assert_eq!(by["user.age"].type_name, "integer"); + } + + #[test] + fn nullable_type_union_picks_non_null() { + // JSON Schema expresses nullable fields as `"type": ["null", "string"]` + // (or any ordering). Pick the first non-null type so the output row + // reflects the real type rather than `"null"`. + let json = r#"{ + "properties": { + "leading_null": {"type": ["null", "string"]}, + "trailing_null": {"type": ["integer", "null"]}, + "all_null": {"type": ["null"]} + } + }"#; + let rows = flatten(json); + let by = by_path(&rows); + assert_eq!(by["leading_null"].type_name, "string"); + assert_eq!(by["trailing_null"].type_name, "integer"); + assert_eq!(by["all_null"].type_name, "null"); + } + + #[test] + fn one_of_any_of_union_fields() { + let json = r#"{ + "properties": { + "payload": { + "oneOf": [ + {"properties": {"text": {"type": "string"}}}, + {"properties": {"count": {"type": "integer"}}} + ] + } + } + }"#; + let rows = flatten(json); + let by = by_path(&rows); + assert!(by.contains_key("payload.text")); + assert!(by.contains_key("payload.count")); + } + + #[test] + fn local_ref_resolves() { + let json = r##"{ + "$defs": { + "Address": {"type": "object", "properties": {"street": {"type": "string"}}} + }, + "properties": { + "home": {"$ref": "#/$defs/Address"} + } + }"##; + let rows = flatten(json); + let by = by_path(&rows); + assert_eq!(by["home.street"].type_name, "string"); + } + + #[test] + fn local_ref_cycle_terminates() { + let json = r##"{ + "$defs": { + "Node": { + "type": "object", + "properties": { + "next": {"$ref": "#/$defs/Node"} + } + } + }, + "properties": { + "root": {"$ref": "#/$defs/Node"} + } + }"##; + let rows = flatten(json); + let by = by_path(&rows); + // First resolution of Node happens at `root`; the second hop into + // `root.next` must recognise it's re-entering the same `$ref` chain + // and stop without a third level of expansion. + assert!(by.contains_key("root.next")); + assert!(!by.contains_key("root.next.next")); + } + + #[test] + fn external_ref_emits_ref_type_row() { + let json = r#"{ + "properties": { + "ext": {"$ref": "https://example.com/schema.json"} + } + }"#; + let rows = flatten_with_options(json, &with_internal()); + let by = by_path(&rows); + assert_eq!(by["ext"].type_name, "ref"); + let meta: serde_json::Value = + serde_json::from_str(by["ext"].metadata.as_ref().expect("ref metadata present")) + .expect("metadata parses as JSON"); + assert_eq!(meta["$ref"], "https://example.com/schema.json"); + } + + #[test] + fn enum_and_format_are_captured() { + let json = r#"{ + "properties": { + "status": {"type": "string", "enum": ["active", "pending"]}, + "created_at": {"type": "string", "format": "date-time"} + } + }"#; + let rows = flatten(json); + let by = by_path(&rows); + assert_eq!( + by["status"].enum_values.as_deref(), + Some(&["active".to_string(), "pending".to_string()][..]) + ); + assert_eq!(by["created_at"].format.as_deref(), Some("date-time")); + } + + #[test] + fn malformed_input_yields_zero_rows() { + assert!(flatten("not json").is_empty()); + assert!(flatten("{broken").is_empty()); + assert!(flatten("").is_empty()); + } + + #[test] + fn oversized_input_is_rejected_without_parsing() { + let opts = FlattenOptions { + max_bytes: 32, + ..FlattenOptions::default() + }; + let big = r#"{"properties": {"a": {"type": "string"}, "b": {"type": "integer"}}}"#; + assert!(big.len() > 32); + assert!(flatten_with_options(big, &opts).is_empty()); + } + + #[test] + fn max_depth_truncates_walk() { + let opts = FlattenOptions { + max_depth: 2, + include_internal: true, + ..FlattenOptions::default() + }; + let json = r#"{ + "properties": { + "a": {"type": "object", "properties": { + "b": {"type": "object", "properties": { + "c": {"type": "object", "properties": { + "d": {"type": "string"} + }} + }} + }} + } + }"#; + let rows = flatten_with_options(json, &opts); + // We saw up to depth 2 (a.b); "a.b.c" lives at depth 3 which is capped. + // Exact path set depends on when the cap trips, so we only assert that + // the deepest path ("a.b.c.d") is absent. + assert!(!rows.iter().any(|r| r.path == "a.b.c.d")); + } + + #[test] + fn max_rows_caps_output() { + let opts = FlattenOptions { + max_rows: 2, + ..FlattenOptions::default() + }; + let json = r#"{ + "properties": { + "a": {"type": "string"}, + "b": {"type": "string"}, + "c": {"type": "string"}, + "d": {"type": "string"} + } + }"#; + let rows = flatten_with_options(json, &opts); + assert_eq!(rows.len(), 2); + } + + #[test] + fn json_pointer_path_style() { + let opts = FlattenOptions { + path_style: PathStyle::JsonPointer, + ..FlattenOptions::default() + }; + let json = r#"{ + "properties": { + "user": {"type": "object", "properties": { + "name": {"type": "string"} + }} + } + }"#; + let rows = flatten_with_options(json, &opts); + let paths: Vec<_> = rows.iter().map(|r| r.path.as_str()).collect(); + assert_eq!(paths, vec!["/user/name"]); + } + + #[test] + fn documents_without_properties_yield_zero_rows() { + assert!(flatten(r#"{"foo": "bar"}"#).is_empty()); + assert!(flatten(r#"{"properties": {}}"#).is_empty()); + assert!(flatten(r"[1, 2, 3]").is_empty()); + } + + #[tokio::test] + async fn udtf_emits_schema_and_batch() { + use datafusion::prelude::SessionContext; + let ctx = SessionContext::new(); + let func = FlattenJsonPropertiesTableFunc::new(); + let provider = func + .call(&[Expr::Literal( + ScalarValue::Utf8(Some( + r#"{"properties":{"a":{"type":"string"}}}"#.to_string(), + )), + None, + )]) + .expect("call succeeds for literal"); + + let schema = provider.schema(); + assert_eq!(schema.fields().len(), 9); + + let state = ctx.state(); + let plan = provider.scan(&state, None, &[], None).await.expect("scan"); + let results = datafusion::physical_plan::collect(plan, ctx.task_ctx()) + .await + .expect("collect"); + assert_eq!(results[0].num_rows(), 1); + } + + #[test] + fn udtf_rejects_non_literal_first_arg() { + use datafusion::common::Column; + let func = FlattenJsonPropertiesTableFunc::new(); + let err = func + .call(&[Expr::Column(Column::new_unqualified("body"))]) + .expect_err("column argument must be rejected"); + assert!(err.to_string().contains("UNNEST")); + } + + #[test] + fn scalar_udf_return_type_is_large_list_of_struct() { + let udf = FlattenJsonPropertiesScalar::new(); + let ty = udf.return_type(&[DataType::Utf8]).expect("return type"); + match ty { + DataType::LargeList(field) => { + assert!(matches!(field.data_type(), DataType::Struct(_))); + } + other => panic!("expected LargeList, got {other:?}"), + } + } + + #[test] + fn scalar_udf_invokes_per_row() { + use arrow::array::StringArray; + + let udf = FlattenJsonPropertiesScalar::new(); + let input = Arc::new(StringArray::from(vec![ + Some(r#"{"properties":{"a":{"type":"string"}}}"#), + Some(r#"{"properties":{"b":{"type":"integer"},"c":{"type":"boolean"}}}"#), + None, + ])) as ArrayRef; + + let arg_field = Arc::new(Field::new("body", DataType::Utf8, true)); + let return_field = Arc::new(Field::new("result", ROW_LIST_TYPE.clone(), true)); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input)], + arg_fields: vec![arg_field], + number_rows: 3, + return_field, + config_options: Arc::new(datafusion::config::ConfigOptions::default()), + }) + .expect("invoke succeeds"); + + let arr = match result { + ColumnarValue::Array(a) => a, + other @ ColumnarValue::Scalar(_) => panic!("expected array, got {other:?}"), + }; + let list = arr + .as_any() + .downcast_ref::() + .expect("large-list array"); + assert_eq!(list.len(), 3); + // Row 0 has 1 flattened property; row 1 has 2; row 2 is NULL-valued but + // still emits an (empty) list slot per row. + assert_eq!(list.value(0).len(), 1); + assert_eq!(list.value(1).len(), 2); + } + + #[tokio::test] + async fn scan_with_projection_returns_only_requested_columns() { + use datafusion::prelude::SessionContext; + let ctx = SessionContext::new(); + let func = FlattenJsonPropertiesTableFunc::new(); + let provider = func + .call(&[Expr::Literal( + ScalarValue::Utf8(Some( + r#"{"properties":{"a":{"type":"string"}}}"#.to_string(), + )), + None, + )]) + .expect("call succeeds"); + + // Full schema has 9 columns (path, parent_path, name, description, + // type, required, format, enum_values, metadata); request only + // columns 0 (path) and 4 (type). + let projection = vec![0usize, 4]; + let state = ctx.state(); + let plan = provider + .scan(&state, Some(&projection), &[], None) + .await + .expect("scan with projection"); + let results = datafusion::physical_plan::collect(plan, ctx.task_ctx()) + .await + .expect("collect"); + assert_eq!(results[0].num_columns(), 2, "expected 2 projected columns"); + assert_eq!(results[0].schema().field(0).name(), "path"); + assert_eq!(results[0].schema().field(1).name(), "type"); + } +} diff --git a/crates/runtime/src/datafusion/udtf/json_tree.rs b/crates/runtime/src/datafusion/udtf/json_tree.rs new file mode 100644 index 0000000000..07267a7f23 --- /dev/null +++ b/crates/runtime/src/datafusion/udtf/json_tree.rs @@ -0,0 +1,819 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! `json_tree` — recursive depth-first walk of an arbitrary JSON document. +//! +//! Schema-agnostic sibling of `flatten_json_properties`. Mirrors `DuckDB` / +//! `SQLite`'s table function of the same name: one row per node (interior and +//! leaf), in depth-first order, with JSON-Path addresses and a parent pointer. +//! +//! ```text +//! json_tree(input Utf8 [, max_depth => UInt, max_rows => UInt, max_bytes => UInt]) -> TABLE( +//! key Utf8, +//! value Utf8, +//! type Utf8, +//! atom Utf8, +//! id Int64, +//! parent Int64, +//! fullkey Utf8, +//! path Utf8 +//! ) +//! ``` +//! +//! Registered twice: +//! - As a UDTF for `SELECT * FROM json_tree('{...}')`. Named options +//! (`max_depth`, `max_rows`, `max_bytes`) are only accepted in this form. +//! - As a scalar UDF returning `List>` for per-row / +//! `LATERAL json_tree(s.body)` usage via `UNNEST`. The scalar form takes +//! only the JSON argument and always runs with default caps. + +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{ + Array, ArrayRef, Int64Builder, LargeListArray, StringBuilder, StructArray, as_string_array, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::compute::kernels::cast::cast; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::{Session, TableFunctionImpl, TableProvider}; +use datafusion::common::Result as DataFusionResult; +use datafusion::datasource::TableType; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; +use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_datasource::source::DataSourceExec; +use opentelemetry::KeyValue; +use opentelemetry::global; +use opentelemetry::metrics::{Counter, Meter}; +use serde_json::Value; + +pub const JSON_TREE_UDTF_NAME: &str = "json_tree"; + +const DEFAULT_MAX_DEPTH: usize = 64; +const DEFAULT_MAX_ROWS: usize = 1_000_000; +const DEFAULT_MAX_BYTES: usize = 8 * 1024 * 1024; + +/// Scalar UDF ceiling across a single evaluated batch. Per-document caps +/// already bound individual rows, but a wide input batch could still +/// accumulate `number_rows * max_rows` entries in memory. Error out loudly +/// past this watermark so operators see the condition rather than OOM. +const SCALAR_BATCH_MAX_ROWS: usize = 10_000_000; + +// -------- Metrics -------- + +static METER: LazyLock = LazyLock::new(|| global::meter("json_tree")); + +static INVOCATIONS: LazyLock> = LazyLock::new(|| { + METER + .u64_counter("json_tree_invocations_total") + .with_description("Invocations of json_tree.") + .build() +}); + +static ROWS_EMITTED: LazyLock> = LazyLock::new(|| { + METER + .u64_counter("json_tree_rows_emitted_total") + .with_description("Rows emitted by json_tree.") + .build() +}); + +static ERRORS: LazyLock> = LazyLock::new(|| { + METER + .u64_counter("json_tree_errors_total") + .with_description( + "Errors inside json_tree, labelled by kind (parse|depth_exceeded|input_too_large|row_cap_hit).", + ) + .build() +}); + +fn record_error(kind: &'static str) { + ERRORS.add(1, &[KeyValue::new("kind", kind)]); +} + +// -------- Options + Output schema -------- + +#[derive(Debug, Clone)] +pub struct JsonTreeOptions { + pub max_depth: usize, + pub max_rows: usize, + pub max_bytes: usize, +} + +impl Default for JsonTreeOptions { + fn default() -> Self { + Self { + max_depth: DEFAULT_MAX_DEPTH, + max_rows: DEFAULT_MAX_ROWS, + max_bytes: DEFAULT_MAX_BYTES, + } + } +} + +static TREE_FIELDS: LazyLock = LazyLock::new(|| { + Fields::from(vec![ + Field::new("key", DataType::Utf8, true), + Field::new("value", DataType::Utf8, true), + Field::new("type", DataType::Utf8, false), + Field::new("atom", DataType::Utf8, true), + Field::new("id", DataType::Int64, false), + Field::new("parent", DataType::Int64, true), + Field::new("fullkey", DataType::Utf8, false), + // `path` is the parent JSON-Path. The root row has no parent, so it's + // emitted as NULL (matches DuckDB / SQLite `json_tree` semantics). + Field::new("path", DataType::Utf8, true), + ]) +}); + +static OUTPUT_SCHEMA: LazyLock = + LazyLock::new(|| Arc::new(Schema::new(TREE_FIELDS.clone()))); + +/// Return type of the scalar UDF form. Uses `LargeList` (i64 offsets) +/// instead of `List` so a large batch can't overflow the offset range and +/// silently drop rows. `UNNEST` works on both variants, so the change is +/// transparent to downstream SQL. +static ROW_LIST_TYPE: LazyLock = LazyLock::new(|| { + DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Struct(TREE_FIELDS.clone()), + true, + ))) +}); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TreeRow { + pub key: Option, + pub value: Option, + pub type_name: String, + pub atom: Option, + pub id: i64, + pub parent: Option, + pub fullkey: String, + pub path: Option, +} + +// -------- Public entry points -------- + +#[must_use] +pub fn json_tree(input: &str) -> Vec { + json_tree_with_options(input, &JsonTreeOptions::default()) +} + +#[must_use] +pub fn json_tree_with_options(input: &str, opts: &JsonTreeOptions) -> Vec { + INVOCATIONS.add(1, &[]); + + if input.len() > opts.max_bytes { + record_error("input_too_large"); + return Vec::new(); + } + + let Ok(root) = serde_json::from_str::(input) else { + record_error("parse"); + return Vec::new(); + }; + let mut ctx = WalkCtx { + rows: Vec::new(), + next_id: 0, + depth_cap_hit: false, + row_cap_hit: false, + }; + visit(&root, None, None, "$", None, 0, opts, &mut ctx); + ROWS_EMITTED.add(ctx.rows.len() as u64, &[]); + ctx.rows +} + +struct WalkCtx { + rows: Vec, + next_id: i64, + depth_cap_hit: bool, + row_cap_hit: bool, +} + +#[expect( + clippy::too_many_arguments, + reason = "walker threads per-node state; collapsing into a struct adds indirection without clarity" +)] +fn visit( + node: &Value, + key: Option, + parent: Option, + fullkey: &str, + path: Option<&str>, + depth: usize, + opts: &JsonTreeOptions, + ctx: &mut WalkCtx, +) { + if ctx.row_cap_hit { + return; + } + if depth > opts.max_depth { + if !ctx.depth_cap_hit { + ctx.depth_cap_hit = true; + record_error("depth_exceeded"); + } + return; + } + if ctx.rows.len() >= opts.max_rows { + if !ctx.row_cap_hit { + ctx.row_cap_hit = true; + record_error("row_cap_hit"); + } + return; + } + let id = ctx.next_id; + ctx.next_id += 1; + + ctx.rows.push(TreeRow { + key, + value: Some(node.to_string()), + type_name: type_of(node).to_owned(), + atom: atom_of(node), + id, + parent, + fullkey: fullkey.to_owned(), + path: path.map(ToOwned::to_owned), + }); + + match node { + Value::Object(map) => { + for (child_key, child) in map { + if ctx.row_cap_hit { + return; + } + let child_fullkey = format!("{fullkey}.{}", escape_object_key(child_key)); + visit( + child, + Some(child_key.clone()), + Some(id), + &child_fullkey, + Some(fullkey), + depth + 1, + opts, + ctx, + ); + } + } + Value::Array(items) => { + for (idx, child) in items.iter().enumerate() { + if ctx.row_cap_hit { + return; + } + let child_fullkey = format!("{fullkey}[{idx}]"); + // DuckDB / SQLite `json_tree` sets `key` to the array index as + // a string so consumers can distinguish array siblings. + visit( + child, + Some(idx.to_string()), + Some(id), + &child_fullkey, + Some(fullkey), + depth + 1, + opts, + ctx, + ); + } + } + _ => {} + } +} + +fn type_of(v: &Value) -> &'static str { + match v { + Value::Null => "null", + Value::Bool(_) => "boolean", + Value::Number(n) if n.is_i64() || n.is_u64() => "integer", + Value::Number(_) => "real", + Value::String(_) => "string", + Value::Array(_) => "array", + Value::Object(_) => "object", + } +} + +fn atom_of(v: &Value) -> Option { + match v { + Value::Null => Some("null".to_owned()), + Value::Bool(b) => Some(b.to_string()), + Value::Number(n) => Some(n.to_string()), + Value::String(s) => Some(s.clone()), + Value::Array(_) | Value::Object(_) => None, + } +} + +fn escape_object_key(key: &str) -> String { + // SQLite / DuckDB JSON-path shorthand (`$.a.b`) accepts identifier-style + // keys only — anything else, including hyphens, must be bracket-quoted so + // consumers can re-parse the `fullkey`. + let first = key.chars().next(); + let simple = first.is_some_and(|c| !c.is_ascii_digit()) + && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'); + if simple { + key.to_owned() + } else { + format!("[{}]", serde_json::Value::String(key.to_owned())) + } +} + +// -------- UDTF -------- + +#[derive(Clone, Default)] +pub struct JsonTreeTableFunc; + +impl JsonTreeTableFunc { + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl Debug for JsonTreeTableFunc { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JsonTreeTableFunc").finish() + } +} + +impl TableFunctionImpl for JsonTreeTableFunc { + fn call(&self, exprs: &[Expr]) -> DataFusionResult> { + let (input, opts) = parse_udtf_args(exprs)?; + let rows = input + .as_deref() + .map(|s| json_tree_with_options(s, &opts)) + .unwrap_or_default(); + Ok(Arc::new(JsonTreeTable { + schema: Arc::clone(&OUTPUT_SCHEMA), + rows, + })) + } +} + +fn parse_udtf_args(exprs: &[Expr]) -> DataFusionResult<(Option, JsonTreeOptions)> { + let mut iter = exprs.iter(); + let first = iter.next().ok_or_else(|| { + DataFusionError::Plan(format!( + "{JSON_TREE_UDTF_NAME}() requires a JSON string argument." + )) + })?; + let input = match first { + Expr::Literal(ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v), _) => v.clone(), + Expr::Literal(ScalarValue::Null, _) => None, + other => { + return Err(DataFusionError::NotImplemented(format!( + "{JSON_TREE_UDTF_NAME}() currently supports a literal JSON string as the first \ + argument. For per-row / LATERAL invocation, use \ + `UNNEST({JSON_TREE_UDTF_NAME}())`. Got: {other:?}." + ))); + } + }; + + let mut opts = JsonTreeOptions::default(); + for arg in iter { + if let Expr::Literal(scalar, Some(meta)) = arg + && let Some(name) = meta.inner().get("spice.parameter_name") + { + let name = name.clone(); + match name.as_str() { + "max_depth" => opts.max_depth = parse_usize(&name, scalar)?, + "max_rows" => opts.max_rows = parse_usize(&name, scalar)?, + "max_bytes" => opts.max_bytes = parse_usize(&name, scalar)?, + other => { + return Err(DataFusionError::Plan(format!( + "Unknown option '{other}'. Supported: max_depth, max_rows, max_bytes." + ))); + } + } + continue; + } + return Err(DataFusionError::Plan(format!( + "Arguments after the JSON string must be named, e.g. `max_depth => 64`. Got: {arg:?}." + ))); + } + + Ok((input, opts)) +} + +fn parse_usize(name: &str, v: &ScalarValue) -> DataFusionResult { + let n: i64 = match v { + ScalarValue::Int8(Some(n)) => i64::from(*n), + ScalarValue::Int16(Some(n)) => i64::from(*n), + ScalarValue::Int32(Some(n)) => i64::from(*n), + ScalarValue::Int64(Some(n)) => *n, + ScalarValue::UInt8(Some(n)) => i64::from(*n), + ScalarValue::UInt16(Some(n)) => i64::from(*n), + ScalarValue::UInt32(Some(n)) => i64::from(*n), + ScalarValue::UInt64(Some(n)) => i64::try_from(*n) + .map_err(|_| DataFusionError::Plan(format!("{name} must fit in i64, got {n}")))?, + other => { + return Err(DataFusionError::Plan(format!( + "{name} must be an integer, got {other:?}" + ))); + } + }; + usize::try_from(n) + .map_err(|_| DataFusionError::Plan(format!("{name} must be non-negative, got {n}"))) +} + +#[derive(Debug)] +pub struct JsonTreeTable { + schema: SchemaRef, + rows: Vec, +} + +#[async_trait] +impl TableProvider for JsonTreeTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> DataFusionResult> { + // Single-node only: a bare `DataSourceExec(MemorySourceConfig)` is + // rejected by `EnsureSupportedFileScan` in cluster mode. Distributed + // support requires a dedicated `UdtfArgs` proto variant + codec so + // remote executors can re-invoke the walker; that's follow-up scope. + let batch = rows_to_batch(&self.rows, Arc::clone(&self.schema))?; + let src = MemorySourceConfig::try_new( + &[vec![batch]], + Arc::clone(&self.schema), + projection.cloned(), + )?; + Ok(Arc::new(DataSourceExec::new(Arc::new(src)))) + } +} + +fn rows_to_batch(rows: &[TreeRow], schema: SchemaRef) -> DataFusionResult { + RecordBatch::try_new(schema, build_tree_arrays(rows)) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) +} + +fn build_tree_arrays(rows: &[TreeRow]) -> Vec { + let mut key = StringBuilder::with_capacity(rows.len(), rows.len() * 8); + let mut value = StringBuilder::with_capacity(rows.len(), rows.len() * 32); + let mut type_name = StringBuilder::with_capacity(rows.len(), rows.len() * 4); + let mut atom = StringBuilder::with_capacity(rows.len(), rows.len() * 8); + let mut id = Int64Builder::with_capacity(rows.len()); + let mut parent = Int64Builder::with_capacity(rows.len()); + let mut fullkey = StringBuilder::with_capacity(rows.len(), rows.len() * 16); + let mut path = StringBuilder::with_capacity(rows.len(), rows.len() * 16); + + for row in rows { + match &row.key { + Some(v) => key.append_value(v), + None => key.append_null(), + } + match &row.value { + Some(v) => value.append_value(v), + None => value.append_null(), + } + type_name.append_value(&row.type_name); + match &row.atom { + Some(v) => atom.append_value(v), + None => atom.append_null(), + } + id.append_value(row.id); + match row.parent { + Some(p) => parent.append_value(p), + None => parent.append_null(), + } + fullkey.append_value(&row.fullkey); + match &row.path { + Some(v) => path.append_value(v), + None => path.append_null(), + } + } + + vec![ + Arc::new(key.finish()), + Arc::new(value.finish()), + Arc::new(type_name.finish()), + Arc::new(atom.finish()), + Arc::new(id.finish()), + Arc::new(parent.finish()), + Arc::new(fullkey.finish()), + Arc::new(path.finish()), + ] +} + +// -------- Scalar UDF -------- + +#[derive(Debug, Clone)] +pub struct JsonTreeScalar { + signature: Signature, +} + +impl Default for JsonTreeScalar { + fn default() -> Self { + Self::new() + } +} + +impl JsonTreeScalar { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8View]), + ], + Volatility::Immutable, + ), + } + } +} + +impl PartialEq for JsonTreeScalar { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for JsonTreeScalar {} + +impl std::hash::Hash for JsonTreeScalar { + fn hash(&self, state: &mut H) { + self.name().hash(state); + } +} + +impl ScalarUDFImpl for JsonTreeScalar { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + JSON_TREE_UDTF_NAME + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(ROW_LIST_TYPE.clone()) + } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let input_col = args + .args + .first() + .ok_or_else(|| { + DataFusionError::Plan(format!( + "{JSON_TREE_UDTF_NAME}() requires a JSON string argument." + )) + })? + .clone(); + + let opts = JsonTreeOptions::default(); + let array = input_col.into_array(args.number_rows)?; + // Signature restricts input to Utf8/LargeUtf8/Utf8View; normalize to + // Utf8 so `as_string_array` below always succeeds. + let normalized = if matches!(array.data_type(), DataType::Utf8) { + array + } else { + cast(&array, &DataType::Utf8) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))? + }; + let strings = as_string_array(&normalized); + + // Using `LargeListArray` (i64 offsets) so a large evaluated batch + // cannot overflow and silently drop rows. Per-document caps inside + // `json_tree_with_options` still bound memory use. + let mut all_rows: Vec = Vec::new(); + let mut offsets: Vec = Vec::with_capacity(strings.len() + 1); + offsets.push(0); + + for idx in 0..strings.len() { + if !strings.is_null(idx) { + let rows = json_tree_with_options(strings.value(idx), &opts); + all_rows.extend(rows); + if all_rows.len() > SCALAR_BATCH_MAX_ROWS { + record_error("batch_cap_hit"); + return Err(DataFusionError::Execution(format!( + "{JSON_TREE_UDTF_NAME}(): batch produced more than {SCALAR_BATCH_MAX_ROWS} rows; lower `max_rows` or split the input." + ))); + } + } + // Walker caps bound the row count well under `i64::MAX`, but if + // somehow they didn't, silently saturating would misalign list + // offsets. Fail loud instead so the condition is visible. + let len = i64::try_from(all_rows.len()).map_err(|_| { + DataFusionError::Execution(format!( + "{JSON_TREE_UDTF_NAME}(): flattened row count exceeds LargeList i64 offset range." + )) + })?; + offsets.push(len); + } + + let struct_array = + StructArray::new(TREE_FIELDS.clone(), build_tree_arrays(&all_rows), None); + let list_array = LargeListArray::new( + Arc::new(Field::new( + "item", + DataType::Struct(TREE_FIELDS.clone()), + true, + )), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(struct_array), + None, + ); + Ok(ColumnarValue::Array(Arc::new(list_array))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn by_fullkey(rows: &[TreeRow]) -> std::collections::HashMap<&str, &TreeRow> { + rows.iter().map(|r| (r.fullkey.as_str(), r)).collect() + } + + #[test] + fn scalar_root_emits_single_row() { + let rows = json_tree("42"); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].type_name, "integer"); + assert_eq!(rows[0].atom.as_deref(), Some("42")); + assert_eq!(rows[0].fullkey, "$"); + assert!(rows[0].parent.is_none()); + assert!(rows[0].key.is_none()); + } + + #[test] + fn root_object_is_interior_and_children_reference_it() { + let rows = json_tree(r#"{"a": 1, "b": "two"}"#); + assert_eq!(rows.len(), 3); + let by = by_fullkey(&rows); + assert_eq!(by["$.a"].parent, Some(0)); + assert_eq!(by["$.a"].type_name, "integer"); + assert_eq!(by["$.b"].type_name, "string"); + } + + #[test] + fn arrays_index_paths_numerically() { + let rows = json_tree(r#"{"xs": [10, 20, 30]}"#); + let by = by_fullkey(&rows); + assert_eq!(by["$.xs"].type_name, "array"); + assert_eq!(by["$.xs[0]"].atom.as_deref(), Some("10")); + assert_eq!(by["$.xs[2]"].atom.as_deref(), Some("30")); + } + + #[test] + fn depth_first_order_and_monotonic_ids() { + let rows = json_tree(r#"{"a": {"b": 1}, "c": 2}"#); + let ids: Vec = rows.iter().map(|r| r.id).collect(); + assert_eq!(ids, vec![0, 1, 2, 3]); + let fullkeys: Vec<&str> = rows.iter().map(|r| r.fullkey.as_str()).collect(); + assert_eq!(fullkeys, vec!["$", "$.a", "$.a.b", "$.c"]); + } + + #[test] + fn keys_with_special_characters_are_quoted() { + let rows = json_tree(r#"{"with space": 1, "has-hyphen": 2, "_ok": 3, "plain": 4}"#); + let fullkeys: Vec<&str> = rows.iter().map(|r| r.fullkey.as_str()).collect(); + // Space and hyphen both force bracket-quoting so consumers can re-parse. + assert!(fullkeys.contains(&r#"$.["with space"]"#)); + assert!(fullkeys.contains(&r#"$.["has-hyphen"]"#)); + // Identifier-safe keys stay in shorthand form. + assert!(fullkeys.contains(&"$._ok")); + assert!(fullkeys.contains(&"$.plain")); + } + + #[test] + fn malformed_input_yields_zero_rows() { + assert!(json_tree("not json").is_empty()); + assert!(json_tree("").is_empty()); + } + + #[test] + fn deeply_nested_terminates_at_max_depth() { + const NESTING: usize = DEFAULT_MAX_DEPTH + 20; + let mut doc = String::from("0"); + for _ in 0..NESTING { + doc = format!("[{doc}]"); + } + let rows = json_tree(&doc); + assert!(!rows.is_empty()); + assert!(rows.len() <= DEFAULT_MAX_DEPTH + 1); + } + + #[test] + fn max_depth_option_is_honoured() { + let opts = JsonTreeOptions { + max_depth: 2, + ..Default::default() + }; + // depth 0 → root object, depth 1 → "a", depth 2 → "a.b", depth 3 → stop. + let rows = json_tree_with_options(r#"{"a": {"b": {"c": 1}}}"#, &opts); + let fullkeys: Vec<_> = rows.iter().map(|r| r.fullkey.as_str()).collect(); + assert!(fullkeys.contains(&"$.a.b")); + assert!(!fullkeys.contains(&"$.a.b.c")); + } + + #[test] + fn max_rows_caps_output() { + // 50 elements but cap of 10 → 10 rows total (cap includes root). + let doc = "[".to_string() + &(0..49).map(|_| "0,").collect::() + "0]"; + let opts = JsonTreeOptions { + max_rows: 10, + ..Default::default() + }; + let rows = json_tree_with_options(&doc, &opts); + assert_eq!(rows.len(), 10); + } + + #[test] + fn max_bytes_rejects_oversized_input() { + let opts = JsonTreeOptions { + max_bytes: 4, + ..Default::default() + }; + assert!(json_tree_with_options(r#"{"a": 1}"#, &opts).is_empty()); + } + + #[tokio::test] + async fn udtf_table_provider_roundtrips() { + use datafusion::prelude::SessionContext; + let ctx = SessionContext::new(); + let func = JsonTreeTableFunc::new(); + let provider = func + .call(&[Expr::Literal( + ScalarValue::Utf8(Some(r#"{"a": [1, 2]}"#.to_string())), + None, + )]) + .expect("call succeeds"); + let state = ctx.state(); + let plan = provider.scan(&state, None, &[], None).await.expect("scan"); + let results = datafusion::physical_plan::collect(plan, ctx.task_ctx()) + .await + .expect("collect"); + // root object + array + 2 ints = 4 rows. + assert_eq!(results[0].num_rows(), 4); + } + + #[test] + fn scalar_udf_return_type() { + let udf = JsonTreeScalar::new(); + let ty = udf.return_type(&[DataType::Utf8]).expect("return type"); + assert!(matches!(ty, DataType::LargeList(_))); + } + + #[tokio::test] + async fn scan_with_projection_returns_only_requested_columns() { + use datafusion::prelude::SessionContext; + let ctx = SessionContext::new(); + let func = JsonTreeTableFunc::new(); + let provider = func + .call(&[Expr::Literal( + ScalarValue::Utf8(Some(r#"{"a": [1, 2]}"#.to_string())), + None, + )]) + .expect("call succeeds"); + + // Full schema has 8 columns; request only columns 0 (key), 2 (type), 6 (fullkey). + let projection = vec![0usize, 2, 6]; + let state = ctx.state(); + let plan = provider + .scan(&state, Some(&projection), &[], None) + .await + .expect("scan with projection"); + let results = datafusion::physical_plan::collect(plan, ctx.task_ctx()) + .await + .expect("collect"); + assert_eq!(results[0].num_columns(), 3, "expected 3 projected columns"); + assert_eq!(results[0].schema().field(0).name(), "key"); + assert_eq!(results[0].schema().field(1).name(), "type"); + assert_eq!(results[0].schema().field(2).name(), "fullkey"); + } +} diff --git a/crates/runtime/src/datafusion/udtf/mod.rs b/crates/runtime/src/datafusion/udtf/mod.rs new file mode 100644 index 0000000000..ee2e55ff94 --- /dev/null +++ b/crates/runtime/src/datafusion/udtf/mod.rs @@ -0,0 +1,18 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +pub mod json_properties; +pub mod json_tree; diff --git a/examples/flatten-json-properties/README.md b/examples/flatten-json-properties/README.md new file mode 100644 index 0000000000..dbb48fd66c --- /dev/null +++ b/examples/flatten-json-properties/README.md @@ -0,0 +1,172 @@ +# `flatten_json_properties` — Searchable Attributes Index + +Turn a dataset of JSON-Schema-shaped documents into a flat, embeddable, +vector-searchable attributes index — entirely from `spicepod.yaml`, with no +pre-processing service. + +This recipe is the worked example from issue +[#10399](https://github.com/spiceai/spiceai/issues/10399). + +## What you'll build + +A Spicepod that: + +1. Ingests a catalog of JSON-Schema documents (one row per schema). +2. Defines a view that calls `flatten_json_properties(body)` on each schema, + producing one row per field across all schemas. +3. Accelerates that view into DuckDB for sub-second query latency. +4. Embeds the per-field `description` column so you can run vector / hybrid + search across field descriptions. + +Query results look like: + +```sql +-- Find fields that mention "customer demographics" across every schema. +SELECT schema_id, path, name, description +FROM vector_search('attributes', 'customer demographics signal', 10); +``` + +## Files + +- `spicepod.yaml` — the Spicepod manifest. +- `sample_schemas.json` — three sample JSON-Schema documents, enough to see + the view populate without standing up an API. + +## Running locally + +1. Start Spice in this directory: + + ```bash + spice run + ``` + +2. In another terminal, connect to the SQL REPL: + + ```bash + spice sql + ``` + +3. Inspect the raw catalog: + + ```sql + SELECT id, title FROM schemas; + ``` + +4. Inspect the flattened attributes view: + + ```sql + SELECT schema_id, path, name, type, description + FROM attributes + ORDER BY schema_id, path + LIMIT 20; + ``` + +5. Try a vector search over field descriptions: + + ```sql + SELECT schema_id, path, name, description + FROM vector_search('attributes', 'customer email', 5); + ``` + +## How it works + +### 1. Ingest the catalog + +The `schemas` dataset reads `sample_schemas.json` as-is, exposing `id`, +`title`, and `body` (the raw schema document) as columns. + +### 2. Flatten with `flatten_json_properties` + +`flatten_json_properties(body)` walks each schema's `properties` tree and +emits one row per field with these columns: + +| column | type | description | +|---------------|--------------|----------------------------------------------------| +| `path` | Utf8 | Dotted path, e.g. `user.address.street` | +| `parent_path` | Utf8 | Everything but the leaf | +| `name` | Utf8 | Leaf field name | +| `description` | Utf8 | From the field's `description` annotation | +| `type` | Utf8 | `string`, `integer`, `object`, `array`, `map`, `ref`, … | +| `required` | Boolean | Inferred from the ancestor's `required:[...]` | +| `format` | Utf8 | e.g. `date-time`, `uuid` | +| `enum_values` | List | Present when the field declares `enum` | +| `metadata` | Utf8 | Full field spec JSON — query with `json_get(metadata, '$.x-custom')` | + +The function handles `items.properties` (arrays of objects), +`additionalProperties` maps, `allOf` / `oneOf` / `anyOf` merge, and local +`$ref` pointers with cycle detection. External `$ref` URIs are emitted as a +row with `type = 'ref'` and are never dereferenced. + +### 3. Per-row LATERAL via UNNEST + +The view uses the scalar form of `flatten_json_properties` combined with +`UNNEST`, which gives row-level evaluation: + +```sql +SELECT s.id AS schema_id, a.* +FROM schemas s, + UNNEST(flatten_json_properties(s.body)) AS a +``` + +The UDTF form (`FROM flatten_json_properties('{...}')`) exists for ad-hoc +testing with literal inputs. + +### 4. Acceleration + embeddings + +`acceleration.enabled: true` materializes the view into DuckDB on a refresh +schedule. The `description` column has an `embeddings:` block so each row's +description is embedded once and stored next to it, making vector search a +single-hop lookup. + +### 5. Options + +Pass named options to tune the walker: + +```sql +SELECT * +FROM flatten_json_properties( + '{"properties": {"a": {"type": "string"}}}', + max_depth => 16, + max_rows => 10000, + include_internal => true, -- also emit object/array/map rows + path_style => 'json-pointer', + dialect => 'json-schema' +); +``` + +| option | type | default | notes | +|--------------------|--------|---------------|-----------------------------------------------| +| `max_depth` | UInt | `32` | walk stops past this depth | +| `max_rows` | UInt | `100000` | per-document row cap | +| `max_bytes` | UInt | `8_388_608` | input size limit (8 MiB) | +| `dialect` | Utf8 | `json-schema` | `json-schema` \| `openapi` | +| `include_internal` | Bool | `false` | emit container rows (`object`, `array`, `map`) | +| `path_style` | Utf8 | `dot` | `dot` (`a.b.c`) \| `json-pointer` (`/a/b/c`) | + +### 6. `json_tree` — generic alternative + +If your input isn't JSON-Schema-shaped, reach for `json_tree`. It's a +schema-agnostic recursive walker with DuckDB/SQLite-compatible output (cols +`key`, `value`, `type`, `atom`, `id`, `parent`, `fullkey`, `path`): + +```sql +SELECT key, type, atom, fullkey +FROM json_tree('{"a": [1, 2], "b": {"c": "hi"}}'); +``` + +## Telemetry + +The walker emits the following OpenTelemetry counters, scraped by any +configured metrics exporter: + +- `flatten_json_properties_invocations_total{dialect}` +- `flatten_json_properties_rows_emitted_total` +- `flatten_json_properties_errors_total{kind}` where + `kind ∈ {parse, depth_exceeded, row_cap_hit, cycle, input_too_large}` +- `json_tree_invocations_total` / `json_tree_rows_emitted_total` / + `json_tree_errors_total{kind}` + +## See also + +- Issue: https://github.com/spiceai/spiceai/issues/10399 +- PR: https://github.com/spiceai/spiceai/pull/10406 diff --git a/examples/flatten-json-properties/sample_schemas.json b/examples/flatten-json-properties/sample_schemas.json new file mode 100644 index 0000000000..453845e9ee --- /dev/null +++ b/examples/flatten-json-properties/sample_schemas.json @@ -0,0 +1,78 @@ +[ + { + "id": "user-v1", + "title": "User", + "body": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "User", + "type": "object", + "required": ["id", "email"], + "properties": { + "id": {"type": "string", "format": "uuid", "description": "Stable user identifier"}, + "email": {"type": "string", "format": "email", "description": "Primary contact email"}, + "name": {"type": "string", "description": "Full legal name"}, + "created_at": {"type": "string", "format": "date-time", "description": "When the account was created"}, + "status": {"type": "string", "enum": ["active", "pending", "disabled"], "description": "Account lifecycle status"}, + "preferences": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Arbitrary user preferences as key-value pairs" + } + } + } + }, + { + "id": "order-v1", + "title": "Order", + "body": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Order", + "type": "object", + "required": ["id", "customer_id", "line_items"], + "properties": { + "id": {"type": "string", "format": "uuid", "description": "Order identifier"}, + "customer_id": {"type": "string", "format": "uuid", "description": "Reference to the User that placed this order"}, + "total_cents": {"type": "integer", "description": "Order grand total in cents"}, + "line_items": { + "type": "array", + "description": "Line items on this order", + "items": { + "type": "object", + "required": ["sku"], + "properties": { + "sku": {"type": "string", "description": "Stock keeping unit identifier"}, + "quantity": {"type": "integer", "description": "Units ordered"}, + "unit_price_cents": {"type": "integer", "description": "Per-unit price at time of order"} + } + } + } + } + } + }, + { + "id": "address-v1", + "title": "Address", + "body": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Address", + "$defs": { + "Country": { + "type": "object", + "properties": { + "code": {"type": "string", "description": "ISO 3166-1 alpha-2 country code"}, + "name": {"type": "string", "description": "Country name in English"} + } + } + }, + "type": "object", + "properties": { + "street": {"type": "string", "description": "Street address, first line"}, + "street_2": {"type": "string", "description": "Street address, second line"}, + "city": {"type": "string", "description": "City or municipality"}, + "region": {"type": "string", "description": "State, province, or administrative region"}, + "postal": {"type": "string", "description": "Postal or ZIP code"}, + "country": {"$ref": "#/$defs/Country", "description": "Country reference"} + } + } + } +] From 53746444d63034a3c132f28cb2c92ade49daf8ea Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 20 Apr 2026 22:04:31 -0700 Subject: [PATCH 3/4] Harden /v1/tools and /v1/nsql against unauthenticated / LLM-driven SQL (#10365) * Harden /v1/tools and /v1/nsql against unauthenticated / LLM-driven SQL Addresses threat model items #50 and #51 (docs/threat_models/v2.0.0.md): - Add strict read-only SQL validator (validate_sql_query_read_only) that rejects every DDL/DML/COPY/non-prepared Statement node regardless of per-catalog writability. - Plumb a read_only flag through QueryBuilder/Query and apply the validator at all three plan execution sites (local, Ballista, async). - Default the built-in `sql` tool to read-only; operators may opt in via SqlTool::allow_writes(). LLM tool-use can no longer mutate data through the sql tool. - Run LLM-generated SQL from /v1/nsql under the read-only validator so prompt-injection-driven writes cannot reach writable catalogs. - Gate /v1/tools/* behind a require_auth_configured middleware: when runtime.auth is not set, these routes return 401 rather than invoking tool.call anonymously with attacker-controlled bodies. - Record the new mitigations in the v2.0.0 threat model. * refactor: clarify read-only SQL validation comments and enhance documentation for DDL/DML restrictions * Refactor authentication error response to use JSON format and add SQL tool descriptions for read-only and writable modes * Fix collapsible_if clippy lint in read-only validation path * Reject write-capable extension nodes in read-only validator Spice's planner can represent DDL/DML as LogicalPlan::Extension nodes (DdlExtensionNode, DmlExtensionNode, DistributedCayenne{Insert,Update, Delete,Merge}Node, CayenneMergeNode). The previous read-only validator only matched Ddl/Dml/Copy/Statement and would have let those plan shapes through, defeating the read-only guarantee on /v1/tools/sql and /v1/nsql. - Add Extension arm to validate_sql_query_read_only that denies any node whose UserDefinedLogicalNodeCore::name matches a curated list of write-capable extension names. - Test the deny mechanism with a stub UserDefinedLogicalNode and verify a non-write extension name is still allowed. - Add an integration test that exercises Spice's create_logical_plan wrapper end-to-end (cfg(not(windows))). - Reflect the PREPARE/EXECUTE/DEALLOCATE rejection in the SqlTool read-only description so LLM/tool-selection logic knows the posture. - Replace the PR-contextual 'Unverified in this review' phrasing in the threat model with the durable 'Unverified mitigation'. * Bypass SQL results cache for read-only query paths When ctx.read_only is set (e.g. the /v1/tools/sql read-only tool and the /v1/nsql LLM SQL path), both the SQL-keyed and plan-keyed results-cache lookups are now skipped inside get_plan_or_cached, and the returned RequestCacheManager is forced to CacheDisabled. Previously, a cache hit from a prior writable execution could short-circuit validate_sql_query_read_only, letting a cached result produced by a write-capable plan (e.g. LogicalPlan::Extension nodes like DmlExtension or DistributedCayenneInsert) be served on a read-only surface. Also move WRITE_CAPABLE_EXTENSION_NAMES into the cache crate as the single source of truth, and extend cache_is_enabled_for_plan to reject write-capable LogicalPlan::Extension nodes. Defense-in-depth: even on writable paths, write-capable extension plans are now never cached or populated in the results cache. * fix: flatten write-capable extension check to match guard in cache eligibility Removes one level of nesting as requested in review. --------- Co-authored-by: Viktor Yershov --- crates/cache/src/lib.rs | 30 ++ crates/runtime/src/datafusion/query.rs | 42 ++- .../runtime/src/datafusion/query/builder.rs | 19 ++ crates/runtime/src/datafusion/query/cache.rs | 97 ++++-- .../runtime/src/datafusion/sql_validator.rs | 319 ++++++++++++++++++ crates/runtime/src/http/routes.rs | 45 ++- crates/runtime/src/http/v1/nsql.rs | 9 +- crates/runtime/src/tools/builtin/sql.rs | 46 ++- docs/threat_models/v2.0.0.md | 310 +++++++++++++++++ 9 files changed, 875 insertions(+), 42 deletions(-) create mode 100644 docs/threat_models/v2.0.0.md diff --git a/crates/cache/src/lib.rs b/crates/cache/src/lib.rs index 6bc9cde51c..00b8359bef 100644 --- a/crates/cache/src/lib.rs +++ b/crates/cache/src/lib.rs @@ -59,6 +59,31 @@ pub use utils::filter_transient_error_responses; pub use utils::get_logical_plan_input_tables; pub use utils::to_cached_record_batch_stream; +/// Stable [`datafusion::logical_expr::UserDefinedLogicalNodeCore::name`] values for +/// every Spice logical-plan extension node that performs (or dispatches) a write, +/// a schema mutation, or any other side-effect that must not be reachable via a +/// read-only SQL path and must not be served from or populated into the SQL +/// results cache. +/// +/// Keep this list in sync with: +/// - `datafusion_ddl::DdlExtensionNode` → `"DdlExtension"` +/// - `datafusion_dml::DmlExtensionNode` → `"DmlExtension"` +/// - `cayenne::ddl::logical_nodes::CayenneMergeNode` → `"CayenneMerge"` +/// - `runtime::datafusion::cayenne_ddl::logical_nodes::DistributedCayenne{Insert,Update,Delete}Node` +/// → `"CayenneInsert"` / `"CayenneUpdate"` / `"CayenneDelete"` (they reuse the +/// non-distributed names by design) +/// - `runtime::datafusion::cayenne_ddl::logical_nodes::DistributedCayenneMergeNode` +/// → `"DistributedCayenneMerge"` +pub const WRITE_CAPABLE_EXTENSION_NAMES: &[&str] = &[ + "DdlExtension", + "DmlExtension", + "CayenneInsert", + "CayenneUpdate", + "CayenneDelete", + "CayenneMerge", + "DistributedCayenneMerge", +]; + use crate::result::embeddings::CachedEmbeddingResult; #[derive(Debug, Snafu)] @@ -551,6 +576,11 @@ impl QueryResultsCacheProvider { | LogicalPlan::Dml(..) | LogicalPlan::Copy { .. } | LogicalPlan::Statement(..) => return false, + LogicalPlan::Extension(ext) + if WRITE_CAPABLE_EXTENSION_NAMES.contains(&ext.node.name()) => + { + return false; + } _ => {} } diff --git a/crates/runtime/src/datafusion/query.rs b/crates/runtime/src/datafusion/query.rs index d9546e7739..aec0160e6b 100644 --- a/crates/runtime/src/datafusion/query.rs +++ b/crates/runtime/src/datafusion/query.rs @@ -81,7 +81,9 @@ use super::{ use super::managed_runtime; use crate::datafusion::{ - DataFusion, query::cache::RequestCacheManager, sql_validator::validate_sql_query_operations, + DataFusion, + query::cache::RequestCacheManager, + sql_validator::{validate_sql_query_operations, validate_sql_query_read_only}, }; use managed_runtime::ManagedRuntimeError; use opentelemetry::KeyValue; @@ -190,6 +192,11 @@ pub struct Query { df: Arc, sql: QueryMethod, tracker: Option, + /// When true, the validator additionally rejects DDL, DML, COPY, or any + /// `LogicalPlan::Statement` node (including PREPARE/EXECUTE/DEALLOCATE), + /// regardless of per-catalog writability. Set via [`QueryBuilder::read_only`]; + /// used by `/v1/tools/sql` and `/v1/nsql` to contain LLM-generated SQL. + read_only: bool, } macro_rules! handle_error { @@ -310,7 +317,9 @@ impl Query { sql, parameters, .. } => { // Use the existing get_plan_or_cached which handles all cache control, - // stale-while-revalidate, and query tracking + // stale-while-revalidate, and query tracking. `read_only` is + // threaded through so cached results cannot bypass + // `validate_sql_query_read_only` below. match Query::get_plan_or_cached( &self.df, &session, @@ -318,6 +327,7 @@ impl Query { sql, parameters.clone(), tracker, + self.read_only, ) .await? { @@ -389,6 +399,12 @@ impl Query { let e = find_datafusion_root(e); return Err(Error::UnableToExecuteQuery { source: e }); } + if self.read_only + && let Err(e) = validate_sql_query_read_only(&plan) + { + let e = find_datafusion_root(e); + return Err(Error::UnableToExecuteQuery { source: e }); + } // Get the schema from the logical plan let schema = Arc::new(plan.schema().as_arrow().clone()); @@ -589,6 +605,7 @@ impl Query { sql, parameters.clone(), tracker, + ctx.read_only, ) .await? { @@ -619,6 +636,19 @@ impl Query { ) } + if ctx.read_only + && let Err(e) = validate_sql_query_read_only(&plan) + { + let e = find_datafusion_root(e); + handle_error!( + tracker, + &request_context, + ErrorCode::QueryPlanningError, + e, + UnableToExecuteQuery + ) + } + // Proactively invalidate cached query state for tables affected by // DML mutations (INSERT, DELETE, UPDATE). // - results cache must be cleared so repeated SQL does not replay @@ -878,6 +908,7 @@ impl Query { df: Arc::clone(df), sql: QueryMethod::Plan(Box::new(plan.clone())), tracker: None, + read_only: false, } } @@ -930,6 +961,13 @@ impl Query { self.handle_schema_error(&request_context, &e); return Err(e); } + if self.read_only + && let Err(e) = validate_sql_query_read_only(&plan) + { + let e = find_datafusion_root(e); + self.handle_schema_error(&request_context, &e); + return Err(e); + } let dataset_schema = plan.schema().as_arrow().clone(); let parameter_schema = parameter_schema_for_plan(&plan)?; diff --git a/crates/runtime/src/datafusion/query/builder.rs b/crates/runtime/src/datafusion/query/builder.rs index 999620fc98..370a73cdd4 100644 --- a/crates/runtime/src/datafusion/query/builder.rs +++ b/crates/runtime/src/datafusion/query/builder.rs @@ -31,6 +31,7 @@ pub struct QueryBuilder<'a> { parameters: Option, table_allowlist: Option, query_id: Uuid, + read_only: bool, } impl<'a> QueryBuilder<'a> { @@ -41,6 +42,7 @@ impl<'a> QueryBuilder<'a> { parameters: None, query_id: Uuid::new_v4(), table_allowlist: None, + read_only: false, } } @@ -62,6 +64,22 @@ impl<'a> QueryBuilder<'a> { self } + /// Enforce read-only SQL execution. + /// + /// When enabled, the planned query is additionally checked with + /// [`crate::datafusion::sql_validator::validate_sql_query_read_only`] and rejected if it + /// contains any DDL, DML, COPY, or `LogicalPlan::Statement` node (including + /// `PREPARE`/`EXECUTE`/`DEALLOCATE`) — regardless of whether the target + /// catalogs/datasets are individually marked writable. + /// + /// Used by surfaces that execute SQL on behalf of an LLM or unauthenticated caller + /// (the built-in `sql` tool, `/v1/nsql`). + #[must_use] + pub fn read_only(mut self, read_only: bool) -> Self { + self.read_only = read_only; + self + } + #[must_use] pub fn build(self) -> Query { let sql: Arc = self.sql.into(); @@ -91,6 +109,7 @@ impl<'a> QueryBuilder<'a> { table_allowlist: self.table_allowlist, }, tracker, + read_only: self.read_only, } } } diff --git a/crates/runtime/src/datafusion/query/cache.rs b/crates/runtime/src/datafusion/query/cache.rs index 9da92c7202..3d5adee69f 100644 --- a/crates/runtime/src/datafusion/query/cache.rs +++ b/crates/runtime/src/datafusion/query/cache.rs @@ -96,6 +96,19 @@ enum CacheResult { impl Query { /// Returns a `LogicalPlan` if the result is not cached and needs to be executed, otherwise returns a cached `QueryResult`. + /// + /// When `read_only` is true, both the SQL-keyed and plan-keyed results-cache + /// lookups are skipped, and the returned [`RequestCacheManager`] is forced to + /// [`CacheStatus::CacheDisabled`]. This is required because the read-only + /// contract (enforced by [`super::validate_sql_query_read_only`]) runs on the + /// [`LogicalPlan`] *after* `get_plan_or_cached` returns — a cache hit would + /// otherwise short-circuit validation and let write-capable plans served from + /// a prior cache-store bypass the read-only guarantee on `/v1/tools/sql` and + /// `/v1/nsql`. The existing cache-eligibility check + /// ([`cache::QueryResultsCacheProvider::cache_is_enabled_for_plan`]) only + /// filters the classic DDL/DML/Copy/Statement plan variants and does not + /// cover Spice's write-capable [`LogicalPlan::Extension`] nodes (e.g. + /// `DmlExtension`, `DistributedCayenneInsert`). pub(super) async fn get_plan_or_cached( df: &Arc, session: &SessionState, @@ -103,6 +116,7 @@ impl Query { sql: &str, parameters: Option, tracker: Option, + read_only: bool, ) -> super::Result { let cache_control = request_context.cache_control(); let sql_cache_key = CacheKey::Query(sql, parameters.as_ref()); @@ -117,25 +131,32 @@ impl Query { _ => sql_cache_key, }; - // Try to get cached results from SQL or client key + // Try to get cached results from SQL or client key. When `read_only` is + // true, skip the cache lookup entirely so read-only validation always + // gets a chance to run on the freshly-planned query. let CacheResponse { tracker, raw_key: sql_or_client_raw_key, .. - } = match Self::try_get_cached_result( - df, - &request_context, - tracker, - &sql_or_user_cache_key, - sql, - ) - .await? - { - CacheResponse { - result: CacheResult::Hit(result), - .. - } => return Ok(PlanOrCached::Cached(result)), - response => response, + } = if read_only { + CacheResponse::from(CacheResult::MissOrSkipped, CacheStatus::CacheDisabled) + .with_query_tracker(tracker) + } else { + match Self::try_get_cached_result( + df, + &request_context, + tracker, + &sql_or_user_cache_key, + sql, + ) + .await? + { + CacheResponse { + result: CacheResult::Hit(result), + .. + } => return Ok(PlanOrCached::Cached(result)), + response => response, + } }; let sql_raw_cache_key = sql_cache_key.as_raw_key(Self::plan_hasher(df)); @@ -154,26 +175,32 @@ impl Query { } }; - // Try to get cached results from plan + // Try to get cached results from plan (skipped for read-only, same + // reasoning as the SQL-keyed lookup above). let CacheResponse { mut tracker, raw_key: plan_raw_cache_key, status, .. - } = match Self::try_get_cached_result( - df, - &request_context, - tracker, - &CacheKey::LogicalPlan(&plan), - sql, - ) - .await? - { - CacheResponse { - result: CacheResult::Hit(result), - .. - } => return Ok(PlanOrCached::Cached(result)), - response => response, + } = if read_only { + CacheResponse::from(CacheResult::MissOrSkipped, CacheStatus::CacheDisabled) + .with_query_tracker(tracker) + } else { + match Self::try_get_cached_result( + df, + &request_context, + tracker, + &CacheKey::LogicalPlan(&plan), + sql, + ) + .await? + { + CacheResponse { + result: CacheResult::Hit(result), + .. + } => return Ok(PlanOrCached::Cached(result)), + response => response, + } }; let request_raw_cache_key = match request_context.cache_control() { @@ -185,7 +212,15 @@ impl Query { } .unwrap_or(sql_raw_cache_key); - let cache_status = Self::should_cache_results(df, &plan, status); + // Read-only requests must also not populate the results cache — the + // plan has not yet been validated at this point, and we don't want a + // writable surface's cached output to leak back through a read-only + // caller on a later identical query. + let cache_status = if read_only { + CacheStatus::CacheDisabled + } else { + Self::should_cache_results(df, &plan, status) + }; tracker = tracker.map(|t| t.results_cache_hit(false)); Ok(PlanOrCached::Plan( diff --git a/crates/runtime/src/datafusion/sql_validator.rs b/crates/runtime/src/datafusion/sql_validator.rs index 4ba58bf987..2b0ffb848d 100644 --- a/crates/runtime/src/datafusion/sql_validator.rs +++ b/crates/runtime/src/datafusion/sql_validator.rs @@ -24,6 +24,12 @@ use datafusion::{ use crate::datafusion::DataFusion; +// Re-export the single-source-of-truth list of write-capable extension node +// names. The `cache` crate owns this list so that both the read-only validator +// below and the SQL results-cache eligibility check in `cache` stay in lockstep +// — any write-capable extension must be non-cacheable AND blocked by read-only. +pub(super) use cache::WRITE_CAPABLE_EXTENSION_NAMES; + /// Validates that a logical plan only performs allowed operations on datasets. /// /// Reads (SELECT queries) are allowed on all tables. @@ -215,6 +221,58 @@ fn validate_ddl_operation( ) } +/// Strict read-only validator. +/// +/// Rejects any plan containing DDL, DML, COPY, or any `LogicalPlan::Statement` node +/// (including `PREPARE` / `EXECUTE` / `DEALLOCATE`). `EXECUTE` can indirectly invoke a +/// prepared DDL/DML statement and `PREPARE` / `DEALLOCATE` mutate session state, so all +/// three are disallowed on surfaces that must guarantee read-only execution — notably +/// the built-in `sql` tool and the LLM-generated SQL path in `/v1/nsql`. +/// +/// Spice's planner can also represent DDL/DML as [`LogicalPlan::Extension`] nodes +/// (for example, `DdlExtensionNode` from `datafusion-ddl`, `DmlExtensionNode` from +/// `datafusion-dml`, and the `DistributedCayenne{Insert,Update,Delete,Merge}` / +/// `CayenneMerge` distributed DML nodes). Those nodes are matched here by their +/// stable [`UserDefinedLogicalNodeCore::name`] so that write-capable plans produced +/// by Spice's custom planner cannot bypass the read-only guarantee. Any new +/// write-capable extension node type MUST be added to [`WRITE_CAPABLE_EXTENSION_NAMES`]. +/// +/// # Returns +/// * `Ok(())` if the plan contains only read operations. +/// * `Err(DataFusionError)` if the plan contains any write, schema-mutating, or +/// session-mutating operation. +pub fn validate_sql_query_read_only(plan: &LogicalPlan) -> Result<(), DataFusionError> { + plan.apply_with_subqueries(|node| match node { + LogicalPlan::Ddl(ddl) => plan_err!( + "DDL operation '{}' is not allowed in read-only SQL context.", + ddl.name() + ), + LogicalPlan::Dml(dml) => plan_err!( + "{} operations are not allowed in read-only SQL context.", + dml.name() + ), + LogicalPlan::Copy(_) => { + plan_err!("COPY operations are not allowed in read-only SQL context.") + } + LogicalPlan::Statement(stmt) => plan_err!( + "Statement '{}' is not allowed in read-only SQL context.", + stmt.name() + ), + LogicalPlan::Extension(ext) => { + let name = ext.node.name(); + if WRITE_CAPABLE_EXTENSION_NAMES.contains(&name) { + plan_err!( + "Write-capable extension plan '{name}' is not allowed in read-only SQL context." + ) + } else { + Ok(TreeNodeRecursion::Continue) + } + } + _ => Ok(TreeNodeRecursion::Continue), + })?; + Ok(()) +} + #[cfg(test)] mod tests { use crate::{ @@ -835,4 +893,265 @@ mod tests { "INSERT should be allowed on table in writable default catalog" ); } + + /// [`validate_sql_query_read_only`] must allow SELECT but reject every class of + /// write/schema-mutating plan, independent of per-catalog writability. This is the + /// contract that the built-in `sql` tool and `/v1/nsql` rely on to contain + /// LLM-generated SQL. + #[tokio::test] + async fn test_read_only_validator_allows_select() { + let df = create_test_datafusion(); + + let plan = df + .ctx + .state() + .create_logical_plan("SELECT * FROM tbl_writable") + .await + .expect("plan should be created"); + + validate_sql_query_read_only(&plan).expect("SELECT must be allowed in read-only context"); + } + + #[tokio::test] + async fn test_read_only_validator_rejects_insert_on_writable_dataset() { + let df = create_test_datafusion(); + + let plan = df + .ctx + .state() + .create_logical_plan("INSERT INTO tbl_writable VALUES (1, 'foo', 42.0)") + .await + .expect("plan should be created"); + + let err = validate_sql_query_read_only(&plan) + .expect_err("INSERT must be rejected in read-only context"); + assert!( + err.to_string().contains("read-only"), + "error should cite read-only context, got: {err}" + ); + } + + #[tokio::test] + async fn test_read_only_validator_rejects_delete_on_writable_dataset() { + let df = create_test_datafusion(); + + let plan = df + .ctx + .state() + .create_logical_plan("DELETE FROM tbl_writable WHERE id = 1") + .await + .expect("plan should be created"); + + validate_sql_query_read_only(&plan) + .expect_err("DELETE must be rejected in read-only context"); + } + + #[tokio::test] + async fn test_read_only_validator_rejects_ddl() { + let df = create_test_datafusion(); + + let plan = df + .ctx + .state() + .create_logical_plan("DROP TABLE IF EXISTS tbl_writable") + .await + .expect("plan should be created"); + + validate_sql_query_read_only(&plan).expect_err("DDL must be rejected in read-only context"); + } + + #[tokio::test] + async fn test_read_only_validator_rejects_copy() { + let df = create_test_datafusion(); + + let plan = df + .ctx + .state() + .create_logical_plan("COPY tbl_writable TO '/tmp/out.parquet'") + .await + .expect("plan should be created"); + + let err = validate_sql_query_read_only(&plan) + .expect_err("COPY must be rejected in read-only context"); + assert!( + err.to_string().contains("COPY"), + "error should cite COPY, got: {err}" + ); + } + + /// `PREPARE` mutates session state and the prepared statement could later be + /// `EXECUTE`d to run DDL/DML. The strict read-only validator must reject it. + #[tokio::test] + async fn test_read_only_validator_rejects_prepare() { + let df = create_test_datafusion(); + + let plan = df + .ctx + .state() + .create_logical_plan("PREPARE my_plan AS SELECT * FROM tbl_writable") + .await + .expect("plan should be created"); + + validate_sql_query_read_only(&plan) + .expect_err("PREPARE must be rejected in read-only context"); + } + + /// `EXECUTE` can invoke a prepared DDL/DML statement and must therefore be + /// rejected by the strict read-only validator. + #[tokio::test] + async fn test_read_only_validator_rejects_execute() { + let df = create_test_datafusion(); + + // PREPARE first to get a prepared plan on the session, then verify EXECUTE + // is rejected. PREPARE itself is also rejected, so run it through the + // non-strict path by bypassing the validator. + df.ctx + .sql("PREPARE my_plan AS SELECT 1") + .await + .expect("prepare should succeed"); + + let plan = df + .ctx + .state() + .create_logical_plan("EXECUTE my_plan") + .await + .expect("plan should be created"); + + validate_sql_query_read_only(&plan) + .expect_err("EXECUTE must be rejected in read-only context"); + } + + #[tokio::test] + async fn test_read_only_validator_rejects_deallocate() { + let df = create_test_datafusion(); + + df.ctx + .sql("PREPARE my_plan AS SELECT 1") + .await + .expect("prepare should succeed"); + + let plan = df + .ctx + .state() + .create_logical_plan("DEALLOCATE my_plan") + .await + .expect("plan should be created"); + + validate_sql_query_read_only(&plan) + .expect_err("DEALLOCATE must be rejected in read-only context"); + } + + /// Spice's custom planner represents DDL/DML as [`LogicalPlan::Extension`] + /// nodes (e.g. `DdlExtensionNode`, `DmlExtensionNode`, `DistributedCayenne*Node`). + /// The strict read-only validator must reject those by + /// [`UserDefinedLogicalNodeCore::name`] so a write-capable plan cannot + /// bypass the check by being wrapped in an extension node. + /// + /// Constructing a real `DmlExtensionNode` requires a full catalog-handler + /// wiring, so this test uses a minimal stub extension node whose `.name()` + /// matches one of the names in [`WRITE_CAPABLE_EXTENSION_NAMES`] to + /// exercise the name-based deny directly. + #[tokio::test] + async fn test_read_only_validator_rejects_write_capable_extension_node() { + use datafusion::{ + common::{DFSchema, DFSchemaRef}, + logical_expr::{Expr, Extension, LogicalPlan, UserDefinedLogicalNodeCore}, + }; + use std::cmp::Ordering; + use std::fmt; + + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + struct StubWriteExtension { + schema: DFSchemaRef, + name: &'static str, + } + + impl PartialOrd for StubWriteExtension { + fn partial_cmp(&self, other: &Self) -> Option { + self.name.partial_cmp(other.name) + } + } + + impl UserDefinedLogicalNodeCore for StubWriteExtension { + fn name(&self) -> &'static str { + self.name + } + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + fn expressions(&self) -> Vec { + vec![] + } + fn fmt_for_explain(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "StubWriteExtension({})", self.name) + } + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + _inputs: Vec, + ) -> Result { + Ok(self.clone()) + } + } + + let schema: DFSchemaRef = Arc::new(DFSchema::empty()); + + for banned_name in super::WRITE_CAPABLE_EXTENSION_NAMES { + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(StubWriteExtension { + schema: Arc::clone(&schema), + name: banned_name, + }), + }); + let err = validate_sql_query_read_only(&plan) + .err() + .unwrap_or_else(|| { + panic!("extension '{banned_name}' must be rejected in read-only context") + }); + let err_msg = err.to_string(); + assert!( + err_msg.contains(banned_name) && err_msg.contains("read-only"), + "error should cite '{banned_name}' and read-only, got: {err}" + ); + } + + // A benign (non-write) extension name must still be allowed so + // read-only optimizer extensions such as `IndexTableScanNode` and + // `DuckDBAggregatePushdownNode` are not blocked. + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(StubWriteExtension { + schema, + name: "IndexTableScanNode", + }), + }); + validate_sql_query_read_only(&plan) + .expect("non-write extension must be allowed in read-only context"); + } + + /// Integration check: DDL/DML produced through Spice's planner wrapper + /// ([`DataFusion::create_logical_plan`]) — rather than `DataFusion`'s raw + /// planner — must still be rejected. This covers the statement-level + /// rewrites (`plan_distributed_dml`, `plan_create_table`, etc.) that can + /// emit `LogicalPlan::Extension` instead of `LogicalPlan::{Ddl,Dml}`. + #[cfg(not(windows))] + #[tokio::test] + async fn test_read_only_validator_rejects_dml_via_spice_planner() { + let df = create_test_datafusion(); + let session = df.ctx.state(); + + // INSERT is dispatched through Spice's `plan_distributed_dml`. Without a + // distributed cluster it returns a standard `LogicalPlan::Dml`, still + // covered by the read-only arm — but this exercises the Spice wrapper + // rather than `SessionState::create_logical_plan` directly. + let plan = df + .create_logical_plan(&session, "INSERT INTO tbl_writable VALUES (1, 'foo', 42.0)") + .await + .expect("plan should be created via Spice planner"); + + validate_sql_query_read_only(&plan) + .expect_err("INSERT via Spice planner must be rejected in read-only context"); + } } diff --git a/crates/runtime/src/http/routes.rs b/crates/runtime/src/http/routes.rs index 2bdd1e873c..4ffc3e6ca5 100644 --- a/crates/runtime/src/http/routes.rs +++ b/crates/runtime/src/http/routes.rs @@ -263,6 +263,23 @@ pub(crate) fn routes( } if cfg!(feature = "models") { + // Tool invocation routes require authentication to be configured on the runtime. + // `/v1/tools/{name}` forwards the raw request body to `tool.call`, which for + // built-in tools like `sql` and `websearch` is equivalent to arbitrary query / + // egress. When no `runtime.auth` provider is attached the request would be + // anonymous, so we refuse these routes at the edge with a 401 rather than + // relying on each tool to enforce its own safety posture. Configure + // `runtime.auth.api_key` (or any future provider) to re-enable this surface. + let tools_auth_required = auth_layer.is_some(); + let tools_router = Router::new() + .route("/v1/tools", get(v1::tools::list)) + .route("/v1/tools/{*name}", post(v1::tools::post)) + // Deprecated, use /v1/tools/:name instead + .route("/v1/tool/{name}", post(v1::tools::post)) + .route_layer(middleware::from_fn(move |req, next| { + require_auth_configured(tools_auth_required, req, next) + })); + authenticated_router = authenticated_router .route("/v1/models", get(v1::models::get)) .route("/v1/models/{name}/predict", get(v1::inference::get)) @@ -278,10 +295,7 @@ pub(crate) fn routes( ) .route("/v1/embeddings", post(v1::embeddings::post)) .route("/v1/search", post(v1::search::post)) - .route("/v1/tools", get(v1::tools::list)) - .route("/v1/tools/{*name}", post(v1::tools::post)) - // Deprecated, use /v1/tools/:name instead - .route("/v1/tool/{name}", post(v1::tools::post)) + .merge(tools_router) .route("/v1/workers", get(v1::workers::get)) .layer(Extension(Arc::clone(&rt.completion_llms))) .layer(Extension(Arc::clone(&rt.models))) @@ -481,3 +495,26 @@ async fn check_shutdown( next.run(req).await } + +/// Reject a request with 401 unless the runtime has an authentication provider attached. +/// +/// Used to gate routes whose behavior is unsafe anonymously (`/v1/tools/*`: the raw +/// request body is handed to `tool.call`, which for built-ins like `sql` and +/// `websearch` is equivalent to arbitrary query / outbound fetch). +async fn require_auth_configured( + auth_configured: bool, + req: axum::http::Request, + next: Next, +) -> axum::response::Response { + if auth_configured { + return next.run(req).await; + } + + ( + http::StatusCode::UNAUTHORIZED, + axum::Json(serde_json::json!({ + "message": "Tool invocation (/v1/tools/*) requires `runtime.auth` to be configured. Configure an API key provider in your Spicepod (see https://spiceai.org/docs/reference/runtime#auth) and retry with credentials." + })), + ) + .into_response() +} diff --git a/crates/runtime/src/http/v1/nsql.rs b/crates/runtime/src/http/v1/nsql.rs index de1caaecea..fe0ec17e6c 100644 --- a/crates/runtime/src/http/v1/nsql.rs +++ b/crates/runtime/src/http/v1/nsql.rs @@ -438,9 +438,14 @@ pub(crate) async fn handle_nsql_query( tracing::debug!("Running query:\n{cleaned_query}"); - // Run the SQL with table allowlist enforcement + // Run the SQL with table allowlist enforcement. LLM-generated SQL is + // always executed in read-only mode: the runtime rejects any plan that + // contains DDL, DML, COPY, or a `LogicalPlan::Statement` node (including + // PREPARE/EXECUTE/DEALLOCATE) regardless of per-catalog writability, + // which mitigates model-mediated SQL injection on `/v1/nsql`. let query_result = { - let mut builder = QueryBuilder::new(&cleaned_query, Arc::clone(&df)); + let mut builder = + QueryBuilder::new(&cleaned_query, Arc::clone(&df)).read_only(true); if let Some(ref allowlist) = table_allowlist_opt { builder = builder.allow_tables(allowlist.clone()); } diff --git a/crates/runtime/src/tools/builtin/sql.rs b/crates/runtime/src/tools/builtin/sql.rs index 7e66499c24..86615d8a7d 100644 --- a/crates/runtime/src/tools/builtin/sql.rs +++ b/crates/runtime/src/tools/builtin/sql.rs @@ -35,12 +35,28 @@ pub struct SqlToolParams { /// The SQL query to run. Double quote all select columns and never select columns ending in '_embedding'. The `table_catalog` is 'spice'. Always use it in the query query: String, } + +/// Default description advertised to LLMs / tool selection when the `sql` tool +/// is in its read-only posture (the default). +const DEFAULT_READ_ONLY_DESCRIPTION: &str = "Run a read-only SQL query on the data source. Columns with capitals must be quoted. When needed quote each part of catalog.schema.table: \"catalog\".\"schema\".\"table\". Avoid 'SELECT *', and columns with `_offset` or `_embedding` suffix. DDL and write statements (INSERT/UPDATE/DELETE/COPY/CREATE/DROP) are rejected, as are session-mutating statements (PREPARE/EXECUTE/DEALLOCATE)."; + +/// Default description advertised to LLMs / tool selection when the operator +/// has opted the tool into writable mode via [`SqlTool::allow_writes`]. +const DEFAULT_WRITABLE_DESCRIPTION: &str = "Run an SQL query on the data source. Columns with capitals must be quoted. When needed quote each part of catalog.schema.table: \"catalog\".\"schema\".\"table\". Avoid 'SELECT *', and columns with `_offset` or `_embedding` suffix. This tool accepts write statements (INSERT/UPDATE/DELETE/DDL); use with caution."; + pub struct SqlTool { name: String, description: String, df: Arc, allowed_tables: Option, + /// When true (the default), the tool rejects any DDL, DML, COPY, or + /// `LogicalPlan::Statement` plan (including PREPARE/EXECUTE/DEALLOCATE) at + /// execution time. This prevents LLM- or caller-supplied SQL from mutating data + /// or session state via the `/v1/tools/sql` surface, even when a referenced + /// catalog/dataset is configured writable. Operators that need write access from + /// a tool should configure a distinct writable tool rather than flipping this flag. + read_only: bool, } impl SqlTool { @@ -54,9 +70,33 @@ impl SqlTool { Self { df, name: name.unwrap_or("sql").to_string(), - description: description.unwrap_or("Run an SQL query on the data source. Columns with capitals must be quoted. When needed quote each part of catalog.schema.table: \"catalog\".\"schema\".\"table\". Avoid 'SELECT *', and columns with `_offset` or `_embedding` suffix.").to_string(), - allowed_tables + description: description + .unwrap_or(DEFAULT_READ_ONLY_DESCRIPTION) + .to_string(), + allowed_tables, + read_only: true, + } + } + + /// Allow write statements (INSERT/UPDATE/DELETE/DDL). Defaults to off. + /// + /// This is an escape hatch for operators who have deliberately configured a + /// separate writable tool and understand that any LLM with tool-use access will + /// then be able to mutate the targeted catalog/dataset without per-call + /// confirmation. Leave the default in place unless that trade-off is acceptable. + /// + /// If the tool is still using the default read-only description, it is swapped + /// for the writable default so LLM/tool-selection logic is not misled by a + /// stale "read-only" advertisement. Operator-supplied descriptions are left + /// untouched — callers overriding the description are responsible for keeping + /// it accurate. + #[must_use] + pub fn allow_writes(mut self) -> Self { + self.read_only = false; + if self.description == DEFAULT_READ_ONLY_DESCRIPTION { + self.description = DEFAULT_WRITABLE_DESCRIPTION.to_string(); } + self } } @@ -79,7 +119,7 @@ impl SpiceModelTool for SqlTool { let tool_use_result: Result> = async { let req: SqlToolParams = serde_json::from_str(arg)?; - let mut query_builder = self.df.query_builder(&req.query); + let mut query_builder = self.df.query_builder(&req.query).read_only(self.read_only); if let Some(ref allowlist) = self.allowed_tables { query_builder = query_builder.allow_tables(allowlist.clone()); } diff --git a/docs/threat_models/v2.0.0.md b/docs/threat_models/v2.0.0.md new file mode 100644 index 0000000000..1d45d51778 --- /dev/null +++ b/docs/threat_models/v2.0.0.md @@ -0,0 +1,310 @@ +# Threat Model for Spice.ai OSS v2.0.0 + +*Owner*: Phillip LeBlanc +*Description*: Spice is a portable, open-source runtime for fast, last-mile SQL query, search, and AI inference, written in Rust. + +## Overview + +Any entity that can send requests to a Spice instance is a potential threat actor. The main risk is Spice doing *more* than the runtime was explicitly configured to do. + +In the Spice.ai threat model, the `spicepod.yaml` is the **Root of Trust**. The runtime's job is to faithfully execute that configuration and enforce the boundaries it defines. If the runtime fails to enforce those boundaries, that is a vulnerability in Spice. + +v2.0 expands the threat surface with DDL support (CREATE/DROP TABLE), async query materialization to object storage, distributed ingestion across executor nodes, the DuckLake catalog, the GCS data connector, Cayenne RC improvements (staged WAL writes, S3 Express One Zone), a fully distributed query architecture with active-active HA schedulers, object-store-based state coordination, ExpandSecret RPC for secret forwarding, bidirectional gRPC control streams, URL Tables (opt-in SQL-level SSRF surface), SMB/NFS data connectors, and acceleration snapshot management APIs. + +## Threat Actors + +Since Spice acts as a middleware or "sidecar," the threat actors are primarily defined by their network position relative to Spice. + +- **The "Lateral" Attacker**: An attacker who has compromised a different service in the same Kubernetes cluster or VPC. They have network access to Spice's ports but no filesystem access. They want to use Spice as a proxy to reach upstream databases (Postgres/MySQL) that are firewalled off from the rest of the network, or to steal the credentials Spice holds in memory. Without an API key, they should not be able to perform any actions with Spice or leak any information. +- **The Malicious Client**: A valid user or service authorized to query Spice (has an API key), but wants to do more than allowed. They are trying to escalate their privilege beyond what they are supposed to be able to do in Spice. They have access to `Dataset A` but are crafting SQL inputs to try and read `Dataset B` (which exists in the source but not exposed to them) or write data. This is the most important threat actor and must be tightly constrained by the configuration. +- **The Rogue Executor** (new in v2.0): A compromised or malicious Ballista executor node in a distributed cluster. With distributed ingestion, executors now read from data sources and write to accelerated tables. A rogue executor could exfiltrate data during ingestion, inject poisoned data into accelerated tables, or intercept secrets forwarded via the ExpandSecret RPC. +- **The Storage-Adjacent Attacker** (new in v2.0): An attacker with access to object storage buckets (S3/GCS) used for async query results, Cayenne Vortex files, DuckLake data, acceleration snapshots, scheduler registry records, partition metadata, or job state. They could tamper with or read materialized data and cluster coordination state. + +## Assets + +Spice is responsible for protecting: + +- **Upstream credentials**: The connection strings, passwords, or keys Spice uses to talk to backend databases (e.g., Postgres, Dremio, Snowflake) and cloud providers (AWS, GCP). +- **Upstream Data Integrity**: The data sitting in the backend. Spice must not accidentally allow deletion or modification outside of configured write paths (Iceberg, Cayenne DDL). +- **Data Confidentiality (Scope)**: Ensuring data from `Table A` is not leaked when querying `Table B`. +- **Accelerated Data Integrity** (expanded in v2.0): Cayenne Vortex files, WAL files, DuckLake metadata, and async query results must not be tampered with. +- **Schema Integrity** (new in v2.0): DDL operations (CREATE/DROP TABLE) must only affect authorized catalogs. +- **Cluster Coordination State** (new in v2.0): Scheduler registry records, partition assignments, and job state stored in object storage must not be tampered with. Compromise of this state could redirect queries, partition data to malicious nodes, or disrupt cluster operations. +- **Secret Confidentiality in Cluster** (new in v2.0): Secrets forwarded to executors via the ExpandSecret RPC must remain confidential in transit and not be persisted on executor nodes. + +## Threat Surface & Vectors + +The threat surface is the **Network API** (HTTP, Flight, FlightSQL, OpenTelemetry, MCP, Iceberg REST, AI/model/search endpoints), the **SQL Engine Logic** (including the DDL/DML/FlightSQL planner crates), the **Cluster Control Plane** (gRPC on internal port 50052 via `ClusterService` and `SchedulerGrpcServer`), **Object Storage** (S3/GCS) and **file://-backed cluster state** (scheduler registry, partition metadata, and job state), **Outbound Telemetry** (OTEL exporters and optional Prometheus scrape listener), and the **Spicepod Root of Trust**. + +### Port Layout + +| Port | Visibility | Services | mTLS / Auth | +| ----- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------- | +| 50051 | Public | `FlightServiceServer` (user queries, FlightSQL), `OtelService` | Optional TLS; auth via `api_key_auth` / header forwarding | +| 8090 | Public | HTTP API (REST queries, health, status, models, search, MCP, async queries, Iceberg REST) | Optional TLS; auth via `runtime-auth` layer | +| 9090 | Public (opt-in) | Prometheus metrics — only bound when a metrics bind address is configured; served on a separate listener that bypasses HTTP API auth | Optional TLS; no API-key auth | +| 50052 | Internal | `SchedulerGrpcServer`, `ClusterService` (full RPC set, see below) | **Required** (opt-out only via `--allow-insecure-connections`) | + +The main vectors for exploitation are: + +- **Isolation breakout ("table hopping")**: An attacker sends a query for a valid dataset, but uses SQL injection or engine flaws to reference a dataset not exposed in the `spicepod.yaml`, or uses a "Pass-through" vulnerability to execute commands directly on the backend. +- **Read-Only Bypass (Integrity Violation)**: An attacker constructs a query that results in a state change in the backend database. +- **DDL Abuse** (new in v2.0): An attacker uses CREATE TABLE to exhaust storage or create tables pointing to attacker-controlled locations, or uses DROP TABLE to destroy data in Iceberg or Cayenne catalogs. +- **Authentication Bypass (API Surface)**: If `spicepod.yaml` mandates an API Key (`api_key_auth`), the runtime must enforce it globally. +- **Information Leakage via Errors**: Spice encounters an error and returns raw error messages containing credentials or infrastructure details. +- **Function-Based SSRF (Server-Side Request Forgery)**: SQL engine functions that reach out to the network could be abused to make Spice initiate connections to internal servers. +- **Async Query Result Exposure** (new in v2.0): Materialized query results (Arrow IPC chunks with 12h default TTL) in object storage could be accessed by unauthorized parties if bucket ACLs are overly permissive. +- **Distributed Ingestion Tampering** (new in v2.0): A rogue executor could inject malicious data during distributed ingestion or exfiltrate source data. +- **WAL Poisoning** (new in v2.0): An attacker with filesystem access could modify Cayenne WAL files to inject data on the next flush. +- **DuckLake Metadata Manipulation** (new in v2.0): Access to the DuckLake DuckDB metadata file could allow redirecting table references to malicious Parquet files. +- **Scheduler Registry Poisoning** (new in v2.0): An attacker with object storage write access could inject fake scheduler records at `{prefix}/schedulers/{id}.json`, redirecting executors to a malicious scheduler. +- **Secret Forwarding Interception** (new in v2.0): The ExpandSecret RPC transmits secrets from scheduler to executor. Without mTLS, an attacker could intercept credentials in transit. +- **Control Stream Hijacking** (new in v2.0): The bidirectional gRPC control stream carries privileged commands (UpdatePartitions, CancelTasks, PollNow). A man-in-the-middle could inject commands to reassign data or abort queries. +- **Partition Metadata Tampering** (new in v2.0): Partition assignments stored at `{prefix}/accelerations/partitions/{table}.json` could be modified to redirect data reads to attacker-controlled executors. +- **Shuffle Data Exposure** (new in v2.0): Intermediate shuffle data in executor memory, local disk, or object storage could be read by a compromised executor or adjacent attacker. +- **URL Tables SSRF** (new in v2.0): When URL Tables are enabled (`runtime.params.url_tables: enabled`), any authenticated user can issue SQL queries against arbitrary URLs (S3, Azure Blob, HTTP/HTTPS), using Spice's ambient cloud credentials. This is an **authenticated-user SSRF**, not just a misconfiguration risk — disabled by default. +- **SMB Internal Network Access** (new in v2.0): The SMB connector accepts `user`, `pass`, `port`, and `client_timeout` parameters. If an attacker can influence configuration, datasets could point to internal file shares and leak SMB credentials via errors/logs. +- **NFS Internal Network Access** (new in v2.0): The NFS connector has no application-level authentication (it exposes only `client_timeout` and listing-table parameters) and relies entirely on host/network-level controls. Datasets pointing at internal NFS exports are bounded only by network reachability and NFS export permissions. +- **Acceleration Snapshot Rollback** (new in v2.0): The snapshot management API allows rolling back to previous data snapshots. A malicious client could revert to stale data. Current snapshot listing now reports verified vs. unverified state (size/existence checks), but those checks do not cover cryptographic integrity. +- **Prometheus Metrics Reconnaissance** (new in v2.0): When the Prometheus metrics listener is enabled, it runs on a separate TCP/TLS server that bypasses the HTTP API auth layer. Scraping reveals query patterns, dataset sizes, error rates, and cluster topology. Metrics are **disabled by default** and only exposed when an operator configures a metrics bind address. +- **Outbound Telemetry Exfiltration** (new in v2.0): Anonymous telemetry is **enabled by default** and posts to `telemetry.spiceai.io` unless explicitly disabled. The local `task_history` table additionally retains chat prompts, tool calls, and SQL text, accessible via SQL to any principal with access to `spice.task_history`. +- **AI/Tool Endpoint Abuse** (new in v2.0): `/v1/models`, `/v1/chat/completions`, `/v1/responses`, `/v1/embeddings`, `/v1/search`, `/v1/tools`, `/v1/nsql`, `/v1/evals/*`, and MCP SSE are live on the authenticated HTTP router. The built-in `sql` tool can be auto-invoked by the LLM (no user confirmation), `/v1/nsql` executes LLM-generated SQL, and `/v1/tools/{name}` forwards the raw request body straight to the tool. See the "AI / Tool / Search Surface" section below. + +## Trust Assumptions + +- The host/K8s environment and `spicepod.yaml` are controlled by a trusted operator. +- Attackers **do not** have direct filesystem or ConfigMap edit access. **An attacker with filesystem access to modify `spicepod.yaml` is already considered to have full control of the environment.** +- Denial-of-service via expensive queries is treated as "expected database behavior" and handled by infra controls, not Spice itself. +- Cluster nodes (scheduler and executors) authenticate via built-in mTLS on internal port 50052. mTLS is required by default; the `--allow-insecure-connections` flag is available for development/testing only and should never be used in production. Rogue executor injection is mitigated primarily by mTLS certificate verification and secondarily by network isolation. The `spice cluster tls init` / `spice cluster tls add` CLI commands are convenience tooling for generating **development** PKI; production deployments are expected to use an operator-managed CA. +- Object storage (S3/GCS) access is protected by cloud IAM policies managed by the operator. Cluster state stored in object storage *or* in a local `file://` directory (scheduler registry, partition metadata, job state) relies on the same IAM/filesystem controls for integrity. +- Secrets forwarded via ExpandSecret RPC are protected by mTLS in transit and are not persisted on executor nodes. + +## New Threats in v2.0 + +| # | Threat | Severity | STRIDE Type | +| --- | ------------------------------------------------------------------ | -------- | --------------------------- | +| 31 | DDL operations allow schema manipulation (CREATE/DROP TABLE) | High | Tampering | +| 32 | Async query results readable from object storage | Medium | Information Disclosure | +| 33 | Distributed ingestion allows unauthorized data writes | High | Tampering | +| 34 | Cayenne WAL poisoning or corruption | Medium | Tampering | +| 35 | DuckLake catalog metadata manipulation | Medium | Tampering | +| 36 | GCS credential exposure via data connector | Medium | Information Disclosure | +| 37 | Executor data exfiltration via distributed ingestion | High | Information Disclosure | +| 38 | Async query result tampering in object storage | Medium | Tampering | +| 39 | Scheduler registry poisoning via object storage | High | Spoofing | +| 40 | Secret forwarding interception via ExpandSecret RPC | High | Information Disclosure | +| 41 | Control stream hijacking | High | Tampering | +| 42 | Partition metadata tampering in object storage | High | Tampering | +| 43 | Shuffle data exposure in executor memory or storage | Medium | Information Disclosure | +| 44 | URL Tables enable SQL-level SSRF | High | Elevation of Privilege | +| 45 | SMB connector exposes internal SMB shares (with credentials) | Medium | Information Disclosure | +| 45a | NFS connector exposes internal NFS exports (no app-level auth) | Medium | Information Disclosure | +| 46 | Acceleration snapshot rollback serves stale data | Medium | Tampering | +| 47 | GetAppDefinition RPC exposes full Spicepod configuration | High | Information Disclosure | +| 48 | Unauthenticated Prometheus metrics exposure | Low | Information Disclosure | +| 49 | LLM auto-invokes tools without user confirmation | High | Elevation of Privilege | +| 50 | `/v1/nsql` executes LLM-generated SQL | High | Tampering / Info Disclosure | +| 51 | `/v1/tools/{name}` exposes direct tool invocation | High | Elevation of Privilege | +| 52 | MCP stdio transport allows local process execution | High | Elevation of Privilege | +| 53 | Search filter/column injection into planner | Medium | DoS / Tampering | +| 54 | Evals persist inputs/outputs into `spice.evals.*` | Medium | Info Disclosure | +| 55 | Eval endpoint cost-amplification | Medium | Denial of Service | +| 56 | Anonymous-by-default API when `runtime.auth` unset | High | Spoofing | +| 57 | No JWT/OIDC provider for external auth | Medium | Elevation of Privilege | +| 58 | Databricks auth header principal confusion | Medium | Spoofing | +| 59 | gRPC auth middleware constructed but not applied | Medium | Elevation of Privilege | +| 60 | No global inbound rate limit on high-cost AI routes | Medium | Denial of Service | +| 61 | Iceberg discovery API enumerates all visible tables | Medium | Info Disclosure | +| 62 | Iceberg route-prefix inconsistency | Medium | Tampering | +| 63 | Runtime catalog visibility is only authZ boundary for Iceberg REST | Medium | Elevation of Privilege | +| 64 | Anonymous telemetry enabled by default to telemetry.spiceai.io | Medium | Info Disclosure | +| 65 | `task_history` retains chat/tool/SQL content queryable via SQL | High | Info Disclosure | +| 66 | Misconfigured OTEL endpoint enables data egress | Medium | Info Disclosure | +| 67 | OTEL exporter has no auth-header support | Low | Denial of Service | +| 68 | Cron workers run with ambient runtime credentials | High | Elevation of Privilege | +| 69 | Eval results persist sensitive inputs/outputs indefinitely | Medium | Info Disclosure | + +## Updated Threats from v1.9.1 + +| # | Threat | Change | +| --- | --------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| 1 | Tampering with file-based accelerators | Expanded to cover Cayenne WAL files and DuckLake metadata | +| 14 | Malicious SQL / SQL injection leading to ACE | Updated mitigation: access-mode/catalog-capability DDL gating (Iceberg + Cayenne), SQL identifier quoting, DataFusion error sanitization, prepared statements (scoped) | +| 26 | Unsigned acceleration snapshots or artifacts | Expanded to cover async query results, DuckLake Parquet, and cluster state artifacts | +| 27 | Rogue cluster node injection | Updated for built-in mTLS, ExpandSecret RPC, control streams, and scheduler registry | +| 12 | Data source details leaked via errors/logging | Updated mitigation: token redaction implemented, /v1/spicepods returns summaries | + +## Distributed Query Architecture + +v2.0 introduces a fully distributed query model: + +- **Active-active HA**: Multiple scheduler nodes register scheduler records at `schedulers/{scheduler_id}.json` and cluster metadata at `metadata/cluster.json` in the cluster state store (object storage or `file://`). Default heartbeat TTL is 30 seconds. Any scheduler can serve queries; stale schedulers are detected and removed automatically. +- **Partition-aware distribution**: Accelerated tables are partitioned across executors. The scheduler uses greedy set cover to select executors that can serve required partitions with minimal data movement. Partition metadata is stored under `accelerations/partitions/` via `object_store_occ`. +- **Control streams**: Executor-initiated bidirectional gRPC streams carry messages in both directions. **Scheduler → executor** messages include `UpdatePartitions`, `CancelTasks`, and `PollNow` (privileged commands). **Executor → scheduler** messages include `Heartbeat` (sent every 10s) and periodic metrics reports. Network connections are always initiated executor→scheduler. +- **ClusterService gRPC**: A dedicated protobuf service on internal port 50052. The full RPC surface includes `GetAppDefinition`, `ExpandSecret`, `GetSchedulers`, `GetTaskHistory`, `GetMetrics`, `ControlStream`, and `AllocateInitialPartitions`. `GetAppDefinition` and `ExpandSecret` were previously Flight Actions on the public port (50051) and were moved to the internal mTLS-protected port as a critical security improvement. +- **Secret forwarding**: Executors resolve secrets via the ExpandSecret RPC to the scheduler. Secrets are never stored locally on executor nodes. +- **Async queries**: The `/v1/queries` API submits queries for background execution. Job metadata is written to `jobs/{job_id}.json` and result chunks to `jobs/{job_id}/chunk_N.arrow` in the state store, with configurable TTL (default 12h). Query IDs are UUIDv7-derived, timestamp-prefixed identifiers (Databricks-style), not raw UUID strings. +- **Cluster state backends**: All cluster coordination state uses ETag-based optimistic concurrency control (OCC). Supported backends are object storage (S3/GCS) and — as of post-v2.0 — a local `file://` directory. Local file-backed state is a distinct threat surface (filesystem permissions replace cloud IAM). +- **mTLS**: Built-in mTLS is required by default for all cluster communication on internal port 50052 (`--node-mtls-ca-certificate-file`, `--node-mtls-certificate-file`, `--node-mtls-key-file`). The `--allow-insecure-connections` flag allows opting out for development only. CLI tooling (`spice cluster tls init`, `spice cluster tls add`) generates **development** certificates; production deployments must use operator-managed PKI. +- **Authentication boundary**: External client authentication uses API key auth (`api_key_auth` in spicepod) and header-forwarded credentials via `runtime-auth`. Internal cluster authentication uses mTLS. These are distinct mechanisms with different trust models. The `--cluster-api-key` flag was removed in favor of mTLS-only cluster auth. + +## DDL and DML Enforcement + +DDL (CREATE/DROP TABLE) and DML (INSERT/UPDATE/DELETE) are **not a simple static allowlist**. They are gated two ways: + +1. **Catalog access mode**: The SQL validator only permits DDL on catalogs marked `access: read_write_create`, and only permits UPDATE/DELETE on catalogs marked as writable. Each catalog's access mode is set at load time. +2. **Backend capability**: Only backends that implement the DDL/DML analyzer rules can accept those statements. Currently, **Iceberg** and **Cayenne** are the DDL-capable backends; UPDATE is narrower (writable Cayenne-backed catalogs only). + +Since post-v2.0, DDL/DML/FlightSQL planner logic has been extracted into dedicated crates (`datafusion-ddl`, `datafusion-dml`, `datafusion-flightsql`). These crates formalize the attack surface but do not change the gating rules. + +Prepared statements are accepted by the SQL validator, but **should not be treated as a universal injection mitigation** — they cover parameterized literal values, not dynamic identifiers or SQL composed by tools/LLMs. + +## Security Hardening Mitigations (Implemented) + +The following security-relevant mitigations have been implemented and are tracked as controls: + +- **SQL identifier validation/quoting** in component loading and `top_n_sample` helpers (mitigates SQL injection via table/column names) +- **Token redaction** in debug/error output for credential-bearing fields (mitigates credential leakage in logs) +- **Recursion depth limits** for DynamoDB and S3 Vectors (mitigates DoS via deeply nested data) +- **Spicepod summary API** — `/v1/spicepods` returns summaries instead of full configuration (mitigates connection detail leakage) +- **Input sanitization** for `top_n_sample` `order_by` clause (mitigates SQL injection in the sampling tool) +- **Prepared statements** — parameterized queries mitigate injection of literal values in federated execution (see scope note above) +- **Isolated refresh runtime** — refresh tasks run on a dedicated Tokio runtime, preventing them from blocking the query API +- **DataFusion error sanitization** — raw DataFusion errors are scrubbed before being surfaced externally +- **Function registry deny list** — honored in accelerated-table filter pushdown so denied UDFs cannot be smuggled through predicates +- **GetAppDefinition/ExpandSecret moved to internal port** — no longer reachable on public Flight port (50051) +- **Metrics listener off by default** — Prometheus is only bound when explicitly configured +- **`/v1/tools/*` requires configured auth** — the tool invocation surface returns 401 unless `runtime.auth` is configured, closing the unauthenticated-tool-call path that made `#51` equivalent to arbitrary SQL when combined with `#56` (mitigates `#51`) +- **Built-in `sql` tool is read-only by default** — `SqlTool` runs through a strict read-only validator that rejects DDL/DML/COPY/Statement regardless of per-catalog writability; operators must explicitly opt in to writes (mitigates `#51` write-path) +- **`/v1/nsql` executes LLM-generated SQL read-only** — all model-produced SQL is run through the same read-only validator before execution on `QueryBuilder`, so prompt-injection-driven writes cannot reach writable catalogs (mitigates `#50` write-path) + +> **Unverified mitigation**: the previously-listed "path traversal prevention in tar extraction" mitigation is not clearly present in Spice-owned snapshot archive code. Treat snapshot archives from untrusted sources as hostile until a traversal-rejection check is verified. + +## AI / Tool / Search Surface (Follow-up Modeling) + +The authenticated HTTP router exposes an AI inference surface that directly executes SQL, calls LLM providers, and invokes tools on behalf of the caller. Unless API-key auth is configured, **all of these routes are reachable anonymously**. + +### Route Inventory + +| Route | Effect | Attacker-Controllable Input | +| ------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | --------------------------- | +| `POST /v1/chat/completions` | Forwards chat request to configured model; prompt + captured output logged to task history | Full request body | +| `POST /v1/responses` | OpenAI-compatible responses API; forwards to model | Full request body | +| `POST /v1/embeddings` | Generates embeddings via configured embedding model | Text payloads | +| `POST /v1/nsql` | **LLM generates SQL from NL, then runtime executes it** via `QueryBuilder::run` | NL query; dataset selection | +| `GET /v1/models`, `GET /v1/models/{name}/predict` | Lists models; runs local model inference | Model name, inputs | +| `POST /v1/search` | Vector + keyword search; attacker-supplied `text`, `where_cond`, `additional_columns`, `keywords` flow into planner | Search body | +| `GET /v1/tools`, `POST /v1/tools/{name}` | **Raw request body passed directly to tool `.call`** | Entire body | +| `POST /v1/evals/{name}` | Triggers eval run; executes model calls + dataset SQL; persists input/actual/expected to `spice.evals.runs` / `spice.evals.results` | Eval inputs | +| `GET /v1/evals`, `GET /v1/workers` | Metadata listing | n/a | +| `GET /v1/mcp/sse` | MCP server over SSE (feature-gated) | MCP client messages | +| `POST /v1/mcp/sse` | MCP server over SSE (feature-gated) | MCP client messages | +| `GET /v1/catalogs`, `GET /v1/datasets` | Lists catalogs/datasets visible to the runtime | n/a | + +### Built-in Tools (LLM-Invocable Without Confirmation) + +`ToolUsingChat` / `ToolUsingResponses` auto-invoke tools returned by the LLM with no user approval step. The built-in tool catalog includes: + +- `sql` — executes caller-provided SQL via `QueryBuilder` (read-only by default; writes require explicit opt-in via `SqlTool::allow_writes()`). +- `sample_distinct_columns`, `random_sample`, `top_n_sample` — read dataset samples. +- `search` — runs vector/keyword search across configured datasets. +- `websearch` — outbound HTTP to a web-search provider (egress amplifier). +- `list_datasets`, `table_schema`, `get_readiness` — metadata enumeration. +- MCP `stdio` client — can spawn local processes on the runtime host when an MCP tool is configured with a stdio transport (config-time risk; not attacker-triggerable at request time, but promotes Spicepod tampering to local code exec on the runtime host). + +### Threats (AI / Tool / Search) + +| # | Threat | Severity | STRIDE | +| --- | ------------------------------------------------------------------------------------------------------------------------------- | -------- | ----------------------------- | +| 49 | LLM auto-invokes tools without user confirmation (including `sql`), enabling prompt-injection-driven SQL or websearch egress | High | EoP / Tampering | +| 50 | `POST /v1/nsql` executes LLM-generated SQL against runtime datasets | High | Tampering / Info Disclosure | +| 51 | `POST /v1/tools/{name}` exposes direct unauthenticated tool invocation (including `sql` tool) when auth is not configured | High | EoP / Tampering | +| 52 | MCP stdio transport allows operator-configured commands to run as the Spice process (Spicepod tampering → local code execution) | High | EoP | +| 53 | Search `where_cond` / `keywords` / `additional_columns` injection into search planner (DoS + possible filter bypass) | Medium | DoS / Tampering | +| 54 | `/v1/evals/{name}` persists user-supplied inputs and model responses into `spice.evals.*` — long-lived sensitive data retention | Medium | Info Disclosure / Repudiation | +| 55 | Eval endpoint triggers unbounded model calls — cost amplification DoS | Medium | DoS | + +## runtime-auth Layer (Follow-up Modeling) + +- **Default is anonymous**: if `runtime.auth` is not set in the Spicepod, `EndpointAuth` constructs a no-auth provider and all `/v1/*` routes are reachable without credentials. Only `GET /health` and `GET /v1/ready` are intentionally unauthenticated. +- **Only provider implemented is API-key auth**: the `X-API-Key` HTTP header, the Flight token header, and gRPC metadata `api-key`. There is currently no built-in JWT, OIDC, basic, or mTLS provider for *external* auth. mTLS is only used for internal cluster auth on port 50052. +- **Flight auth is always mounted** as a `BasicAuthLayer`, with the handshake path explicitly exempted; the HTTP auth layer, by contrast, is an **optional middleware** that is only mounted when a verifier is configured. +- **gRPC auth trait exists but is currently unhooked in runtime server wiring** — the `EndpointAuth::grpc` provider is constructed but no callsite applies it as middleware. Any future gRPC surface outside `FlightServiceServer` may be unauthenticated by default. +- **Credential passthrough is narrow, not generic**: only the Databricks-specific `spice-databricks-auth` request header is parsed into `DatabricksAuthExtension` and consumed by the Databricks U2M token provider. There is no generic user-header → connector forwarding mechanism. +- **Rate limiting**: `runtime-rate-control` exists and is wired into *outbound* LLM providers (e.g., OpenAI) and the Flight *write* path. There is no global inbound HTTP rate limit for high-cost routes (`/v1/chat/completions`, `/v1/nsql`, `/v1/evals/*`, `/v1/search`). + +### Threats (runtime-auth) + +| # | Threat | Severity | STRIDE | +| --- | ----------------------------------------------------------------------------------------------------------------------------------------- | -------- | -------------------------- | +| 56 | Anonymous-by-default API surface when `runtime.auth` is not configured | High | Spoofing / Info Disclosure | +| 57 | No JWT/OIDC provider — deployments without a reverse proxy cannot do per-user identity or claims-based authZ | Medium | EoP | +| 58 | Databricks auth header extension creates a user-controlled principal confusion between datasets backed by different Databricks workspaces | Medium | Spoofing | +| 59 | gRPC auth middleware is constructed but not applied — future gRPC endpoints risk defaulting to unauthenticated | Medium | EoP | +| 60 | No global inbound rate limit on high-cost AI routes enables cost-amplification DoS | Medium | DoS | + +## Iceberg REST Catalog Server (Follow-up Modeling) + +Spice exposes an **in-process, discovery-only Iceberg REST-style API** over the authenticated HTTP router: + +- Exposed ops: `config`, `list namespaces`, `get/head namespace`, `list tables`, `get/head table metadata`. +- **No `create`, `drop`, `commit`, or write ops are registered**. Iceberg writes still happen through the data-connector path, not this API. +- Authorization: runtime DataFusion catalog/schema/table visibility is used as the access boundary; there is no separate warehouse-IAM check inside these handlers. +- Bridge scope: only catalogs already registered in the runtime context — attacker cannot specify arbitrary remote catalogs in the request payload. +- Route registration uses `/v1/config` and `/v1/namespaces`; the handler docstrings describe `/v1/iceberg/...`. This **route-prefix inconsistency should be verified** (either the routes are mounted under `/v1/` at the root Iceberg namespace, or the docstrings are stale). + +### Threats (Iceberg REST) + +| # | Threat | Severity | STRIDE | +| --- | --------------------------------------------------------------------------------------------------------------------------------- | -------- | --------------------- | +| 61 | Iceberg discovery API enumerates every catalog/namespace/table visible to the runtime identity, bypassing any per-user scoping | Medium | Info Disclosure | +| 62 | Route-prefix inconsistency (`/v1/namespaces` vs documented `/v1/iceberg/...`) risks surprise exposure or unintended URL shadowing | Medium | Tampering | +| 63 | Runtime catalog visibility is the only authZ boundary — there is no per-table ACL hook for Iceberg REST responses | Medium | EoP / Info Disclosure | + +## Outbound Telemetry (Follow-up Modeling) + +- **Anonymous telemetry is ENABLED BY DEFAULT** unless explicitly disabled (`--disable-telemetry` or runtime telemetry config off). The default remote endpoint is `telemetry.spiceai.io`. Operators running Spice in air-gapped or sensitive environments must explicitly turn this off. +- **Runtime metrics pipeline** (when the metrics server is enabled) combines a Prometheus exporter, a Spice `spice_metrics` Arrow exporter, and an optional OTEL push reader. +- **OTEL push exporter config** currently exposes: `enabled`, `endpoint`, `push_interval`, and a metrics whitelist. **There is no `headers` field in the current OTEL exporter config on trunk** — an earlier draft of this model incorrectly said custom auth headers had landed. Without headers, authenticated OTEL collectors cannot be used without a sidecar proxy. +- **Task-history tracing exporter** captures span attributes including `input` and `captured_output` for chat/tool/SQL paths, and writes them to the runtime `task_history` table (accessible via SQL). Chat prompts, tool calls, and SQL text are therefore retained locally by default. +- Optional **Zipkin export** for task-history traces is available via the runtime tracing config. + +### Threats (Outbound Telemetry) + +| # | Threat | Severity | STRIDE | +| --- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | ----------------------------- | +| 64 | Anonymous telemetry enabled by default sends runtime usage data to `telemetry.spiceai.io` without explicit operator opt-in | Medium | Info Disclosure / Repudiation | +| 65 | `task_history` retains chat prompts, tool calls, and SQL text in a queryable table — any principal with SQL access to `spice.task_history` can read other users' prompts | High | Info Disclosure | +| 66 | Misconfigured OTEL exporter endpoint redirects metrics (and, if extended to spans, query/prompt content) to attacker-controlled collectors | Medium | Info Disclosure | +| 67 | OTEL exporter currently has no auth-header support, preventing use of authenticated collectors without a sidecar (availability/integrity concern, not a direct exfil vector) | Low | DoS | + +## Workers / Evals (Follow-up Modeling) + +- **Workers are runtime components** (cron-scheduled model/SQL tasks loaded from the Spicepod), *not* Ballista executors. They are triggered by an in-process scheduler. +- **Worker identity is the runtime/operator identity** — cron-triggered `WorkerPromptTask` and `WorkerSqlTask` run with ambient runtime secrets and full dataset visibility, regardless of who authored the Spicepod entry. +- **`/v1/workers` is listing-only** — there is no HTTP trigger endpoint, so workers cannot be invoked ad-hoc by external callers. Compromise vector is Spicepod tampering. +- **`/v1/evals/{name}` is an HTTP trigger** and executes model calls plus dataset SQL as part of the eval run. Inputs, model outputs, and expected values are persisted to `spice.evals.runs` and `spice.evals.results` (including `input`, `actual`, `expected` columns), creating a long-lived record of potentially sensitive data. + +### Threats (Workers / Evals) + +| # | Threat | Severity | STRIDE | +| --- | ---------------------------------------------------------------------------------------------------------------------------------------- | -------- | ----------------------------- | +| 68 | Cron worker prompt/SQL tasks run with ambient runtime credentials — Spicepod tampering promotes to full data plane + LLM provider access | High | EoP / Tampering | +| 69 | Eval runs persist user inputs and model outputs into `spice.evals.*` tables indefinitely | Medium | Info Disclosure / Repudiation | + +## Post-v2.0 Deltas (Informational) + +These trunk changes post-date the initial v2.0.0 threat model and should be folded into the next revision: + +- `datafusion-ddl`, `datafusion-dml`, `datafusion-flightsql` extracted into dedicated crates. +- `file://` state_location support for the async-queries scheduler (new cluster-state backend). +- Control-plane heartbeat behavior changed to not block when all slots are acquired. +- DataFusion error sanitization landed on the public error path. +- Snapshot compaction removed; explicit snapshot existence/size verification added. +- Sort pushdown and additional predicate pushdown across providers (expands what runs in federated backends). + +> **Correction from prior draft**: OTEL metrics exporter auth-header support is *not* on trunk — the OTEL exporter config currently has no `headers` field. This line has been removed from the deltas list. + +## Full Threat Model + +See [v2.0.0.json](v2.0.0.json) for the complete STRIDE threat model in OWASP Threat Dragon format. From b9e71a9ea35dd5f4fc3722af861aa0be11e34ef1 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 20 Apr 2026 22:51:51 -0700 Subject: [PATCH 4/4] feat(embeddings): multi-vector embeddings with MaxSim + late-interaction (#10408) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(embeddings): multi-vector embeddings with MaxSim + late-interaction Extends column-level embeddings to accept list-of-string columns and produces one embedding vector per list element, stored as List> per row. Enables tag/synonym/attribute-style columns to participate in vector search without users having to pre-flatten into a separate table. Search adds two new scoring modes alongside the existing chunked-scalar path: - Single-query × multi-vector (MaxSim / Mean / Sum): per-row score is max/mean/sum over the list element cosines. Default is MaxSim (ColBERT-style). _match returns the element that produced the top cosine. - Multi-query × multi-vector (late-interaction): SUM_{q in Q} MAX_{d in D} cos(q, d). Opt in by passing an array query, e.g. vector_search(tbl, ['foo','bar'], col). Config surface (both ColumnLevelEmbeddingConfig and the legacy ColumnEmbeddingConfig) gains two new fields: - aggregation: max|mean|sum (default max); rejected on scalar columns. - max_elements_per_row: default 32, hard cap 1024; excess elements are dropped with a tracing::warn. Mode is auto-detected from the column Arrow type: Utf8/Utf8View/LargeUtf8 → Scalar, List/LargeList → ListMulti. Chunking is rejected on list columns; multi-vector options on scalar columns error with a clear message. Implementation: - EmbeddingInputMode { Scalar, ListMulti } threaded through EmbeddingTable; resolve_input_mode enforces all validation rules. - decompose_list_of_strings handles ListArray, LargeListArray and all three Utf8 element types; respects max_elements_per_row truncation. - get_vectors_per_list_element (async + sync) embeds via one model call, respecting null-row (→ empty output list) and null/empty-element (→ null vector) semantics. - base_table_has_embedding_column relaxes the offset-column requirement when the source column is list-typed; multi-vector uses element index as the implicit offset. - ChunkedNonIndexVectorGeneration grows a VectorScanMode with three variants; search() dispatches to search_chunked_scalar, search_list_multi, or search_late_interaction. Aggregation uses Expr::AggregateFunction.partition_by windowing. Late-interaction unions per-query subplans tagged with q_idx and does a two-step aggregate (pk, q_idx → MAX; pk → SUM). - VectorSearchTableFuncArgs gains a queries: Vec alongside the existing query: String. parse_query_arg accepts either a Utf8 literal or a make_array(...) expression; to_expr round-trips both forms. Dispatcher errors when multi-query is paired with a scalar column. - Telemetry track_vector_search gets multi_vector and multi_vector_aggregation KeyValue dims when applicable. Accelerator compatibility: multi-vector output Arrow shape is identical to the chunked-scalar path's output, so Arrow, Cayenne, and DuckDB round-trip transparently. Turso serializes nested lists to JSON TEXT today (turso.rs:581-583); SQLite inherits via datafusion-table-providers. The compat matrix is documented at the head of EmbeddingInputMode. A native typed side-table for SQLite/Turso is a future optimization that would benefit chunked-scalar equally. Tests: 41 new unit tests across embeddings::table (25), embeddings::execution_plan (12), and embeddings::udtf::parser_tests (4) cover type detection, input-mode resolution, list decomposition with null/empty/truncation edge cases, multi-vector list-array construction, end-to-end per-element embedding via a mock embedder, and the make_array query parser. * Fix + Lint * Lint * fix(embeddings): address review comments on multi-vector PR - build_multi_vector_list_array validates each embedding's length against vector_length before appending, returning a structured error instead of letting the FixedSizeListBuilder panic on mismatch. - decompose_generic_list hoists value_offsets() outside the loop and resolves the string-array variant once via a three-way downcast, removing per-element dynamic dispatch. Introduces a generic build_rows helper parameterised by a closure. - parse_query_arg rejects make_array(...) with more than VECTOR_SEARCH_MAX_QUERIES (32) elements to prevent late-interaction plans from blowing up on unbounded input. - to_expr derives the single-query literal from args.queries.first() so the single- and multi-query branches stay consistent. * fix(embeddings): round 2 of review fixes — format, missing-column error, grammar - cargo fmt fixup on decompose_generic_list after the string-variant refactor. - try_new now returns Error::EmbeddingColumnNotInSchema when a configured embedding source column is missing from the base schema, instead of silently dropping the column. Misconfiguration fails fast during table construction. - Grammar: "Cannot use it create an embeddings" → "Cannot use it to create embeddings" in the base_table_has_embedding_column warning. * Lint * feat: Implement JSON schema decomposition for HTTP connector * fix(vector): unify aggregation handling for ChunkedScalar and LateInteraction modes * Lint * fix(tests): improve variable naming for clarity in embedding tests * fix(clippy): replace unwrap with expect and Arc::clone in embedding tests --------- Co-authored-by: Viktor Yershov --- .../datafusion/codec/spice_logical_codec.rs | 2 + crates/runtime/src/embeddings/connector.rs | 2 + .../runtime/src/embeddings/execution_plan.rs | 588 +++++++++++++++- crates/runtime/src/embeddings/table.rs | 645 +++++++++++++++++- crates/runtime/src/embeddings/udtf.rs | 229 ++++++- crates/runtime/src/search/candidate/vector.rs | 475 ++++++++++++- .../src/search/candidate/vector_udtf.rs | 1 + crates/runtime/src/search/rrf.rs | 1 + crates/runtime/src/search/search_engine.rs | 38 +- crates/runtime/src/view.rs | 2 + crates/runtime/tests/models/hf.rs | 2 + crates/runtime/tests/models/openai.rs | 10 + crates/runtime/tests/models/s3_vectors.rs | 4 + crates/runtime/tests/models/search.rs | 4 + crates/spicepod/src/component/embeddings.rs | 47 ++ crates/spicepod/src/semantic.rs | 26 +- 16 files changed, 1970 insertions(+), 106 deletions(-) diff --git a/crates/runtime/src/cluster/datafusion/codec/spice_logical_codec.rs b/crates/runtime/src/cluster/datafusion/codec/spice_logical_codec.rs index d8fd87016b..e621633e60 100644 --- a/crates/runtime/src/cluster/datafusion/codec/spice_logical_codec.rs +++ b/crates/runtime/src/cluster/datafusion/codec/spice_logical_codec.rs @@ -117,6 +117,7 @@ impl SpiceLogicalCodec { ); let exprs = VectorSearchTableFunc::to_expr(&VectorSearchTableFuncArgs { tbl: SqlTableReference::parse_str(&vector_args.table), + queries: vec![vector_args.query.clone()], query: vector_args.query, column: vector_args.column, limit: vector_args.limit.map(Self::limit_from_u64).transpose()?, @@ -164,6 +165,7 @@ impl SpiceLogicalCodec { }; let vector_exprs = VectorSearchTableFunc::to_expr(&VectorSearchTableFuncArgs { tbl: SqlTableReference::parse_str(&args.table), + queries: vec![args.query.clone()], query: args.query.clone(), column: args.column.clone(), limit: args.limit.map(Self::limit_from_u64).transpose()?, diff --git a/crates/runtime/src/embeddings/connector.rs b/crates/runtime/src/embeddings/connector.rs index 73fa80ac7c..e3eb54c6c2 100644 --- a/crates/runtime/src/embeddings/connector.rs +++ b/crates/runtime/src/embeddings/connector.rs @@ -122,6 +122,8 @@ impl EmbeddingConnector { chunking: e.chunking.clone(), primary_keys: e.row_ids.clone(), vector_size: e.vector_size, + aggregation: e.aggregation, + max_elements_per_row: e.max_elements_per_row, }) }) .collect_vec(); diff --git a/crates/runtime/src/embeddings/execution_plan.rs b/crates/runtime/src/embeddings/execution_plan.rs index 0ce7797610..6637ae7078 100644 --- a/crates/runtime/src/embeddings/execution_plan.rs +++ b/crates/runtime/src/embeddings/execution_plan.rs @@ -15,8 +15,9 @@ limitations under the License. */ use arrow::array::{ - Array, ArrayRef, FixedSizeListArray, FixedSizeListBuilder, LargeStringArray, ListArray, - PrimitiveBuilder, RecordBatch, StringArray, StringViewArray, + Array, ArrayRef, FixedSizeListArray, FixedSizeListBuilder, GenericListArray, LargeListArray, + LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveBuilder, RecordBatch, StringArray, + StringViewArray, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field, Float32Type, Int32Type, SchemaRef}; @@ -41,7 +42,7 @@ use snafu::ResultExt; use std::collections::HashMap; use std::{any::Any, sync::Arc, thread}; -use super::table::EmbeddingColumnConfig; +use super::table::{EmbeddingColumnConfig, EmbeddingInputMode}; use crate::model::EmbeddingModelStore; use crate::{embedding_col, offset_col}; use rayon::ThreadPool; @@ -259,6 +260,7 @@ pub(crate) async fn compute_additional_embedding_columns( let EmbeddingColumnConfig { model_name, chunker: chunker_opt, + input_mode, .. } = cfg; tracing::trace!("Embedding column '{col}' with model {model_name}"); @@ -276,6 +278,38 @@ pub(crate) async fn compute_additional_embedding_columns( continue; }; + // Multi-vector path (list-of-strings source): decompose the + // list column into per-row strings, embed each, and rebuild as + // `List>`. No offsets column is emitted. + if let EmbeddingInputMode::ListMulti { + max_elements_per_row, + .. + } = *input_mode + { + let Some(rows) = decompose_list_of_strings(raw_data, max_elements_per_row) else { + tracing::warn!( + "Expected a list-of-strings column for '{col}' in multi-vector mode, got {}", + raw_data.data_type() + ); + continue; + }; + + let list_array = if model.supports_sync_embeddings() { + let task_model = Arc::clone(model); + let vector_size = cfg.vector_size; + task::spawn_blocking(move || { + get_vectors_per_list_element_in_process(rows, &task_model, vector_size) + }) + .await?? + } else { + get_vectors_per_list_element(rows, &**model, cfg.vector_size).await? + }; + + tracing::trace!("Successfully embedded column '{col}' in multi-vector mode"); + embed_arrays.insert(embedding_col!(col), Arc::new(list_array) as ArrayRef); + continue; + } + let Some(arr_iter) = convert_string_arrow_to_iterator!(raw_data) else { tracing::warn!( "Expected 'StringArray', 'StringViewArray' or 'LargeStringArray' for column '{}', but got {}", @@ -473,6 +507,286 @@ pub(super) fn get_vectors_in_process( Ok(builder.finish()) } +// ===== Multi-vector (list-of-strings) embedding path ===== +// +// For list-typed source columns, each list element gets its own +// embedding. The output Arrow shape is `List>`: +// one outer list per row, inner element is one vector per source-list +// element. See `EmbeddingInputMode::ListMulti` in `table.rs` for the +// configuration contract, and `EmbeddingTable::embedding_fields` for +// the schema. +// +// The search path at query time (M3) unnests these into per-element +// rows and aggregates per-row scores via MaxSim / mean / sum. + +/// Per-row decomposition of a list-of-strings column. +/// - Outer `None`: source row was null (produces an empty output list). +/// - Inner `None`: element was null or empty (produces a null vector). +/// - Inner `Some(s)`: element is a valid non-empty string to embed. +type DecomposedListOfStrings = Vec>>>; + +/// Decompose a list-of-strings Arrow column into owned per-row vecs, +/// applying `max_elements_per_row` truncation. Returns `None` for +/// unsupported input shapes (caller should emit a descriptive error). +fn decompose_list_of_strings( + arr: &ArrayRef, + max_elements_per_row: usize, +) -> Option { + if let Some(list) = arr.as_any().downcast_ref::() { + return decompose_generic_list(list, max_elements_per_row); + } + if let Some(list) = arr.as_any().downcast_ref::() { + return decompose_generic_list(list, max_elements_per_row); + } + None +} + +fn decompose_generic_list( + list: &GenericListArray, + max_elements_per_row: usize, +) -> Option { + let values = list.values(); + let offsets = list.value_offsets(); + let values_any = values.as_any(); + + if let Some(arr) = values_any.downcast_ref::() { + Some(build_rows(list, offsets, max_elements_per_row, |j| { + if arr.is_null(j) { + None + } else { + let v = arr.value(j); + if v.is_empty() { + None + } else { + Some(v.to_string()) + } + } + })) + } else if let Some(arr) = values_any.downcast_ref::() { + Some(build_rows(list, offsets, max_elements_per_row, |j| { + if arr.is_null(j) { + None + } else { + let v = arr.value(j); + if v.is_empty() { + None + } else { + Some(v.to_string()) + } + } + })) + } else { + values_any.downcast_ref::().map(|arr| { + build_rows(list, offsets, max_elements_per_row, |j| { + if arr.is_null(j) { + None + } else { + let v = arr.value(j); + if v.is_empty() { + None + } else { + Some(v.to_string()) + } + } + }) + }) + } +} + +fn build_rows( + list: &GenericListArray, + offsets: &[O], + max_elements_per_row: usize, + get_str: impl Fn(usize) -> Option, +) -> DecomposedListOfStrings { + let mut rows = Vec::with_capacity(list.len()); + let mut any_truncated = false; + + for i in 0..list.len() { + if list.is_null(i) { + rows.push(None); + continue; + } + let start = offsets[i].as_usize(); + let end = offsets[i + 1].as_usize(); + let raw_len = end.saturating_sub(start); + let effective_end = if raw_len > max_elements_per_row { + any_truncated = true; + start + max_elements_per_row + } else { + end + }; + let mut row = Vec::with_capacity(effective_end - start); + for j in start..effective_end { + row.push(get_str(j)); + } + rows.push(Some(row)); + } + + if any_truncated { + tracing::warn!( + "Multi-vector column truncated to max_elements_per_row={max_elements_per_row}; excess elements dropped." + ); + } + rows +} + +/// Flatten a decomposed row set into: the strings to send to the +/// embedding model, a per-row length (for outer list offsets), and a +/// per-position validity flag (`true` = an embedding slot, `false` = a +/// null-vector slot). +fn flatten_rows_for_embedding( + rows: DecomposedListOfStrings, +) -> (Vec, Vec>, Vec) { + let mut flat: Vec = Vec::new(); + let mut validity: Vec> = Vec::with_capacity(rows.len()); + let mut lengths: Vec = Vec::with_capacity(rows.len()); + + for row_opt in rows { + match row_opt { + None => { + lengths.push(0); + validity.push(Vec::new()); + } + Some(elements) => { + let mut v = Vec::with_capacity(elements.len()); + for elem in elements { + match elem { + Some(s) => { + flat.push(s); + v.push(true); + } + None => v.push(false), + } + } + lengths.push(v.len()); + validity.push(v); + } + } + } + + (flat, validity, lengths) +} + +/// Build the final `List>` output array from the +/// flattened embeddings plus per-row length and per-position validity. +#[expect(clippy::cast_sign_loss)] +fn build_multi_vector_list_array( + positions_validity: &[Vec], + row_lengths: &[usize], + embedded: &[Vec], + vector_length: i32, +) -> Result> { + let total_elements: usize = row_lengths.iter().sum(); + + let mut inner_builder = FixedSizeListBuilder::with_capacity( + PrimitiveBuilder::::with_capacity(total_elements * (vector_length as usize)), + vector_length, + total_elements, + ) + .with_field(Arc::new(Field::new("item", DataType::Float32, false))); + + let expected_slots: usize = positions_validity + .iter() + .flat_map(|v| v.iter()) + .filter(|&&valid| valid) + .count(); + if embedded.len() != expected_slots { + return Err(format!( + "embedding count mismatch: expected {expected_slots} vectors but got {}", + embedded.len() + ) + .into()); + } + + let expected_dims = vector_length as usize; + let mut embed_ptr: usize = 0; + for validity in positions_validity { + for &valid in validity { + if valid { + let vec = &embedded[embed_ptr]; + if vec.len() != expected_dims { + return Err(format!( + "embedding vector length mismatch at index {embed_ptr}: expected {expected_dims} dimensions but got {}", + vec.len() + ) + .into()); + } + inner_builder.values().append_slice(vec); + inner_builder.append(true); + embed_ptr += 1; + } else { + // The inner Float32 field is declared non-nullable; the + // outer FixedSizeList slot encodes nullness. Append + // placeholder zeros and mark only the parent slot null. + for _ in 0..expected_dims { + inner_builder.values().append_value(0.0); + } + inner_builder.append(false); + } + } + } + + let inner_field = Arc::new(Field::new_fixed_size_list( + "item", + Field::new("item", DataType::Float32, false), + vector_length, + true, + )); + + let offsets = OffsetBuffer::::from_lengths(row_lengths.iter().copied()); + + Ok(ListArray::try_new( + inner_field, + offsets, + Arc::new(inner_builder.finish()), + None, + )?) +} + +/// Embed each element of a list-of-strings column via the async +/// [`Embed`] trait. Produces `List>`. +pub(super) async fn get_vectors_per_list_element( + rows: DecomposedListOfStrings, + model: &dyn Embed, + vector_length: i32, +) -> Result> { + let (flat, validity, lengths) = flatten_rows_for_embedding(rows); + + let embedded: Vec> = if flat.is_empty() { + Vec::new() + } else { + model.embed(EmbeddingInput::StringArray(flat)).await? + }; + + build_multi_vector_list_array(&validity, &lengths, &embedded, vector_length) +} + +/// Sync counterpart to [`get_vectors_per_list_element`], using rayon +/// for in-process embedding models. +pub(super) fn get_vectors_per_list_element_in_process( + rows: DecomposedListOfStrings, + model: &Arc, + vector_length: i32, +) -> Result> { + let (flat, validity, lengths) = flatten_rows_for_embedding(rows); + + let embedded: Vec> = if flat.is_empty() { + Vec::new() + } else { + let pool = build_embedding_pool(model.parallelism())?; + let batches = pool.install(|| { + flat.into_par_iter() + .chunks(32) + .map(|chunk| model.embed_sync(EmbeddingInput::StringArray(chunk))) + .collect::, _>>() + })?; + batches.into_iter().flatten().collect() + }; + + build_multi_vector_list_array(&validity, &lengths, &embedded, vector_length) +} + /// Embed a [`StringArray`] using the provided [`Embed`] model and [`Chunker`]. The output is a [`ListArray`], /// where each input [`String`] gets chunked and embedded into a [`FixedSizeListArray`]. /// @@ -745,4 +1059,272 @@ mod tests { Ok(()) } + + // ===== M2: multi-vector (per-list-element) embedding ===== + + use crate::embeddings::execution_plan::{ + DecomposedListOfStrings, build_multi_vector_list_array, decompose_list_of_strings, + flatten_rows_for_embedding, get_vectors_per_list_element, + }; + use arrow::array::{ + ArrayRef, FixedSizeListArray, LargeListBuilder, ListBuilder, StringBuilder, + }; + use std::sync::Arc; + + fn mk_list_array(rows: &[Option>>]) -> ArrayRef { + let mut builder = ListBuilder::new(StringBuilder::new()); + for row in rows { + match row { + None => builder.append(false), + Some(elements) => { + for e in elements { + match e { + Some(s) => builder.values().append_value(s), + None => builder.values().append_null(), + } + } + builder.append(true); + } + } + } + Arc::new(builder.finish()) + } + + fn mk_large_list_array(rows: &[Option>>]) -> ArrayRef { + let mut builder = LargeListBuilder::new(StringBuilder::new()); + for row in rows { + match row { + None => builder.append(false), + Some(elements) => { + for e in elements { + match e { + Some(s) => builder.values().append_value(s), + None => builder.values().append_null(), + } + } + builder.append(true); + } + } + } + Arc::new(builder.finish()) + } + + #[test] + fn test_decompose_list_of_strings_basic() { + let arr = mk_list_array(&[ + Some(vec![Some("red"), Some("round")]), + Some(vec![Some("blue")]), + ]); + let rows = decompose_list_of_strings(&arr, 32).expect("list-of-strings supported"); + assert_eq!(rows.len(), 2); + assert_eq!( + rows[0].as_ref().expect("row 0 should be Some"), + &vec![Some("red".to_string()), Some("round".to_string())] + ); + assert_eq!( + rows[1].as_ref().expect("row 1 should be Some"), + &vec![Some("blue".to_string())] + ); + } + + #[test] + fn test_decompose_list_of_strings_null_row() { + let arr = mk_list_array(&[Some(vec![Some("a")]), None, Some(vec![Some("b")])]); + let rows = decompose_list_of_strings(&arr, 32).expect("ok"); + assert!(rows[0].is_some()); + assert!(rows[1].is_none()); + assert!(rows[2].is_some()); + } + + #[test] + fn test_decompose_list_of_strings_null_and_empty_element() { + // Empty string and null element should both become None inside + // the row so the embedder isn't asked to embed them. + let arr = mk_list_array(&[Some(vec![Some("x"), Some(""), None, Some("y")])]); + let rows = decompose_list_of_strings(&arr, 32).expect("ok"); + let row = rows[0].as_ref().expect("row 0 should be Some"); + assert_eq!( + row, + &vec![Some("x".to_string()), None, None, Some("y".to_string())] + ); + } + + #[test] + fn test_decompose_list_of_strings_truncates_to_cap() { + let arr = mk_list_array(&[Some(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])]); + let rows = decompose_list_of_strings(&arr, 2).expect("ok"); + let row = rows[0].as_ref().expect("row 0 should be Some"); + assert_eq!(row.len(), 2); + assert_eq!(row[0], Some("a".to_string())); + assert_eq!(row[1], Some("b".to_string())); + } + + #[test] + fn test_decompose_large_list_of_strings() { + // LargeList must also be supported. + let arr = mk_large_list_array(&[Some(vec![Some("red"), Some("green")])]); + let rows = decompose_list_of_strings(&arr, 32).expect("LargeList supported"); + assert_eq!( + rows[0].as_ref().expect("row 0 should be Some"), + &vec![Some("red".to_string()), Some("green".to_string())] + ); + } + + #[test] + fn test_decompose_unsupported_type() { + let arr: ArrayRef = Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])); + let res = decompose_list_of_strings(&arr, 32); + assert!(res.is_none()); + } + + #[test] + fn test_decompose_list_of_non_strings_rejected() { + // A `List` column must be rejected at the decomposition + // boundary, not silently treated as an all-null multi-vector + // column. + use arrow::array::{Int32Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow_schema::{DataType, Field}; + let values = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let offsets = OffsetBuffer::::from_lengths([3usize]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let list: ArrayRef = Arc::new(ListArray::new(field, offsets, Arc::new(values), None)); + assert!(decompose_list_of_strings(&list, 32).is_none()); + } + + #[test] + fn test_flatten_rows_for_embedding_preserves_validity() { + let rows: DecomposedListOfStrings = vec![ + Some(vec![Some("a".to_string()), None]), + None, + Some(vec![Some("b".to_string())]), + ]; + let (flat, validity, lengths) = flatten_rows_for_embedding(rows); + + assert_eq!(flat, vec!["a".to_string(), "b".to_string()]); + assert_eq!(validity, vec![vec![true, false], vec![], vec![true]]); + assert_eq!(lengths, vec![2, 0, 1]); + } + + #[test] + fn test_build_multi_vector_list_array_shapes() + -> Result<(), Box> { + let validity = vec![vec![true, false], vec![], vec![true]]; + let lengths = vec![2, 0, 1]; + let embedded = vec![vec![0.1, 0.2], vec![0.9, 0.8]]; + let out = build_multi_vector_list_array(&validity, &lengths, &embedded, 2)?; + + assert_eq!(out.len(), 3); + // Row 0: two elements, second is null + let out_row0 = out.value(0); + let fsl0 = out_row0 + .as_any() + .downcast_ref::() + .expect("row 0 should be FixedSizeListArray"); + assert_eq!(fsl0.len(), 2); + assert!(!fsl0.is_null(0)); + assert!(fsl0.is_null(1)); + + // Row 1: empty list + let out_row1 = out.value(1); + assert_eq!(out_row1.len(), 0); + + // Row 2: single element + let out_row2 = out.value(2); + let fsl2 = out_row2 + .as_any() + .downcast_ref::() + .expect("row 2 should be FixedSizeListArray"); + assert_eq!(fsl2.len(), 1); + assert!(!fsl2.is_null(0)); + + Ok(()) + } + + #[tokio::test] + async fn test_get_vectors_per_list_element_basic() + -> Result<(), Box> { + let rows: DecomposedListOfStrings = vec![ + Some(vec![Some("red".to_string()), Some("round".to_string())]), + Some(vec![Some("blue".to_string())]), + ]; + let model = MockEmbedder::default() + .with_pair("red", vec![0.1, 0.2]) + .with_pair("round", vec![0.3, 0.4]) + .with_pair("blue", vec![0.5, 0.6]); + let out = get_vectors_per_list_element(rows, &model, 2).await?; + + assert_eq!(out.len(), 2); + let out_row0 = out.value(0); + let fsl0 = out_row0 + .as_any() + .downcast_ref::() + .expect("row 0 should be FixedSizeListArray"); + assert_eq!(fsl0.len(), 2); + let v0 = fsl0.value(0); + let p0 = v0.as_primitive::(); + assert_eq!(p0.value(0), 0.1); + assert_eq!(p0.value(1), 0.2); + let v1 = fsl0.value(1); + let p1 = v1.as_primitive::(); + assert_eq!(p1.value(0), 0.3); + assert_eq!(p1.value(1), 0.4); + + let out_row1 = out.value(1); + let fsl1 = out_row1 + .as_any() + .downcast_ref::() + .expect("row 1 should be FixedSizeListArray"); + assert_eq!(fsl1.len(), 1); + let v2 = fsl1.value(0); + let p2 = v2.as_primitive::(); + assert_eq!(p2.value(0), 0.5); + assert_eq!(p2.value(1), 0.6); + + Ok(()) + } + + #[tokio::test] + async fn test_get_vectors_per_list_element_mixed_nulls() + -> Result<(), Box> { + // Null source row → empty output list. + // Null/empty string element → null vector in output. + let rows: DecomposedListOfStrings = vec![ + None, + Some(vec![ + Some("hello".to_string()), + None, + Some("world".to_string()), + ]), + Some(vec![]), + ]; + let model = MockEmbedder::default() + .with_pair("hello", vec![1.0, 2.0]) + .with_pair("world", vec![3.0, 4.0]); + let out = get_vectors_per_list_element(rows, &model, 2).await?; + + assert_eq!(out.len(), 3); + // Row 0: null source → empty list + assert_eq!(out.value(0).len(), 0); + // Row 1: 3 elements, middle is null + let out_row1 = out.value(1); + let fsl1 = out_row1 + .as_any() + .downcast_ref::() + .expect("row 1 should be FixedSizeListArray"); + assert_eq!(fsl1.len(), 3); + assert!(!fsl1.is_null(0)); + assert!(fsl1.is_null(1)); + assert!(!fsl1.is_null(2)); + // Row 2: empty input list → empty output list + assert_eq!(out.value(2).len(), 0); + + Ok(()) + } } diff --git a/crates/runtime/src/embeddings/table.rs b/crates/runtime/src/embeddings/table.rs index 6e4e9e20f8..64336f86c8 100644 --- a/crates/runtime/src/embeddings/table.rs +++ b/crates/runtime/src/embeddings/table.rs @@ -40,7 +40,10 @@ use crate::embeddings::construct_chunker; use crate::embeddings::execution_plan::EmbeddingTableExec; use crate::model::EmbeddingModelStore; use crate::{embedding_col, offset_col}; -use spicepod::component::embeddings::ColumnEmbeddingConfig; +use spicepod::component::embeddings::{ + ColumnEmbeddingConfig, EmbeddingAggregation, MULTI_VECTOR_MAX_ELEMENTS_DEFAULT, + MULTI_VECTOR_MAX_ELEMENTS_HARD_CAP, +}; use tokio::sync::RwLock; use super::common::{is_valid_embedding_type, is_valid_offset_type, vector_length}; @@ -48,10 +51,29 @@ use super::common::{is_valid_embedding_type, is_valid_offset_type, vector_length #[derive(Debug, Snafu)] pub enum Error { #[snafu(display( - "Column '{column}' has an unsupported data type for embedding. Only string types are allowed. For details, visit: https://spiceai.org/docs/components/embeddings", + "Column '{column}' has an unsupported data type for embedding. Supported types are string (`Utf8`, `Utf8View`, `LargeUtf8`) and list-of-string (`List`, `LargeList`). For details, visit: https://spiceai.org/docs/components/embeddings", ))] InvalidColumnType { column: String, data_type: DataType }, + #[snafu(display( + "Column '{column}' is configured for multi-vector embedding (list-typed) but also has chunking enabled. Chunking only applies to scalar string columns. Remove the chunking configuration for multi-vector columns." + ))] + MultiVectorChunkingNotSupported { column: String }, + + #[snafu(display( + "Column '{column}' was configured with `aggregation` or `max_elements_per_row`, but its type '{data_type}' is not a list-typed column. These options only apply to multi-vector (list-typed) columns." + ))] + MultiVectorOptionsOnScalar { column: String, data_type: DataType }, + + #[snafu(display( + "Column '{column}': `max_elements_per_row` must be between 1 and {cap}, got {value}." + ))] + MaxElementsPerRowOutOfRange { + column: String, + value: usize, + cap: usize, + }, + #[snafu(display( "The dataset is configured with an embedding model '{model}' to embed column '{column}', but the model '{model}' is not defined in Spicepod (as an 'embeddings') or failed to load.\nFor details, visit: https://spiceai.org/docs/components/embeddings" ))] @@ -65,6 +87,11 @@ pub enum Error { row_id_column: String, valid_columns: String, }, + + #[snafu(display( + "The dataset is configured with an embedding for column '{column}', but '{column}' is not present in the dataset schema. Verify the column configuration and try again.\nFor details, visit: https://spiceai.org/docs/components/embeddings" + ))] + EmbeddingColumnNotInSchema { column: String }, } /// An [`EmbeddingTable`] is a [`TableProvider`] where some columns are augmented with associated embedding columns @@ -86,6 +113,63 @@ impl std::fmt::Debug for EmbeddingTable { } } +/// Internal classifier for the source column's Arrow type. +#[derive(Clone, Debug, PartialEq, Eq)] +enum SourceShape { + /// `Utf8` / `Utf8View` / `LargeUtf8` — carries the concrete type for error messages. + Scalar(DataType), + /// `List` / `LargeList` (and `Utf8View`/`LargeUtf8` element variants). + ListOfString, +} + +/// Compatibility matrix for the multi-vector output type +/// (`List>`) across the accelerator engines Spice +/// supports. This shape is identical to what the chunked-scalar path has +/// produced since its introduction; multi-vector columns inherit that +/// behavior. +/// +/// | Accelerator | Storage | Notes | +/// |-------------|---------------------------------------------|------------------------------| +/// | Arrow | Native Arrow in-memory | Transparent. | +/// | Cayenne | Native Arrow persistence | Transparent. | +/// | `DuckDB` | Native `FLOAT[D][]` | Transparent. | +/// | `SQLite` | JSON-serialized `TEXT` (via table-providers)| Functional; JSON overhead. | +/// | Turso | JSON-serialized `TEXT` | See `turso.rs:581-583`. | +/// | `PostgreSQL` | Not yet supported | Out of scope this milestone. | +/// +/// `SQLite` / Turso JSON serialization is lossy in type fidelity (everything +/// round-trips as TEXT) but functionally correct. A proper side-table +/// strategy (`___mv(pk, elem_idx, vector)`) is a future +/// optimization for those accelerators; the current behavior is the same +/// the chunked-scalar path has shipped with. +/// +/// Shape of the source column being embedded. +/// +/// `Scalar` — the source column is a single string per row (`Utf8` / +/// `Utf8View` / `LargeUtf8`). One embedding vector is produced per row, +/// optionally doubly-nested if chunking is enabled. +/// +/// `ListMulti` — the source column is a list of strings per row +/// (`List` / `LargeList` and their `Utf8View` / `LargeUtf8` +/// variants). One embedding vector is produced per list element; at +/// query time per-element similarities are aggregated into a single +/// per-row score via `aggregation`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum EmbeddingInputMode { + Scalar, + ListMulti { + aggregation: EmbeddingAggregation, + max_elements_per_row: usize, + }, +} + +impl EmbeddingInputMode { + #[must_use] + pub fn is_list_multi(&self) -> bool { + matches!(self, Self::ListMulti { .. }) + } +} + #[derive(Clone)] pub struct EmbeddingColumnConfig { /// The name of the embedding model to use for this column. @@ -100,6 +184,10 @@ pub struct EmbeddingColumnConfig { // If None, either no chunking is needed, or [`in_base_table`] is true. pub chunker: Option>, + + /// Shape of the source column. Determines the output Arrow type and + /// whether the search path uses MaxSim-over-elements. + pub input_mode: EmbeddingInputMode, } impl std::fmt::Debug for EmbeddingColumnConfig { @@ -108,6 +196,7 @@ impl std::fmt::Debug for EmbeddingColumnConfig { .field("model_name", &self.model_name) .field("vector_size", &self.vector_size) .field("in_base_table", &self.in_base_table) + .field("input_mode", &self.input_mode) .finish_non_exhaustive() } } @@ -202,9 +291,11 @@ impl EmbeddingTable { let mut embedded_columns: HashMap = HashMap::new(); for (column, config) in embed_columns { - let model = config.model; + let model = config.model.clone(); let chunking_config_opt = embed_chunker_config.get(&column); + let source_shape = Self::detect_source_shape(&column, &base_schema)?; + if Self::base_table_has_embedding_column(&base_schema, &column) { tracing::debug!( "Column '{column}' has needed embeddings in base table. Will not augment." @@ -227,6 +318,15 @@ impl EmbeddingTable { continue; }; + // For precomputed embeddings, resolve the mode based on + // the source column's shape. If the source column isn't + // present (unusual — we got here via the embedding + // column existing), default to Scalar. + let input_mode = match source_shape { + Some(shape) => Self::resolve_input_mode(&column, shape, &config)?, + None => EmbeddingInputMode::Scalar, + }; + embedded_columns.insert( column, EmbeddingColumnConfig { @@ -234,6 +334,7 @@ impl EmbeddingTable { vector_size: vector_length, in_base_table: true, chunker: None, // Don't need chunking since it is done in base table. + input_mode, }, ); } else { @@ -241,7 +342,12 @@ impl EmbeddingTable { "Column '{column}' does not have needed embeddings in base table. Will augment with model {model}." ); - Self::verify_column_type_supported(&column, &base_schema)?; + // Source shape is required when we're computing + // embeddings — we can't embed a column we can't read. + let Some(shape) = source_shape else { + return EmbeddingColumnNotInSchemaSnafu { column }.fail(); + }; + let input_mode = Self::resolve_input_mode(&column, shape, &config)?; let Some(vector_length) = Self::embedding_size_from_models(&model, &embedding_models).await @@ -271,6 +377,7 @@ impl EmbeddingTable { vector_size: vector_length, in_base_table: false, chunker, + input_mode, }, ); } @@ -287,15 +394,18 @@ impl EmbeddingTable { /// For a base table with column, c, we expect: /// - `c` to be in the base schema. /// - `c_embedding` to be in the base schema. It needs to have a type compatible with [`Self::embedding_fields`]. - /// - If `c_embedding` has a doubly-nested list type, `c_offsets` should also be in the base schema. It should be a `List[FixedSizeList[Int32, 2]]`. + /// - If `c_embedding` has a doubly-nested list type AND the source column `c` is scalar-typed + /// (a chunked scalar embedding), `c_offsets` must also be in the base schema as + /// `List[FixedSizeList[Int32, 2]]`. For multi-vector embeddings (`c` is list-typed), no + /// offsets column is required — element index is the implicit offset. fn base_table_has_embedding_column(base_schema: &SchemaRef, column: &str) -> bool { // Check if the base column exists - if base_schema.column_with_name(column).is_none() { + let Some((_, source_field)) = base_schema.column_with_name(column) else { tracing::warn!( - "Column '{column}' does not exist in the base table. Cannot use it create an embeddings" + "Column '{column}' does not exist in the base table. Cannot use it to create embeddings" ); return false; - } + }; // Check if the embedding column exists and has a valid data type let Some((_, embedding_field)) = @@ -308,12 +418,40 @@ impl EmbeddingTable { return false; } - // If embedding is doubly nested, also check for the offsets column - if let DataType::List(inner) - | DataType::LargeList(inner) - | DataType::FixedSizeList(inner, _) = embedding_field.data_type() - && let DataType::FixedSizeList(_, _) = inner.data_type() - { + // If the source column is list-of-string, this is multi-vector + // mode: no sibling offsets column is required. + let source_is_list_of_string = matches!( + source_field.data_type(), + DataType::List(inner) | DataType::LargeList(inner) + if matches!( + inner.data_type(), + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) + ); + + // Multi-vector mode must have a doubly-nested embedding column + // (`List>` or similar). Otherwise treating a + // scalar embedding as precomputed leads to UNNEST planning errors + // downstream. + let embedding_is_doubly_nested = matches!( + embedding_field.data_type(), + DataType::List(inner) + | DataType::LargeList(inner) + | DataType::FixedSizeList(inner, _) + if matches!(inner.data_type(), DataType::FixedSizeList(_, _)) + ); + + if source_is_list_of_string && !embedding_is_doubly_nested { + tracing::warn!( + "Column '{column}' is list-typed (multi-vector) but the precomputed embedding column '{}' is not doubly-nested (`List>`). Will recompute embeddings.", + embedding_col!(column).as_str() + ); + return false; + } + + // Otherwise, if the embedding is doubly nested (chunked scalar), + // require the offsets column too. + if !source_is_list_of_string && embedding_is_doubly_nested { let Some((_, offsets_field)) = base_schema.column_with_name(offset_col!(column).as_str()) else { @@ -402,6 +540,40 @@ impl EmbeddingTable { }) } + /// Returns true if the column's embedding is produced in multi-vector + /// mode (source column is list-typed, one embedding per list + /// element). Multi-vector and chunked outputs share the same + /// doubly-nested Arrow shape, but the search path aggregates them + /// differently (multi-vector: max over list elements; chunked: max + /// over chunks of one scalar string). + #[must_use] + pub fn is_multi_vector(&self, column: &str) -> bool { + self.embedded_columns + .get(column) + .is_some_and(|cfg| cfg.input_mode.is_list_multi()) + } + + /// Returns the aggregation strategy configured for a multi-vector + /// column, or `None` if the column is scalar. + #[must_use] + pub fn multi_vector_aggregation(&self, column: &str) -> Option { + self.embedded_columns + .get(column) + .and_then(|cfg| match cfg.input_mode { + EmbeddingInputMode::ListMulti { aggregation, .. } => Some(aggregation), + EmbeddingInputMode::Scalar => None, + }) + } + + /// Returns true when the column's output Arrow type is + /// doubly-nested (`List>`): either because the + /// scalar source is chunked, or because the source is list-typed + /// (multi-vector). Both use the same UNNEST-based search path. + #[must_use] + pub fn has_nested_embedding_output(&self, column: &str) -> bool { + self.is_chunked(column) || self.is_multi_vector(column) + } + /// Get the names of the columns that are augmented with embeddings. #[must_use] pub fn get_embedding_columns(&self) -> Vec { @@ -424,21 +596,84 @@ impl EmbeddingTable { vector_length(embedding_field.data_type()) } - fn verify_column_type_supported(column: &str, base_schema: &SchemaRef) -> Result<(), Error> { - if let Some((_, field)) = base_schema.column_with_name(column) { - let data_type = field.data_type(); - if !matches!( - data_type, - DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 - ) { - return InvalidColumnTypeSnafu { - column: column.to_string(), - data_type: data_type.clone(), + /// Shape of the source column, if it exists and has a supported + /// type. Returns `None` when the column isn't in the base schema + /// (caller handles that case); errors when the column exists but has + /// an unsupported type. + fn detect_source_shape( + column: &str, + base_schema: &SchemaRef, + ) -> Result, Error> { + let Some((_, field)) = base_schema.column_with_name(column) else { + return Ok(None); + }; + let data_type = field.data_type(); + match data_type { + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + Ok(Some(SourceShape::Scalar(data_type.clone()))) + } + DataType::List(inner) | DataType::LargeList(inner) + if matches!( + inner.data_type(), + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) => + { + Ok(Some(SourceShape::ListOfString)) + } + _ => InvalidColumnTypeSnafu { + column: column.to_string(), + data_type: data_type.clone(), + } + .fail(), + } + } + + /// Resolve the effective [`EmbeddingInputMode`] for a column given + /// its detected source shape and the user-provided configuration. + /// Enforces validation rules: list-typed multi-vector options only + /// apply to list columns; chunking is incompatible with multi-vector; + /// `max_elements_per_row` is bounds-checked. + fn resolve_input_mode( + column: &str, + shape: SourceShape, + config: &ColumnEmbeddingConfig, + ) -> Result { + match shape { + SourceShape::Scalar(data_type) => { + if config.aggregation.is_some() || config.max_elements_per_row.is_some() { + return MultiVectorOptionsOnScalarSnafu { + column: column.to_string(), + data_type, + } + .fail(); } - .fail(); + Ok(EmbeddingInputMode::Scalar) + } + SourceShape::ListOfString => { + if config.chunking.as_ref().is_some_and(|c| c.enabled) { + return MultiVectorChunkingNotSupportedSnafu { + column: column.to_string(), + } + .fail(); + } + let aggregation = config.aggregation.unwrap_or_default(); + let cap = config + .max_elements_per_row + .unwrap_or(MULTI_VECTOR_MAX_ELEMENTS_DEFAULT); + if cap == 0 || cap > MULTI_VECTOR_MAX_ELEMENTS_HARD_CAP { + return MaxElementsPerRowOutOfRangeSnafu { + column: column.to_string(), + value: cap, + cap: MULTI_VECTOR_MAX_ELEMENTS_HARD_CAP, + } + .fail(); + } + Ok(EmbeddingInputMode::ListMulti { + aggregation, + max_elements_per_row: cap, + }) } } - Ok(()) } async fn embedding_size_from_models( @@ -508,8 +743,10 @@ impl EmbeddingTable { return vec![]; } - if cfg.chunker.is_some() { - vec![ + match (cfg.input_mode, cfg.chunker.is_some()) { + // Scalar + chunked: doubly nested embedding + offsets + // (character offsets of each chunk into the source string). + (EmbeddingInputMode::Scalar, true) => vec![ Arc::new(Field::new_list( embedding_col!(field.name()), Field::new_fixed_size_list( @@ -530,14 +767,33 @@ impl EmbeddingTable { ), false, )), - ] - } else { - vec![Arc::new(Field::new_fixed_size_list( + ], + // Scalar + unchunked: one vector per row. + (EmbeddingInputMode::Scalar, false) => vec![Arc::new(Field::new_fixed_size_list( embedding_col!(field.name()), Field::new("item", DataType::Float32, true), cfg.vector_size, true, - ))] + ))], + // Multi-vector: one vector per list element. No offsets — + // element index serves as the implicit offset into the + // source list at query time. + // + // The inner FixedSizeList is nullable so that null strings + // inside the source list produce null vectors in the output, + // preserving index correspondence with the source column. + // The outer list is non-null: a null source row maps to an + // empty output list. + (EmbeddingInputMode::ListMulti { .. }, _) => vec![Arc::new(Field::new_list( + embedding_col!(field.name()), + Field::new_fixed_size_list( + "item", + Field::new("item", DataType::Float32, false), + cfg.vector_size, + true, + ), + false, + ))], } } } @@ -850,6 +1106,24 @@ mod tests { )); } + #[test] + fn test_list_source_with_scalar_embedding_rejected() { + // Multi-vector source (`List`) paired with a singly-nested + // (scalar) embedding column is a shape mismatch: the runtime + // cannot UNNEST stored vectors per-element from a + // `FixedSizeList` alone. + assert!(!EmbeddingTable::base_table_has_embedding_column( + &Arc::new(Schema::new(vec![ + field("c", DataType::List(field("item", DataType::Utf8))), + field( + "c_embedding", + DataType::FixedSizeList(field("item", DataType::Float32), 4), + ), + ])), + "c" + )); + } + #[tokio::test] async fn test_invalid_row_id_column_rejected() { let schema = Arc::new(Schema::new(vec![ @@ -868,6 +1142,8 @@ mod tests { primary_keys: Some(vec!["n_regionkey".to_string()]), chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }]; let result = @@ -905,6 +1181,8 @@ mod tests { primary_keys: Some(vec!["id".to_string()]), chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }]; let result = @@ -938,6 +1216,8 @@ mod tests { primary_keys: None, chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }]; let result = @@ -969,6 +1249,8 @@ mod tests { primary_keys: Some(vec!["id".to_string(), "nonexistent".to_string()]), chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }]; let result = @@ -983,4 +1265,303 @@ mod tests { "Error should mention the invalid column, got: {err_msg}" ); } + + // ===== M1: multi-vector configuration ===== + + fn list_of_utf8(name: &str) -> FieldRef { + Arc::new(Field::new_list( + name, + Field::new("item", DataType::Utf8, true), + true, + )) + } + + #[test] + fn test_detect_source_shape_scalar_utf8() { + let schema = Arc::new(Schema::new(vec![field("c", DataType::Utf8)])); + let shape = EmbeddingTable::detect_source_shape("c", &schema).expect("ok"); + assert_eq!(shape, Some(SourceShape::Scalar(DataType::Utf8))); + } + + #[test] + fn test_detect_source_shape_list_of_utf8() { + let schema = Arc::new(Schema::new(vec![list_of_utf8("tags")])); + let shape = EmbeddingTable::detect_source_shape("tags", &schema).expect("ok"); + assert_eq!(shape, Some(SourceShape::ListOfString)); + } + + #[test] + fn test_detect_source_shape_unsupported() { + let schema = Arc::new(Schema::new(vec![field("c", DataType::Int32)])); + let err = EmbeddingTable::detect_source_shape("c", &schema) + .expect_err("expected unsupported type"); + assert!(matches!(err, Error::InvalidColumnType { .. })); + } + + #[test] + fn test_detect_source_shape_missing_column() { + let schema = Arc::new(Schema::new(vec![field("other", DataType::Utf8)])); + let shape = EmbeddingTable::detect_source_shape("c", &schema).expect("ok"); + assert_eq!(shape, None); + } + + #[test] + fn test_resolve_input_mode_scalar_default() { + let cfg = ColumnEmbeddingConfig { + column: "c".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: None, + vector_size: None, + aggregation: None, + max_elements_per_row: None, + }; + let mode = + EmbeddingTable::resolve_input_mode("c", SourceShape::Scalar(DataType::Utf8), &cfg) + .expect("ok"); + assert_eq!(mode, EmbeddingInputMode::Scalar); + } + + #[test] + fn test_resolve_input_mode_scalar_rejects_multi_vector_options() { + let cfg = ColumnEmbeddingConfig { + column: "c".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: None, + vector_size: None, + aggregation: Some(EmbeddingAggregation::Max), + max_elements_per_row: None, + }; + let err = + EmbeddingTable::resolve_input_mode("c", SourceShape::Scalar(DataType::Utf8), &cfg) + .expect_err("expected rejection"); + assert!(matches!(err, Error::MultiVectorOptionsOnScalar { .. })); + } + + #[test] + fn test_resolve_input_mode_list_defaults_max_and_cap_32() { + let cfg = ColumnEmbeddingConfig { + column: "tags".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: None, + vector_size: None, + aggregation: None, + max_elements_per_row: None, + }; + let mode = EmbeddingTable::resolve_input_mode("tags", SourceShape::ListOfString, &cfg) + .expect("ok"); + match mode { + EmbeddingInputMode::ListMulti { + aggregation, + max_elements_per_row, + } => { + assert_eq!(aggregation, EmbeddingAggregation::Max); + assert_eq!(max_elements_per_row, MULTI_VECTOR_MAX_ELEMENTS_DEFAULT); + } + EmbeddingInputMode::Scalar => panic!("expected ListMulti"), + } + } + + #[test] + fn test_resolve_input_mode_list_honors_aggregation_override() { + let cfg = ColumnEmbeddingConfig { + column: "tags".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: None, + vector_size: None, + aggregation: Some(EmbeddingAggregation::Mean), + max_elements_per_row: Some(64), + }; + let mode = EmbeddingTable::resolve_input_mode("tags", SourceShape::ListOfString, &cfg) + .expect("ok"); + assert_eq!( + mode, + EmbeddingInputMode::ListMulti { + aggregation: EmbeddingAggregation::Mean, + max_elements_per_row: 64, + } + ); + } + + #[test] + fn test_resolve_input_mode_list_rejects_chunking() { + let cfg = ColumnEmbeddingConfig { + column: "tags".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: Some(spicepod::component::embeddings::EmbeddingChunkConfig { + enabled: true, + target_chunk_size: 256, + overlap_size: 0, + trim_whitespace: false, + }), + vector_size: None, + aggregation: None, + max_elements_per_row: None, + }; + let err = EmbeddingTable::resolve_input_mode("tags", SourceShape::ListOfString, &cfg) + .expect_err("expected chunking rejection"); + assert!(matches!(err, Error::MultiVectorChunkingNotSupported { .. })); + } + + #[test] + fn test_resolve_input_mode_list_rejects_zero_cap() { + let cfg = ColumnEmbeddingConfig { + column: "tags".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: None, + vector_size: None, + aggregation: None, + max_elements_per_row: Some(0), + }; + let err = EmbeddingTable::resolve_input_mode("tags", SourceShape::ListOfString, &cfg) + .expect_err("expected cap rejection"); + assert!(matches!(err, Error::MaxElementsPerRowOutOfRange { .. })); + } + + #[test] + fn test_resolve_input_mode_list_rejects_cap_above_hard_cap() { + let cfg = ColumnEmbeddingConfig { + column: "tags".to_string(), + model: "m".to_string(), + primary_keys: None, + chunking: None, + vector_size: None, + aggregation: None, + max_elements_per_row: Some(MULTI_VECTOR_MAX_ELEMENTS_HARD_CAP + 1), + }; + let err = EmbeddingTable::resolve_input_mode("tags", SourceShape::ListOfString, &cfg) + .expect_err("expected cap rejection"); + assert!(matches!(err, Error::MaxElementsPerRowOutOfRange { .. })); + } + + // ===== M4: accelerator schema compatibility ===== + // + // Multi-vector output is `List>` — the + // identical Arrow shape the chunked-scalar path has always emitted + // (just without the sibling `_offset` column). Arrow, Cayenne, and + // DuckDB accelerators already round-trip this shape via the chunked + // code path, so multi-vector inherits that compatibility with no + // additional accelerator changes. SQLite / Turso nested- + // FixedSizeList support remains fragile and is addressed by the + // M7 side-table strategy. + + #[test] + fn test_embedding_fields_multi_vector_schema_matches_chunked_minus_offset() { + // A multi-vector column should produce exactly one output field: + // `_embedding: List>`. The chunked path + // adds an `_offset` sibling; multi-vector does not. + let tags_field = Arc::new(Field::new_list( + "tags", + Field::new("item", DataType::Utf8, true), + true, + )); + let base_schema = Arc::new(Schema::new(vec![Arc::clone(&tags_field)])); + + let embedded_columns = HashMap::from([( + "tags".to_string(), + EmbeddingColumnConfig { + model_name: "m".to_string(), + vector_size: 4, + in_base_table: false, + chunker: None, + input_mode: EmbeddingInputMode::ListMulti { + aggregation: EmbeddingAggregation::Max, + max_elements_per_row: 32, + }, + }, + )]); + + let base_table: Arc = Arc::new( + datafusion::catalog::MemTable::try_new(base_schema, vec![vec![]]) + .expect("valid schema"), + ); + + let table = EmbeddingTable { + base_table, + embedded_columns, + embedding_models: Arc::new(RwLock::new(HashMap::new())), + }; + + let fields = table.embedding_fields(&tags_field); + assert_eq!(fields.len(), 1, "multi-vector produces no offset column"); + let emb = &fields[0]; + assert_eq!(emb.name(), "tags_embedding"); + // Expect List> + let DataType::List(inner) = emb.data_type() else { + panic!("expected List, got {:?}", emb.data_type()); + }; + let DataType::FixedSizeList(leaf, size) = inner.data_type() else { + panic!("expected inner FixedSizeList, got {:?}", inner.data_type()); + }; + assert_eq!(*size, 4); + assert_eq!(leaf.data_type(), &DataType::Float32); + } + + #[test] + fn test_has_nested_embedding_output_list_multi() { + let tags_field = Arc::new(Field::new_list( + "tags", + Field::new("item", DataType::Utf8, true), + true, + )); + let base_schema = Arc::new(Schema::new(vec![tags_field])); + let embedded_columns = HashMap::from([( + "tags".to_string(), + EmbeddingColumnConfig { + model_name: "m".to_string(), + vector_size: 4, + in_base_table: false, + chunker: None, + input_mode: EmbeddingInputMode::ListMulti { + aggregation: EmbeddingAggregation::Max, + max_elements_per_row: 32, + }, + }, + )]); + let base_table: Arc = Arc::new( + datafusion::catalog::MemTable::try_new(base_schema, vec![vec![]]) + .expect("valid schema"), + ); + let table = EmbeddingTable { + base_table, + embedded_columns, + embedding_models: Arc::new(RwLock::new(HashMap::new())), + }; + + assert!(table.is_multi_vector("tags")); + assert!(!table.is_chunked("tags")); + // has_nested_embedding_output covers either mode — this is what + // the search dispatcher keys off of to pick the UNNEST path. + assert!(table.has_nested_embedding_output("tags")); + assert_eq!( + table.multi_vector_aggregation("tags"), + Some(EmbeddingAggregation::Max) + ); + } + + #[test] + fn test_base_table_has_embedding_list_multi_no_offset_required() { + // Source is List; no offsets column should be required. + let schema = Arc::new(Schema::new(vec![ + list_of_utf8("tags"), + Arc::new(Field::new_list( + "tags_embedding", + Field::new_fixed_size_list( + "item", + Field::new("item", DataType::Float32, false), + 4, + false, + ), + false, + )), + ])); + assert!(EmbeddingTable::base_table_has_embedding_column( + &schema, "tags" + )); + } } diff --git a/crates/runtime/src/embeddings/udtf.rs b/crates/runtime/src/embeddings/udtf.rs index 7c68ffbdcb..7ef15d4eaa 100644 --- a/crates/runtime/src/embeddings/udtf.rs +++ b/crates/runtime/src/embeddings/udtf.rs @@ -96,6 +96,12 @@ use tokio::sync::RwLock; pub static VECTOR_SEARCH_UDTF_NAME: &str = "vector_search"; +/// Upper bound on the number of query strings accepted by `vector_search` when +/// invoked in late-interaction (multi-query) mode. Each query produces its own +/// subplan that is `UNIONed` together, so unbounded arrays can blow up the +/// logical plan size and runtime work. +const VECTOR_SEARCH_MAX_QUERIES: usize = 32; + /// Creates a `UserDefined` signature that allows named parameters (like `rank_weight => X`) /// to pass through for RRF (Reciprocal Rank Fusion) operations. /// @@ -123,7 +129,16 @@ pub static VECTOR_SEARCH_SIGNATURE: LazyLock = LazyLock::new(|| { #[derive(Debug, PartialEq, Clone)] pub struct VectorSearchTableFuncArgs { pub tbl: TableReference, + /// Primary query string. For single-string queries this is the only + /// query; for multi-string (late-interaction) queries it is the + /// first element of `queries` and retained here for backward + /// compatibility with existing consumers that read `query` directly. pub query: String, + /// All query strings. Always contains at least one entry (mirroring + /// `query` for single-string mode). Length > 1 triggers the + /// late-interaction search path when paired with a multi-vector + /// column. + pub queries: Vec, pub column: Option, pub limit: Option, @@ -236,10 +251,27 @@ impl VectorSearchTableFunc { impl VectorSearchTableFunc { #[must_use] pub fn to_expr(args: &VectorSearchTableFuncArgs) -> Vec { - let mut expr = vec![ - Expr::Column(to_column_expr(&args.tbl)), - Expr::Literal(ScalarValue::Utf8(Some(args.query.clone())), None), - ]; + // Multi-query searches round-trip as a `make_array(...)` call; + // single-query stays as a bare Utf8 literal for backwards + // compatibility with pre-multi-query consumers. + let query_expr = if args.queries.len() > 1 { + let make_array = datafusion::functions_nested::make_array::make_array_udf(); + Expr::ScalarFunction(ScalarFunction::new_udf( + make_array, + args.queries + .iter() + .map(|q| Expr::Literal(ScalarValue::Utf8(Some(q.clone())), None)) + .collect(), + )) + } else { + let q = args + .queries + .first() + .cloned() + .unwrap_or_else(|| args.query.clone()); + Expr::Literal(ScalarValue::Utf8(Some(q)), None) + }; + let mut expr = vec![Expr::Column(to_column_expr(&args.tbl)), query_expr]; if let Some(col) = args.column.as_ref() { expr.push(Expr::Column(Column::new_unqualified(col))); @@ -259,6 +291,46 @@ impl VectorSearchTableFunc { expr } + /// Parse the query argument of `vector_search(tbl, , ...)`. + /// Accepts either a single Utf8 string literal, or a `make_array(...)` + /// (i.e. SQL `[...]` / `ARRAY[...]`) whose elements are all Utf8 + /// literals. Returns a non-empty `Vec`. + fn parse_query_arg(query: Option<&Expr>) -> DataFusionResult> { + match query { + Some(Expr::Literal(ScalarValue::Utf8(Some(q)), None)) => Ok(vec![q.clone()]), + Some(Expr::ScalarFunction(ScalarFunction { func, args })) + if func.name().eq_ignore_ascii_case("make_array") => + { + if args.is_empty() { + return Err(DataFusionError::Plan( + "Multi-query array must contain at least one query string.".to_string(), + )); + } + if args.len() > VECTOR_SEARCH_MAX_QUERIES { + return Err(DataFusionError::Plan(format!( + "Multi-query array is limited to {VECTOR_SEARCH_MAX_QUERIES} query strings, got {}.", + args.len() + ))); + } + let mut out = Vec::with_capacity(args.len()); + for a in args { + match a { + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => out.push(s.clone()), + other => { + return Err(DataFusionError::Plan(format!( + "Multi-query array elements must be string literals, got {other:?}." + ))); + } + } + } + Ok(out) + } + other => Err(DataFusionError::Plan(format!( + "Second argument must be a query string or array of query strings, but got {other:?}." + ))), + } + } + fn parse_args(args: &[Expr]) -> DataFusionResult { // Filter out passthrough parameters (those with spice.parameter_name metadata) // These are meant for table functions like RRF, not for vector_search itself @@ -276,11 +348,13 @@ impl VectorSearchTableFunc { let tbl_ref = table_ref_from_column_expr(c); let query = args.next(); - let Some(Expr::Literal(ScalarValue::Utf8(Some(q)), None)) = query else { - return Err(DataFusionError::Plan(format!( - "Second argument must be a query string, but got {query:?}." - ))); - }; + let queries = Self::parse_query_arg(query)?; + // `q` is used in downstream error messages + back-compat field. + let q = queries.first().cloned().ok_or_else(|| { + DataFusionError::Plan( + "Invalid arguments: vector_search query argument must contain at least one query value.".to_string(), + ) + })?; let (column, limit, include_score) = match (args.next(), args.next(), args.next()) { // No arguments, provides defaults @@ -346,7 +420,8 @@ impl VectorSearchTableFunc { tbl: tbl_ref .resolve(SPICE_DEFAULT_CATALOG, SPICE_DEFAULT_SCHEMA) .into(), - query: q.clone(), + query: q, + queries, column, limit: limit.map(|l| usize::try_from(l).unwrap_or(usize::MAX)), include_score, @@ -483,7 +558,11 @@ impl TableFunctionImpl for VectorSearchTableFunc { })?; let (col, _) = args.get_column_and_config(&embedding_table_provider.embedded_columns)?; - if embedding_table_provider.is_chunked(col.as_str()) { + // Both chunked-scalar and multi-vector (list-typed) columns use + // the same UNNEST-based non-indexed search path, but with + // different scan modes. + let is_multi_vector = embedding_table_provider.is_multi_vector(col.as_str()); + if embedding_table_provider.is_chunked(col.as_str()) || is_multi_vector { let state = df.ctx.state(); let Some(embed_udf) = state.scalar_functions().get(EMBED_UDF_NAME) else { return Err(DataFusionError::Plan(format!( @@ -492,7 +571,16 @@ impl TableFunctionImpl for VectorSearchTableFunc { }; // Unsafe: worse case is metric without dimensions. - let dimensions = unsafe { RequestContext::current_sync().to_dimensions() }; + let mut dimensions = unsafe { RequestContext::current_sync().to_dimensions() }; + if is_multi_vector { + dimensions.push(opentelemetry::KeyValue::new("multi_vector", true)); + if let Some(agg) = embedding_table_provider.multi_vector_aggregation(col.as_str()) { + dimensions.push(opentelemetry::KeyValue::new( + "multi_vector_aggregation", + agg.to_string(), + )); + } + } telemetry::track_vector_search(&dimensions); let pks = self .explicit_pks @@ -500,17 +588,56 @@ impl TableFunctionImpl for VectorSearchTableFunc { .cloned() .or_else(|| get_primary_keys(&table_provider).ok()); - let table = ChunkedNonIndexVectorGeneration::new( - &table_provider, - &args.tbl, - embed_udf, - embedding_table_provider - .get_embedding_model_used_by(&col) - .unwrap_or_default(), - pks.unwrap_or_default(), - &col, - ) - .search(args.query)?; + let model_name = embedding_table_provider + .get_embedding_model_used_by(&col) + .unwrap_or_default(); + let pks_vec = pks.unwrap_or_default(); + + let table = if is_multi_vector { + if args.queries.len() > 1 { + // Multi-query × multi-vector → ColBERT-style + // late-interaction: `SUM_{q in Q} MAX_{d in D} cos(q, d)`. + ChunkedNonIndexVectorGeneration::new_late_interaction( + &table_provider, + &args.tbl, + embed_udf, + model_name, + pks_vec, + &col, + args.queries.clone(), + ) + .search(args.query)? + } else { + let aggregation = embedding_table_provider + .multi_vector_aggregation(col.as_str()) + .unwrap_or_default(); + ChunkedNonIndexVectorGeneration::new_list_multi( + &table_provider, + &args.tbl, + embed_udf, + model_name, + pks_vec, + &col, + aggregation, + ) + .search(args.query)? + } + } else { + if args.queries.len() > 1 { + return Err(DataFusionError::Plan(format!( + "Multi-query `vector_search(tbl, [q1, q2, ...], col)` requires a multi-vector (list-typed) column; column '{col}' is scalar." + ))); + } + ChunkedNonIndexVectorGeneration::new( + &table_provider, + &args.tbl, + embed_udf, + model_name, + pks_vec, + &col, + ) + .search(args.query)? + }; return alias_value_to_match(Arc::clone(&table)); } @@ -765,3 +892,59 @@ fn alias_value_to_match( .collect::>(); Ok(Arc::new(ViewTable::new(bldr.project(cols)?.build()?, None))) } + +#[cfg(test)] +mod parser_tests { + use super::VectorSearchTableFunc; + use datafusion::prelude::Expr; + use datafusion::scalar::ScalarValue; + use datafusion_expr::expr::ScalarFunction; + use std::sync::Arc; + + fn lit_utf8(s: &str) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some(s.to_string())), None) + } + + #[test] + fn test_parse_query_arg_single_string() { + let q = lit_utf8("hello"); + let out = VectorSearchTableFunc::parse_query_arg(Some(&q)).expect("ok"); + assert_eq!(out, vec!["hello".to_string()]); + } + + #[test] + fn test_parse_query_arg_make_array() { + use datafusion::functions_nested::make_array::make_array_udf; + let make_array = make_array_udf(); + let q = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(&make_array), + vec![lit_utf8("red"), lit_utf8("round")], + )); + let out = VectorSearchTableFunc::parse_query_arg(Some(&q)).expect("ok"); + assert_eq!(out, vec!["red".to_string(), "round".to_string()]); + } + + #[test] + fn test_parse_query_arg_make_array_non_string_element_rejected() { + use datafusion::functions_nested::make_array::make_array_udf; + let make_array = make_array_udf(); + let q = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(&make_array), + vec![ + lit_utf8("red"), + Expr::Literal(ScalarValue::Int32(Some(42)), None), + ], + )); + let err = VectorSearchTableFunc::parse_query_arg(Some(&q)).expect_err("expected rejection"); + assert!(err.to_string().contains("must be string literals")); + } + + #[test] + fn test_parse_query_arg_empty_make_array_rejected() { + use datafusion::functions_nested::make_array::make_array_udf; + let make_array = make_array_udf(); + let q = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&make_array), vec![])); + let err = VectorSearchTableFunc::parse_query_arg(Some(&q)).expect_err("expected rejection"); + assert!(err.to_string().contains("at least one query string")); + } +} diff --git a/crates/runtime/src/search/candidate/vector.rs b/crates/runtime/src/search/candidate/vector.rs index 1bc3e7c465..bd384203d7 100644 --- a/crates/runtime/src/search/candidate/vector.rs +++ b/crates/runtime/src/search/candidate/vector.rs @@ -21,6 +21,7 @@ use datafusion::common::{Column, UnnestOptions}; use datafusion::datasource::{DefaultTableSource, ViewTable}; use datafusion::error::DataFusionError; use datafusion::functions::math::isnan; +use datafusion::functions_aggregate::expr_fn::{avg, first_value, max, sum}; use datafusion::functions_window::expr_fn::row_number; use datafusion::prelude::{array_element, substring}; use datafusion::sql::TableReference; @@ -32,14 +33,49 @@ use datafusion_expr::{ use runtime_datafusion_udfs::cosine_distance; use search::generation::CandidateGeneration; use search::{SEARCH_SCORE_COLUMN_NAME, SEARCH_VALUE_COLUMN_NAME}; +use spicepod::component::embeddings::EmbeddingAggregation; use std::sync::Arc; // Distance column name for the vector search query. // static VECTOR_DISTANCE_COLUMN_NAME: &str = "dist"; // Surrogate unique identifier name to use when no primary keys are provided. static VSS_TEMP_GEN_ID_COLUMN: &str = "vss_temp_gen_id"; +// Alias used for the source-list column while unnesting ListMulti rows; +// after unnest this column holds one matched element per row. +static LIST_MULTI_MATCH_ELEMENT_ALIAS: &str = "_match_element"; -/// A [`CandidateGeneration`] for datasets that have a chunked embedding column, but aren't using a vector index. +/// Scan mode for the non-indexed vector generation. +/// +/// `ChunkedScalar` — the source column is a single string per row that +/// has been chunked at ingest time; an `_offset` column carries +/// character offsets back into the source string. +/// +/// `ListMulti` — the source column is a list of strings; each list +/// element was embedded in M2. Single-query `MaxSim` / Mean / Sum over +/// stored elements per row. +/// +/// `LateInteraction` — both the query and the stored column are +/// multi-element. Scoring is `SUM_{q in Q} MAX_{d in D} cos(q, d)` +/// (ColBERT-style). The query strings are carried on the generator and +/// each is embedded via the `embed` UDF at query time. +#[derive(Clone, Debug)] +pub enum VectorScanMode { + ChunkedScalar, + ListMulti { + aggregation: EmbeddingAggregation, + }, + LateInteraction { + /// All query strings. The first is the "primary" query used for + /// the `query` field on `VectorSearchTableFuncArgs`; all are + /// embedded via `embed(literal, model)` in the SQL. + queries: Vec, + }, +} + +/// A [`CandidateGeneration`] for datasets whose embedding column is +/// doubly-nested (`List>`) and that do not use a +/// native vector index. Handles both chunked-scalar (content split +/// into chunks at ingest) and list-of-string multi-vector inputs. pub struct ChunkedNonIndexVectorGeneration { table_provider: Arc, tbl: TableReference, @@ -47,9 +83,12 @@ pub struct ChunkedNonIndexVectorGeneration { embed: Arc, primary_keys: Vec, embedding_column: String, + mode: VectorScanMode, } impl ChunkedNonIndexVectorGeneration { + /// Chunked-scalar mode constructor. Preserves the pre-multi-vector + /// public API so existing call sites don't need to change. pub fn new( table_provider: &Arc, tbl: &TableReference, @@ -57,6 +96,72 @@ impl ChunkedNonIndexVectorGeneration { model: String, primary_keys: Vec, embedding_column: &str, + ) -> Self { + Self::with_mode( + table_provider, + tbl, + embed, + model, + primary_keys, + embedding_column, + VectorScanMode::ChunkedScalar, + ) + } + + /// Multi-vector (list-of-strings) mode constructor. + pub fn new_list_multi( + table_provider: &Arc, + tbl: &TableReference, + embed: &Arc, + model: String, + primary_keys: Vec, + embedding_column: &str, + aggregation: EmbeddingAggregation, + ) -> Self { + Self::with_mode( + table_provider, + tbl, + embed, + model, + primary_keys, + embedding_column, + VectorScanMode::ListMulti { aggregation }, + ) + } + + /// Late-interaction (multi-query × multi-element) constructor. + /// `queries` must be non-empty. If it contains a single element the + /// generator falls back to single-query `MaxSim` over the stored + /// elements (semantically equivalent to the `ListMulti` path with + /// `aggregation = max`). + pub fn new_late_interaction( + table_provider: &Arc, + tbl: &TableReference, + embed: &Arc, + model: String, + primary_keys: Vec, + embedding_column: &str, + queries: Vec, + ) -> Self { + Self::with_mode( + table_provider, + tbl, + embed, + model, + primary_keys, + embedding_column, + VectorScanMode::LateInteraction { queries }, + ) + } + + fn with_mode( + table_provider: &Arc, + tbl: &TableReference, + embed: &Arc, + model: String, + primary_keys: Vec, + embedding_column: &str, + mode: VectorScanMode, ) -> Self { Self { table_provider: Arc::clone(table_provider), @@ -65,6 +170,7 @@ impl ChunkedNonIndexVectorGeneration { embed: Arc::clone(embed), primary_keys, embedding_column: embedding_column.to_string(), + mode, } } @@ -111,26 +217,12 @@ impl ChunkedNonIndexVectorGeneration { lp = lp.filter(f)?; } + let (project_cols, unnest_cols) = self.unnest_projection_and_columns(); + lp = lp - .project( - [ - self.primary_keys.iter().map(ident).collect(), - vec![ - ident(self.embedding_column.clone()), - ident(offset_col!(self.embedding_column)).alias("offset"), - ident(embedding_col!(self.embedding_column.clone())), - ], - ] - .concat(), - )? + .project([self.primary_keys.iter().map(ident).collect(), project_cols].concat())? // Note: `datafusion_expr::builder::unnest` does not work for complex queries - .unnest_columns_with_options( - vec![ - Column::new_unqualified("offset"), - Column::new_unqualified(embedding_col!(self.embedding_column.clone())), - ], - UnnestOptions::new(), - )?; + .unnest_columns_with_options(unnest_cols, UnnestOptions::new())?; // Compute score let mut cols = lp @@ -155,6 +247,46 @@ impl ChunkedNonIndexVectorGeneration { } } + /// Build the per-row projection list and the set of columns that + /// should be `UNNEST`ed, based on the scan mode. + /// + /// For `ChunkedScalar`: keeps the source string intact (the chunked + /// embeddings and offset arrays are unnested in lockstep and the + /// matched substring is extracted post-unnest via `substring(src, + /// offset_start, offset_length)`). + /// + /// For `ListMulti`: unnests both the source list-of-strings and the + /// embedding list together so each resulting row carries a scalar + /// string paired with its vector. + fn unnest_projection_and_columns(&self) -> (Vec, Vec) { + match &self.mode { + VectorScanMode::ChunkedScalar => ( + vec![ + ident(self.embedding_column.clone()), + ident(offset_col!(self.embedding_column)).alias("offset"), + ident(embedding_col!(self.embedding_column.clone())), + ], + vec![ + Column::new_unqualified("offset"), + Column::new_unqualified(embedding_col!(self.embedding_column.clone())), + ], + ), + // Both multi-vector modes unnest the source list and + // embedding list together, pairing a scalar string with its + // vector on each unnested row. + VectorScanMode::ListMulti { .. } | VectorScanMode::LateInteraction { .. } => ( + vec![ + ident(self.embedding_column.clone()).alias(LIST_MULTI_MATCH_ELEMENT_ALIAS), + ident(embedding_col!(self.embedding_column.clone())), + ], + vec![ + Column::new_unqualified(LIST_MULTI_MATCH_ELEMENT_ALIAS), + Column::new_unqualified(embedding_col!(self.embedding_column.clone())), + ], + ), + } + } + /// Intermediate result of vector search on chunk-based table that do not have existing primary key(s). /// /// We use an additional surrogate temp table and a generated primary key. @@ -186,22 +318,12 @@ impl ChunkedNonIndexVectorGeneration { // This is just the table with all the additional columns we may want to join on let additional_lp = lp.clone().alias("additional")?.build()?; - // Process the embedding column and offsets + // Process the embedding column and offsets / list elements + let (project_cols, unnest_cols) = self.unnest_projection_and_columns(); let mut base_lp = lp - .project(vec![ - ident(self.embedding_column.clone()), - ident(offset_col!(self.embedding_column)).alias("offset"), - ident(embedding_col!(self.embedding_column)), - col(VSS_TEMP_GEN_ID_COLUMN), - ])? + .project([project_cols, vec![col(VSS_TEMP_GEN_ID_COLUMN)]].concat())? // Note: `datafusion_expr::builder::unnest` does not work for complex queries - .unnest_columns_with_options( - vec![ - Column::new_unqualified("offset"), - Column::new_unqualified(embedding_col!(self.embedding_column.clone())), - ], - UnnestOptions::new(), - )?; + .unnest_columns_with_options(unnest_cols, UnnestOptions::new())?; // Compute score let mut cols = base_lp @@ -219,11 +341,77 @@ impl ChunkedNonIndexVectorGeneration { additional_lp, )) } + + /// Build the aggregate expression used to roll per-element scores up + /// into a single score per primary key. Applied as a window function + /// partitioned by pk, so a sibling `row_number() = 1` filter selects + /// the best-matching element in the same step. + fn aggregate_score_expr(&self, pks: &[String]) -> Result { + let partition: Vec = pks.iter().map(col).collect(); + let score_arg = col(SEARCH_SCORE_COLUMN_NAME); + + let aggregation = match &self.mode { + VectorScanMode::ListMulti { aggregation } => *aggregation, + // ChunkedScalar uses Max (single chunk wins per row). + // Late-interaction builds its own aggregation pipeline and + // does not go through this helper; `Max` is an inert default. + VectorScanMode::ChunkedScalar | VectorScanMode::LateInteraction { .. } => { + EmbeddingAggregation::Max + } + }; + + let agg = match aggregation { + EmbeddingAggregation::Max => max(score_arg), + EmbeddingAggregation::Mean => avg(score_arg), + EmbeddingAggregation::Sum => sum(score_arg), + }; + + Ok(agg + .partition_by(partition) + .build()? + .alias(AGG_SCORE_COLUMN_NAME)) + } } +// Internal alias for the aggregated per-pk score; distinct from +// `SEARCH_SCORE_COLUMN_NAME` so we can rewrite that column after +// aggregation without colliding with the per-element score. +static AGG_SCORE_COLUMN_NAME: &str = "_agg_score"; + #[async_trait::async_trait] impl CandidateGeneration for ChunkedNonIndexVectorGeneration { fn search(&self, query: String) -> Result, DataFusionError> { + match &self.mode { + VectorScanMode::ChunkedScalar => self.search_chunked_scalar(query), + VectorScanMode::ListMulti { .. } => self.search_list_multi(query), + VectorScanMode::LateInteraction { queries } => { + // Argument `query` carries the primary query string; the + // full set lives on the generator. Single-element + // late-interaction collapses to `search_list_multi`. + let qs = queries.clone(); + if qs.len() <= 1 { + self.search_list_multi(query) + } else { + self.search_late_interaction(&qs) + } + } + } + } + + fn value_derived_from(&self) -> String { + self.embedding_column.clone() + } + + fn value_projection_name(&self) -> String { + SEARCH_VALUE_COLUMN_NAME.to_string() + } +} + +impl ChunkedNonIndexVectorGeneration { + fn search_chunked_scalar( + &self, + query: String, + ) -> Result, DataFusionError> { let (pks, score_table, additional_table) = self.score_cte_sql(&self.table_provider, query, &[])?; @@ -309,11 +497,224 @@ impl CandidateGeneration for ChunkedNonIndexVectorGeneration { Ok(Arc::new(ViewTable::new(plan.build()?, None))) } - fn value_derived_from(&self) -> String { - self.embedding_column.clone() + /// Multi-vector (`ListMulti`) search: per-list-element cosine similarity + /// rolled up per primary key with the configured aggregation + /// (`max` / `mean` / `sum`). The `_match` column is the source list + /// element that produced the best per-element score. + fn search_list_multi(&self, query: String) -> Result, DataFusionError> { + let (pks, score_table, additional_table) = + self.score_cte_sql(&self.table_provider, query, &[])?; + + // Project primary keys + per-element score + matched element. + // Both `_score` and `_match_element` are scalar columns after + // UNNEST in `score_cte_sql`. + let mut plan = LogicalPlanBuilder::new(score_table) + .project( + [ + pks.iter().map(ident).collect(), + vec![ + col(SEARCH_SCORE_COLUMN_NAME), + col(LIST_MULTI_MATCH_ELEMENT_ALIAS), + ], + ] + .concat(), + )? + .filter( + LogicalExpr::ScalarFunction(ScalarFunction::new_udf( + isnan(), + vec![ident(SEARCH_SCORE_COLUMN_NAME)], + )) + .is_false(), + )?; + + let final_additional_columns: Vec<_> = self + .table_provider + .schema() + .fields() + .iter() + .filter_map(|f| { + if self.primary_keys.contains(f.name()) { + None + } else { + Some(ident(f.name().clone())) + } + }) + .collect(); + + // Two sibling window functions: one to pick the argmax element + // (`chunk_rank = 1`), one to compute the aggregated per-pk score. + let rank_window = row_number() + .partition_by(pks.iter().map(col).collect()) + .order_by(vec![col(SEARCH_SCORE_COLUMN_NAME).sort(false, false)]) + .build()? + .alias("chunk_rank"); + + let agg_window = self.aggregate_score_expr(&pks)?; + + plan = plan + .window(vec![rank_window, agg_window])? + .alias("rank")? + .filter(col("chunk_rank").eq(lit(1)))? + .sort(vec![ + LogicalExpr::Column(Column::new(Some("rank"), AGG_SCORE_COLUMN_NAME)) + .sort(false, false), + ])? + .join( + additional_table, + JoinType::Left, + pks.iter() + .map(|pk| (Column::from_name(pk), Column::from_name(pk))) + .collect(), + None, + )? + .project( + [ + final_additional_columns, + self.primary_keys + .iter() + .map(|pk| Column::new(Some("rank"), pk).into()) + .collect::>(), + vec![ + col(LIST_MULTI_MATCH_ELEMENT_ALIAS).alias(SEARCH_VALUE_COLUMN_NAME), + // Surface the aggregated score under the + // canonical `_score` name that downstream RRF / + // consumers expect. + col(AGG_SCORE_COLUMN_NAME).alias(SEARCH_SCORE_COLUMN_NAME), + ], + ] + .concat(), + )?; + + Ok(Arc::new(ViewTable::new(plan.build()?, None))) } - fn value_projection_name(&self) -> String { - SEARCH_VALUE_COLUMN_NAME.to_string() + /// Late-interaction (ColBERT-style) search over a multi-vector + /// column with a multi-string query. + /// + /// Scoring: for each query string `q_k`, compute the best per-row + /// cosine similarity against any stored element (`MaxSim`). Sum those + /// per-query bests into the final row score. + /// + /// Implementation: one sub-plan per query (reusing `score_cte_sql`'s + /// UNNEST + per-element cosine), tagged with `q_idx`, unioned; + /// then a two-step aggregate collapses per-query bests to a single + /// per-primary-key row. + fn search_late_interaction( + &self, + queries: &[String], + ) -> Result, DataFusionError> { + // Reuse the first query to grab the canonical primary-key set + // and the additional-columns plan for the final join. + let (pks, _primary_score_table, additional_table) = + self.score_cte_sql(&self.table_provider, queries[0].clone(), &[])?; + + // Build one tagged sub-plan per query. + let mut subplans: Vec = Vec::with_capacity(queries.len()); + for (idx, q) in queries.iter().enumerate() { + let (_, score_table, _) = self.score_cte_sql(&self.table_provider, q.clone(), &[])?; + let idx_i64 = i64::try_from(idx).unwrap_or(i64::MAX); + let subplan = LogicalPlanBuilder::new(score_table) + .project( + [ + pks.iter().map(ident).collect(), + vec![ + col(SEARCH_SCORE_COLUMN_NAME), + col(LIST_MULTI_MATCH_ELEMENT_ALIAS), + lit(idx_i64).alias("q_idx"), + ], + ] + .concat(), + )? + .filter( + LogicalExpr::ScalarFunction(ScalarFunction::new_udf( + isnan(), + vec![ident(SEARCH_SCORE_COLUMN_NAME)], + )) + .is_false(), + )? + .build()?; + subplans.push(subplan); + } + + // UNION ALL the per-query sub-plans. + let first = subplans.remove(0); + let mut unioned = LogicalPlanBuilder::new(first); + for p in subplans { + unioned = unioned.union(p)?; + } + + // Step 1: per (pk, q_idx) — MAX cosine (best stored element for + // this query) and FIRST_VALUE(match element ordered by score). + let step1_group: Vec = + [pks.iter().map(col).collect(), vec![col("q_idx")]].concat(); + let per_query_best_col = "per_query_best"; + let per_query_match_col = "per_query_match"; + + let per_query_sort = vec![col(SEARCH_SCORE_COLUMN_NAME).sort(false, false)]; + let step1 = unioned + .aggregate( + step1_group, + vec![ + max(col(SEARCH_SCORE_COLUMN_NAME)).alias(per_query_best_col), + first_value(col(LIST_MULTI_MATCH_ELEMENT_ALIAS), per_query_sort) + .alias(per_query_match_col), + ], + )? + .alias("per_query")?; + + // Step 2: per pk — SUM per-query bests (late-interaction total); + // pick the match element from whichever query scored highest. + let step2_group: Vec = pks.iter().map(col).collect(); + let match_sort = vec![col(per_query_best_col).sort(false, false)]; + let aggregated = step1 + .aggregate( + step2_group, + vec![ + sum(col(per_query_best_col)).alias(SEARCH_SCORE_COLUMN_NAME), + first_value(col(per_query_match_col), match_sort) + .alias(SEARCH_VALUE_COLUMN_NAME), + ], + )? + .alias("agg")?; + + // Assemble final columns: additional columns from the base + // table + primary keys + _match + _score, ordered by score. + let final_additional_columns: Vec<_> = self + .table_provider + .schema() + .fields() + .iter() + .filter_map(|f| { + if self.primary_keys.contains(f.name()) { + None + } else { + Some(ident(f.name().clone())) + } + }) + .collect(); + + let plan = aggregated + .sort(vec![col(SEARCH_SCORE_COLUMN_NAME).sort(false, false)])? + .join( + additional_table, + JoinType::Left, + pks.iter() + .map(|pk| (Column::from_name(pk), Column::from_name(pk))) + .collect(), + None, + )? + .project( + [ + final_additional_columns, + self.primary_keys + .iter() + .map(|pk| Column::new(Some("agg"), pk).into()) + .collect::>(), + vec![col(SEARCH_VALUE_COLUMN_NAME), col(SEARCH_SCORE_COLUMN_NAME)], + ] + .concat(), + )?; + + Ok(Arc::new(ViewTable::new(plan.build()?, None))) } } diff --git a/crates/runtime/src/search/candidate/vector_udtf.rs b/crates/runtime/src/search/candidate/vector_udtf.rs index b46a81f4e2..104c044bb4 100644 --- a/crates/runtime/src/search/candidate/vector_udtf.rs +++ b/crates/runtime/src/search/candidate/vector_udtf.rs @@ -54,6 +54,7 @@ impl CandidateGeneration for VectorUDTFGeneration { fn search(&self, query: String) -> Result, DataFusionError> { let udtf_args = VectorSearchTableFunc::to_expr(&VectorSearchTableFuncArgs { tbl: self.tbl.clone(), + queries: vec![query.clone()], query, column: Some(self.embedding_column.clone()), limit: None, diff --git a/crates/runtime/src/search/rrf.rs b/crates/runtime/src/search/rrf.rs index c067d00e57..16f60d7b8c 100644 --- a/crates/runtime/src/search/rrf.rs +++ b/crates/runtime/src/search/rrf.rs @@ -981,6 +981,7 @@ mod tests { vector_size: 64, in_base_table: true, chunker: None, + input_mode: crate::embeddings::table::EmbeddingInputMode::Scalar, }, ); diff --git a/crates/runtime/src/search/search_engine.rs b/crates/runtime/src/search/search_engine.rs index 33f581fc18..8c71fa2534 100644 --- a/crates/runtime/src/search/search_engine.rs +++ b/crates/runtime/src/search/search_engine.rs @@ -149,8 +149,11 @@ impl SearchEngine { }); }; - // Use UDTF for non-chunked `EmbeddingTable`. - if !embedding_table.is_chunked(embedding_column) { + // Use UDTF for scalar non-chunked `EmbeddingTable`. Both + // chunked-scalar and multi-vector (list-typed) sources need + // the UNNEST-based non-indexed search path. + let is_multi_vector = embedding_table.is_multi_vector(embedding_column); + if !embedding_table.is_chunked(embedding_column) && !is_multi_vector { return Ok(Arc::new(VectorUDTFGeneration::new( &self.df, tbl, @@ -175,14 +178,29 @@ impl SearchEngine { }); }; - Ok(Arc::new(ChunkedNonIndexVectorGeneration::new( - &table_provider, - tbl, - embed_udf, - model_name, - primary_keys.to_vec(), - embedding_column, - ))) + if is_multi_vector { + let aggregation = embedding_table + .multi_vector_aggregation(embedding_column) + .unwrap_or_default(); + Ok(Arc::new(ChunkedNonIndexVectorGeneration::new_list_multi( + &table_provider, + tbl, + embed_udf, + model_name, + primary_keys.to_vec(), + embedding_column, + aggregation, + ))) + } else { + Ok(Arc::new(ChunkedNonIndexVectorGeneration::new( + &table_provider, + tbl, + embed_udf, + model_name, + primary_keys.to_vec(), + embedding_column, + ))) + } } } diff --git a/crates/runtime/src/view.rs b/crates/runtime/src/view.rs index 230d959415..ea83155cec 100644 --- a/crates/runtime/src/view.rs +++ b/crates/runtime/src/view.rs @@ -138,6 +138,8 @@ pub(crate) async fn prepare_view( primary_keys: emb.row_ids.clone(), chunking: emb.chunking.clone(), vector_size: emb.vector_size, + aggregation: emb.aggregation, + max_elements_per_row: emb.max_elements_per_row, }) }) .collect(), diff --git a/crates/runtime/tests/models/hf.rs b/crates/runtime/tests/models/hf.rs index dcfe43715f..dc020e7aa9 100644 --- a/crates/runtime/tests/models/hf.rs +++ b/crates/runtime/tests/models/hf.rs @@ -84,6 +84,8 @@ mod nsql { row_ids: None, chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, diff --git a/crates/runtime/tests/models/openai.rs b/crates/runtime/tests/models/openai.rs index 417abc87a5..d786cc3ae4 100644 --- a/crates/runtime/tests/models/openai.rs +++ b/crates/runtime/tests/models/openai.rs @@ -75,6 +75,8 @@ mod nsql { row_ids: None, chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, @@ -179,6 +181,8 @@ mod search { chunking: None, row_ids: Some(vec!["id".to_string()]), vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, @@ -236,6 +240,8 @@ mod search { row_ids: None, chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, @@ -254,6 +260,8 @@ mod search { trim_whitespace: false, }), vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, @@ -607,6 +615,8 @@ async fn openai_test_chat_messages() -> Result<(), anyhow::Error> { row_ids: Some(vec!["i_item_sk".to_string()]), chunking: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, diff --git a/crates/runtime/tests/models/s3_vectors.rs b/crates/runtime/tests/models/s3_vectors.rs index 756c931271..e85e7d7651 100644 --- a/crates/runtime/tests/models/s3_vectors.rs +++ b/crates/runtime/tests/models/s3_vectors.rs @@ -825,6 +825,8 @@ pub(crate) mod search { chunking: None, row_ids: Some(vec!["id".to_string()]), vector_size: None, + aggregation: None, + max_elements_per_row: None, }])]; let app = AppBuilder::new("search_app") @@ -1025,6 +1027,8 @@ pub fn get_package_delivery_dataset( chunking: None, row_ids: Some(vec!["event.id".to_string()]), vector_size: None, + aggregation: None, + max_elements_per_row: None, }]), vectors_filterable_col("message.status"), vectors_filterable_col("event.created"), diff --git a/crates/runtime/tests/models/search.rs b/crates/runtime/tests/models/search.rs index 77e364bffc..3d011e32dd 100644 --- a/crates/runtime/tests/models/search.rs +++ b/crates/runtime/tests/models/search.rs @@ -533,6 +533,8 @@ pub(crate) fn item_tpcds_dataset_w_embeddings( row_ids: primary_keys, chunking, vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, @@ -574,6 +576,8 @@ pub(crate) fn catalog_page_tpcds_dataset_w_embeddings( row_ids: primary_keys, chunking, vector_size: None, + aggregation: None, + max_elements_per_row: None, }], description: None, full_text_search: None, diff --git a/crates/spicepod/src/component/embeddings.rs b/crates/spicepod/src/component/embeddings.rs index 39e142bc8b..37635621af 100644 --- a/crates/spicepod/src/component/embeddings.rs +++ b/crates/spicepod/src/component/embeddings.rs @@ -223,6 +223,41 @@ impl Display for EmbeddingPrefix { } } +/// Aggregation strategy applied when a multi-vector (list-typed) column +/// is queried. Each list element produces its own embedding; at query +/// time the per-element similarities are combined into a single per-row +/// score using this aggregation. +/// +/// `Max` is the ColBERT-style `MaxSim` default — a row scores as high as +/// its best-matching element. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(JsonSchema))] +#[serde(rename_all = "lowercase")] +pub enum EmbeddingAggregation { + #[default] + Max, + Mean, + Sum, +} + +impl Display for EmbeddingAggregation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Max => write!(f, "max"), + Self::Mean => write!(f, "mean"), + Self::Sum => write!(f, "sum"), + } + } +} + +/// Hard cap on multi-vector list elements embedded per row. Beyond this +/// limit, excess elements are dropped with a warning. See +/// `ColumnEmbeddingConfig::max_elements_per_row`. +pub const MULTI_VECTOR_MAX_ELEMENTS_HARD_CAP: usize = 1024; + +/// Default cap if none specified on a multi-vector column. +pub const MULTI_VECTOR_MAX_ELEMENTS_DEFAULT: usize = 32; + #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] #[cfg_attr(feature = "schemars", derive(JsonSchema))] pub struct EmbeddingChunkConfig { @@ -280,4 +315,16 @@ pub struct ColumnEmbeddingConfig { #[serde(skip_serializing_if = "Option::is_none")] pub vector_size: Option, + + /// Aggregation strategy for multi-vector embeddings. Only meaningful + /// when the underlying column is list-typed (`List` / + /// `LargeList`). Defaults to `max` (ColBERT-style `MaxSim`). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub aggregation: Option, + + /// Maximum number of list elements embedded per row for multi-vector + /// columns. Defaults to `32`; hard-capped at `1024`. Excess elements + /// are dropped with a warning log. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_elements_per_row: Option, } diff --git a/crates/spicepod/src/semantic.rs b/crates/spicepod/src/semantic.rs index 2393463549..0961b86ba5 100644 --- a/crates/spicepod/src/semantic.rs +++ b/crates/spicepod/src/semantic.rs @@ -22,7 +22,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::Error}; use serde_json::Value; -use crate::component::embeddings::EmbeddingChunkConfig; +use crate::component::embeddings::{EmbeddingAggregation, EmbeddingChunkConfig}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[cfg_attr(feature = "schemars", derive(JsonSchema))] @@ -133,6 +133,16 @@ pub struct ColumnLevelEmbeddingConfig { #[serde(skip_serializing_if = "Option::is_none")] pub vector_size: Option, + + /// Aggregation strategy for multi-vector embeddings. Only meaningful + /// when the underlying column is list-typed. Defaults to `max`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub aggregation: Option, + + /// Maximum number of list elements embedded per row for multi-vector + /// columns. Defaults to `32`; hard-capped at `1024`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_elements_per_row: Option, } impl ColumnLevelEmbeddingConfig { @@ -143,9 +153,23 @@ impl ColumnLevelEmbeddingConfig { chunking: None, row_ids: None, vector_size: None, + aggregation: None, + max_elements_per_row: None, } } + #[must_use] + pub fn with_aggregation(mut self, aggregation: EmbeddingAggregation) -> Self { + self.aggregation = Some(aggregation); + self + } + + #[must_use] + pub fn with_max_elements_per_row(mut self, n: usize) -> Self { + self.max_elements_per_row = Some(n); + self + } + #[must_use] pub fn chunking(mut self, chunking: EmbeddingChunkConfig) -> Self { self.chunking = Some(chunking);