Skip to content

Commit b9e71a9

Browse files
lukekimkrinart
andauthored
feat(embeddings): multi-vector embeddings with MaxSim + late-interaction (spiceai#10408)
* feat(embeddings): multi-vector embeddings with MaxSim + late-interaction Extends column-level embeddings to accept list-of-string columns and produces one embedding vector per list element, stored as List<FixedSizeList<F32, D>> per row. Enables tag/synonym/attribute-style columns to participate in vector search without users having to pre-flatten into a separate table. Search adds two new scoring modes alongside the existing chunked-scalar path: - Single-query × multi-vector (MaxSim / Mean / Sum): per-row score is max/mean/sum over the list element cosines. Default is MaxSim (ColBERT-style). _match returns the element that produced the top cosine. - Multi-query × multi-vector (late-interaction): SUM_{q in Q} MAX_{d in D} cos(q, d). Opt in by passing an array query, e.g. vector_search(tbl, ['foo','bar'], col). Config surface (both ColumnLevelEmbeddingConfig and the legacy ColumnEmbeddingConfig) gains two new fields: - aggregation: max|mean|sum (default max); rejected on scalar columns. - max_elements_per_row: default 32, hard cap 1024; excess elements are dropped with a tracing::warn. Mode is auto-detected from the column Arrow type: Utf8/Utf8View/LargeUtf8 → Scalar, List<Utf8>/LargeList<Utf8> → ListMulti. Chunking is rejected on list columns; multi-vector options on scalar columns error with a clear message. Implementation: - EmbeddingInputMode { Scalar, ListMulti } threaded through EmbeddingTable; resolve_input_mode enforces all validation rules. - decompose_list_of_strings handles ListArray, LargeListArray and all three Utf8 element types; respects max_elements_per_row truncation. - get_vectors_per_list_element (async + sync) embeds via one model call, respecting null-row (→ empty output list) and null/empty-element (→ null vector) semantics. - base_table_has_embedding_column relaxes the offset-column requirement when the source column is list-typed; multi-vector uses element index as the implicit offset. - ChunkedNonIndexVectorGeneration grows a VectorScanMode with three variants; search() dispatches to search_chunked_scalar, search_list_multi, or search_late_interaction. Aggregation uses Expr::AggregateFunction.partition_by windowing. Late-interaction unions per-query subplans tagged with q_idx and does a two-step aggregate (pk, q_idx → MAX; pk → SUM). - VectorSearchTableFuncArgs gains a queries: Vec<String> alongside the existing query: String. parse_query_arg accepts either a Utf8 literal or a make_array(...) expression; to_expr round-trips both forms. Dispatcher errors when multi-query is paired with a scalar column. - Telemetry track_vector_search gets multi_vector and multi_vector_aggregation KeyValue dims when applicable. Accelerator compatibility: multi-vector output Arrow shape is identical to the chunked-scalar path's output, so Arrow, Cayenne, and DuckDB round-trip transparently. Turso serializes nested lists to JSON TEXT today (turso.rs:581-583); SQLite inherits via datafusion-table-providers. The compat matrix is documented at the head of EmbeddingInputMode. A native typed side-table for SQLite/Turso is a future optimization that would benefit chunked-scalar equally. Tests: 41 new unit tests across embeddings::table (25), embeddings::execution_plan (12), and embeddings::udtf::parser_tests (4) cover type detection, input-mode resolution, list decomposition with null/empty/truncation edge cases, multi-vector list-array construction, end-to-end per-element embedding via a mock embedder, and the make_array query parser. * Fix + Lint * Lint * fix(embeddings): address review comments on multi-vector PR - build_multi_vector_list_array validates each embedding's length against vector_length before appending, returning a structured error instead of letting the FixedSizeListBuilder panic on mismatch. - decompose_generic_list hoists value_offsets() outside the loop and resolves the string-array variant once via a three-way downcast, removing per-element dynamic dispatch. Introduces a generic build_rows helper parameterised by a closure. - parse_query_arg rejects make_array(...) with more than VECTOR_SEARCH_MAX_QUERIES (32) elements to prevent late-interaction plans from blowing up on unbounded input. - to_expr derives the single-query literal from args.queries.first() so the single- and multi-query branches stay consistent. * fix(embeddings): round 2 of review fixes — format, missing-column error, grammar - cargo fmt fixup on decompose_generic_list after the string-variant refactor. - try_new now returns Error::EmbeddingColumnNotInSchema when a configured embedding source column is missing from the base schema, instead of silently dropping the column. Misconfiguration fails fast during table construction. - Grammar: "Cannot use it create an embeddings" → "Cannot use it to create embeddings" in the base_table_has_embedding_column warning. * Lint * feat: Implement JSON schema decomposition for HTTP connector * fix(vector): unify aggregation handling for ChunkedScalar and LateInteraction modes * Lint * fix(tests): improve variable naming for clarity in embedding tests * fix(clippy): replace unwrap with expect and Arc::clone in embedding tests --------- Co-authored-by: Viktor Yershov <[email protected]>
1 parent 5374644 commit b9e71a9

16 files changed

Lines changed: 1970 additions & 106 deletions

File tree

crates/runtime/src/cluster/datafusion/codec/spice_logical_codec.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ impl SpiceLogicalCodec {
117117
);
118118
let exprs = VectorSearchTableFunc::to_expr(&VectorSearchTableFuncArgs {
119119
tbl: SqlTableReference::parse_str(&vector_args.table),
120+
queries: vec![vector_args.query.clone()],
120121
query: vector_args.query,
121122
column: vector_args.column,
122123
limit: vector_args.limit.map(Self::limit_from_u64).transpose()?,
@@ -164,6 +165,7 @@ impl SpiceLogicalCodec {
164165
};
165166
let vector_exprs = VectorSearchTableFunc::to_expr(&VectorSearchTableFuncArgs {
166167
tbl: SqlTableReference::parse_str(&args.table),
168+
queries: vec![args.query.clone()],
167169
query: args.query.clone(),
168170
column: args.column.clone(),
169171
limit: args.limit.map(Self::limit_from_u64).transpose()?,

crates/runtime/src/embeddings/connector.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ impl EmbeddingConnector {
122122
chunking: e.chunking.clone(),
123123
primary_keys: e.row_ids.clone(),
124124
vector_size: e.vector_size,
125+
aggregation: e.aggregation,
126+
max_elements_per_row: e.max_elements_per_row,
125127
})
126128
})
127129
.collect_vec();

0 commit comments

Comments
 (0)