Skip to content

Commit 02e9cc9

Browse files
fix(sqlite): deserialize JSON columns (0xPlaygrounds#1797)
1 parent 97cabc7 commit 02e9cc9

1 file changed

Lines changed: 222 additions & 0 deletions

File tree

crates/rig-sqlite/src/lib.rs

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,12 +1928,50 @@ fn sqlite_text_value(
19281928
Ok(serde_json::Value::String(value.to_string()))
19291929
}
19301930

1931+
fn sqlite_column_declares_json(column_type: &str) -> bool {
1932+
column_type
1933+
.split_whitespace()
1934+
.next()
1935+
.is_some_and(|token| token.eq_ignore_ascii_case("JSON"))
1936+
}
1937+
1938+
fn sqlite_json_text_value(
1939+
index: usize,
1940+
value_type: Type,
1941+
column: &Column,
1942+
value: &[u8],
1943+
) -> rusqlite::Result<serde_json::Value> {
1944+
let value = std::str::from_utf8(value).map_err(|e| {
1945+
sqlite_column_value_error(
1946+
index,
1947+
value_type,
1948+
column,
1949+
format!("invalid UTF-8 JSON text: {e}"),
1950+
)
1951+
})?;
1952+
1953+
serde_json::from_str(value).map_err(|e| {
1954+
sqlite_column_value_error(index, value_type, column, format!("invalid JSON text: {e}"))
1955+
})
1956+
}
1957+
19311958
fn sqlite_column_value_to_json(
19321959
index: usize,
19331960
column: &Column,
19341961
value: ValueRef<'_>,
19351962
) -> rusqlite::Result<serde_json::Value> {
19361963
let value_type = value.data_type();
1964+
1965+
if sqlite_column_declares_json(column.col_type) {
1966+
return match value {
1967+
ValueRef::Null => Ok(serde_json::Value::Null),
1968+
ValueRef::Text(value) => sqlite_json_text_value(index, value_type, column, value),
1969+
ValueRef::Integer(value) => Ok(serde_json::Value::Number(value.into())),
1970+
ValueRef::Real(value) => sqlite_number_value(index, value_type, column, value),
1971+
ValueRef::Blob(value) => sqlite_json_text_value(index, value_type, column, value),
1972+
};
1973+
}
1974+
19371975
let column_affinity = SqliteColumnAffinity::from_column_type(column.col_type);
19381976

19391977
match (column_affinity, value) {
@@ -2273,6 +2311,66 @@ mod tests {
22732311
]
22742312
}
22752313

2314+
#[test]
2315+
fn json_column_text_decodes_to_json_object() -> anyhow::Result<()> {
2316+
let column = Column::new("metadata", "JSON");
2317+
let value = sqlite_column_value_to_json(
2318+
0,
2319+
&column,
2320+
ValueRef::Text(br#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#),
2321+
)?;
2322+
2323+
let expected = serde_json::json!({
2324+
"knowledge_doc_id": 361,
2325+
"knowledge_id": 1,
2326+
"user_id": 1
2327+
});
2328+
anyhow::ensure!(
2329+
value == expected,
2330+
"JSON column text should decode to a JSON object, got {value:?}"
2331+
);
2332+
2333+
Ok(())
2334+
}
2335+
2336+
#[test]
2337+
fn text_column_json_looking_text_stays_string() -> anyhow::Result<()> {
2338+
let column = Column::new("metadata", "TEXT");
2339+
let value = sqlite_column_value_to_json(
2340+
0,
2341+
&column,
2342+
ValueRef::Text(br#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#),
2343+
)?;
2344+
2345+
let expected =
2346+
serde_json::json!(r#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#);
2347+
anyhow::ensure!(
2348+
value == expected,
2349+
"TEXT column should preserve JSON-looking text as a string, got {value:?}"
2350+
);
2351+
2352+
Ok(())
2353+
}
2354+
2355+
#[test]
2356+
fn json_column_invalid_text_returns_conversion_error() -> anyhow::Result<()> {
2357+
let column = Column::new("metadata", "JSON");
2358+
let err = match sqlite_column_value_to_json(0, &column, ValueRef::Text(b"not json")) {
2359+
Ok(value) => anyhow::bail!("invalid JSON column text should fail, got {value:?}"),
2360+
Err(err) => err,
2361+
};
2362+
2363+
anyhow::ensure!(
2364+
matches!(
2365+
err,
2366+
rusqlite::Error::FromSqlConversionFailure(0, Type::Text, _)
2367+
),
2368+
"invalid JSON column text should return a conversion error, got {err}"
2369+
);
2370+
2371+
Ok(())
2372+
}
2373+
22762374
fn filter_error<T: std::fmt::Debug>(
22772375
result: Result<T, FilterError>,
22782376
context: &str,
@@ -3230,6 +3328,50 @@ mod tests {
32303328
Ok(())
32313329
}
32323330

3331+
#[tokio::test]
3332+
async fn live_json_column_structured_metadata_round_trips_in_top_n() -> anyhow::Result<()> {
3333+
let metadata = StructuredMetadata {
3334+
user_id: 1,
3335+
knowledge_id: 1,
3336+
knowledge_doc_id: 361,
3337+
};
3338+
let index = live_structured_json_metadata_test_index(
3339+
"live_json_column_structured_metadata_round_trips_in_top_n",
3340+
vec![structured_json_metadata_row(
3341+
"structured",
3342+
metadata.clone(),
3343+
"metadata document",
3344+
vec![1.0, 0.0],
3345+
)],
3346+
)
3347+
.await?;
3348+
3349+
let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3350+
.query("needle")
3351+
.samples(1)
3352+
.build();
3353+
let results = index
3354+
.top_n::<StructuredJsonMetadataDocument>(req.clone())
3355+
.await?;
3356+
3357+
let Some((_, id, doc)) = results.first() else {
3358+
anyhow::bail!("expected structured JSON metadata document result");
3359+
};
3360+
anyhow::ensure!(id == "structured", "unexpected id: {id}");
3361+
anyhow::ensure!(
3362+
doc.metadata == metadata,
3363+
"JSON column should deserialize into structured metadata: {doc:?}"
3364+
);
3365+
3366+
let id_results = index.top_n_ids(req).await?;
3367+
anyhow::ensure!(
3368+
id_results.first().is_some_and(|(_, id)| id == "structured"),
3369+
"top_n_ids should still return the structured metadata document id: {id_results:?}"
3370+
);
3371+
3372+
Ok(())
3373+
}
3374+
32333375
#[tokio::test]
32343376
async fn live_text_affinity_metadata_filters_during_candidate_search() -> anyhow::Result<()> {
32353377
let index = live_common_type_test_index(
@@ -4140,6 +4282,22 @@ mod tests {
41404282
Ok(vector_store.index(model))
41414283
}
41424284

4285+
async fn live_structured_json_metadata_test_index(
4286+
name: &str,
4287+
rows: Vec<(StructuredJsonMetadataDocument, OneOrMany<Embedding>)>,
4288+
) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, StructuredJsonMetadataDocument>> {
4289+
register_sqlite_vec_extension();
4290+
4291+
let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4292+
let model = TestEmbeddingModel;
4293+
let vector_store: SqliteVectorStore<_, StructuredJsonMetadataDocument> =
4294+
SqliteVectorStore::new(conn, &model).await?;
4295+
4296+
vector_store.add_rows(rows).await?;
4297+
4298+
Ok(vector_store.index(model))
4299+
}
4300+
41434301
fn row(
41444302
id: impl Into<String>,
41454303
category: impl Into<String>,
@@ -4207,6 +4365,27 @@ mod tests {
42074365
)
42084366
}
42094367

4368+
fn structured_json_metadata_row(
4369+
id: impl Into<String>,
4370+
metadata: StructuredMetadata,
4371+
title: impl Into<String>,
4372+
vec: Vec<f64>,
4373+
) -> (StructuredJsonMetadataDocument, OneOrMany<Embedding>) {
4374+
let document = StructuredJsonMetadataDocument {
4375+
id: id.into(),
4376+
metadata,
4377+
title: title.into(),
4378+
};
4379+
4380+
(
4381+
document.clone(),
4382+
OneOrMany::one(Embedding {
4383+
document: document.title.clone(),
4384+
vec,
4385+
}),
4386+
)
4387+
}
4388+
42104389
fn reordered_id_row(
42114390
id: impl Into<String>,
42124391
title: impl Into<String>,
@@ -4675,6 +4854,49 @@ mod tests {
46754854
}
46764855
}
46774856

4857+
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
4858+
struct StructuredMetadata {
4859+
user_id: i64,
4860+
knowledge_id: i64,
4861+
knowledge_doc_id: i64,
4862+
}
4863+
4864+
#[derive(Clone, Debug, Deserialize, Serialize)]
4865+
struct StructuredJsonMetadataDocument {
4866+
id: String,
4867+
metadata: StructuredMetadata,
4868+
title: String,
4869+
}
4870+
4871+
impl SqliteVectorStoreTable for StructuredJsonMetadataDocument {
4872+
fn name() -> &'static str {
4873+
"live_structured_json_metadata_test_documents"
4874+
}
4875+
4876+
fn schema() -> Vec<Column> {
4877+
vec![
4878+
Column::new("id", "TEXT PRIMARY KEY"),
4879+
Column::new("metadata", "JSON"),
4880+
Column::new("title", "TEXT"),
4881+
]
4882+
}
4883+
4884+
fn id(&self) -> String {
4885+
self.id.clone()
4886+
}
4887+
4888+
fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4889+
vec![
4890+
("id", Box::new(self.id.clone())),
4891+
(
4892+
"metadata",
4893+
Box::new(serde_json::to_string(&self.metadata).unwrap_or_default()),
4894+
),
4895+
("title", Box::new(self.title.clone())),
4896+
]
4897+
}
4898+
}
4899+
46784900
#[derive(Clone, Debug, Deserialize, Serialize)]
46794901
struct TypedTestDocument {
46804902
id: i64,

0 commit comments

Comments
 (0)