Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions crates/vector-store/src/db_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ use crate::DbEmbedding;
use crate::IndexMetadata;
use crate::KeyspaceName;
use crate::Percentage;
use crate::PrimaryKey;
use crate::Progress;
use crate::TableName;
use crate::internals::Internals;
use crate::internals::InternalsExt;
use crate::invariant_key::InvariantKey;
use crate::node_state::Event;
use crate::node_state::NodeState;
use crate::node_state::NodeStateExt;
Expand Down Expand Up @@ -451,6 +453,15 @@ impl Statements {
.collect_vec(),
);

anyhow::ensure!(
primary_key_columns.len() <= InvariantKey::MAX_COLUMNS,
"table {}.{} has {} primary key columns, but at most {} are supported",
metadata.keyspace_name,
metadata.table_name,
primary_key_columns.len(),
InvariantKey::MAX_COLUMNS,
);

let table_columns = Arc::new(
table
.columns
Expand Down Expand Up @@ -752,7 +763,11 @@ impl Statements {
else {
return None;
};
let primary_key = primary_key.into();
let primary_key = PrimaryKey::from(
InvariantKey::try_new(primary_key)
.inspect_err(|err| debug!("range_scan_stream: {err}"))
.ok()?,
);

Some(DbEmbedding {
primary_key,
Expand Down Expand Up @@ -825,8 +840,8 @@ impl Consumer for CdcConsumer {
"CDC error: primary key column {column} value should exist"
))
})
.collect::<anyhow::Result<Vec<_>>>()?
.into();
.collect::<anyhow::Result<Vec<_>>>()?;
let primary_key = PrimaryKey::from(InvariantKey::try_new(primary_key)?);

const HUNDREDS_NANOS_TO_MICROS: u64 = 10;
let timestamp = (self.0.gregorian_epoch
Expand Down
10 changes: 7 additions & 3 deletions crates/vector-store/src/httproutes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,16 +647,20 @@ async fn post_index_ann(
let primary_keys: anyhow::Result<_> = primary_keys
.iter()
.map(|primary_key| {
if primary_key.0.len() != primary_key_columns.len() {
if primary_key.len() != primary_key_columns.len() {
bail!(
"wrong size of a primary key: {}, {}",
primary_key_columns.len(),
primary_key.0.len()
primary_key.len()
);
}
Ok(primary_key)
})
.map_ok(|primary_key| primary_key.0[idx_column].clone())
.map_ok(|primary_key| {
primary_key
.get(idx_column)
.expect("primary key index out of bounds after length check")
})
.map_ok(try_to_json)
.map(|primary_key| primary_key.flatten())
.collect();
Expand Down
83 changes: 50 additions & 33 deletions crates/vector-store/src/index/usearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,15 +944,15 @@ fn cql_cmp(lhs: &CqlValue, rhs: &CqlValue) -> Option<std::cmp::Ordering> {

/// Lexicographically compare tuple values.
/// Returns the ordering of the first non-equal pair, or Equal if all pairs are equal.
fn cql_cmp_tuple<'a>(
primary_key: &'a PrimaryKey,
primary_key_value: impl Fn(&'a PrimaryKey, &ColumnName) -> Option<&'a CqlValue>,
fn cql_cmp_tuple(
primary_key: &PrimaryKey,
primary_key_value: impl Fn(&PrimaryKey, &ColumnName) -> Option<CqlValue>,
lhs: &[ColumnName],
rhs: &[CqlValue],
) -> Option<std::cmp::Ordering> {
for (col, rhs_val) in lhs.iter().zip(rhs.iter()) {
let lhs_val = primary_key_value(primary_key, col)?;
match cql_cmp(lhs_val, rhs_val)? {
match cql_cmp(&lhs_val, rhs_val)? {
std::cmp::Ordering::Equal => continue,
other => return Some(other),
}
Expand All @@ -971,17 +971,17 @@ fn filtered_ann(
) {
fn annotate<F>(f: F) -> F
where
F: for<'a, 'b> Fn(&'a PrimaryKey, &'b ColumnName) -> Option<&'a CqlValue>,
F: Fn(&PrimaryKey, &ColumnName) -> Option<CqlValue>,
{
f
}

let primary_key_value = annotate(
|primary_key: &PrimaryKey, name: &ColumnName| -> Option<&CqlValue> {
|primary_key: &PrimaryKey, name: &ColumnName| -> Option<CqlValue> {
primary_key_columns
.iter()
.position(|key_column| key_column == name)
.and_then(move |idx| primary_key.0.get(idx))
.and_then(move |idx| primary_key.get(idx))
},
);

Expand All @@ -993,27 +993,29 @@ fn filtered_ann(
.restrictions
.iter()
.all(|restriction| match restriction {
Restriction::Eq { lhs, rhs } => primary_key_value(&primary_key, lhs) == Some(rhs),
Restriction::Eq { lhs, rhs } => {
primary_key_value(&primary_key, lhs).as_ref() == Some(rhs)
}
Restriction::In { lhs, rhs } => {
let value = primary_key_value(&primary_key, lhs);
rhs.iter().any(|rhs| value == Some(rhs))
rhs.iter().any(|rhs| value.as_ref() == Some(rhs))
}
Restriction::Lt { lhs, rhs } => primary_key_value(&primary_key, lhs)
.and_then(|value| cql_cmp(value, rhs))
.and_then(|value| cql_cmp(&value, rhs))
.is_some_and(|ord| ord.is_lt()),
Restriction::Lte { lhs, rhs } => primary_key_value(&primary_key, lhs)
.and_then(|value| cql_cmp(value, rhs))
.and_then(|value| cql_cmp(&value, rhs))
.is_some_and(|ord| ord.is_le()),
Restriction::Gt { lhs, rhs } => primary_key_value(&primary_key, lhs)
.and_then(|value| cql_cmp(value, rhs))
.and_then(|value| cql_cmp(&value, rhs))
.is_some_and(|ord| ord.is_gt()),
Restriction::Gte { lhs, rhs } => primary_key_value(&primary_key, lhs)
.and_then(|value| cql_cmp(value, rhs))
.and_then(|value| cql_cmp(&value, rhs))
.is_some_and(|ord| ord.is_ge()),
Restriction::EqTuple { lhs, rhs } => lhs
.iter()
.zip(rhs.iter())
.all(|(lhs, rhs)| primary_key_value(&primary_key, lhs) == Some(rhs)),
.all(|(lhs, rhs)| primary_key_value(&primary_key, lhs).as_ref() == Some(rhs)),
Restriction::InTuple { lhs, rhs } => {
let values: Vec<_> = lhs
.iter()
Expand All @@ -1023,7 +1025,7 @@ fn filtered_ann(
values
.iter()
.zip(rhs.iter())
.all(|(value, rhs)| value == &Some(rhs))
.all(|(value, rhs)| value.as_ref() == Some(rhs))
})
}
Restriction::LtTuple { lhs, rhs } => {
Expand Down Expand Up @@ -1132,6 +1134,7 @@ mod tests {
use crate::ExpansionSearch;
use crate::IndexId;
use crate::index::IndexExt;
use crate::invariant_key::InvariantKey;
use crate::memory;
use scylla::value::CqlValue;
use std::num::NonZeroUsize;
Expand All @@ -1156,7 +1159,7 @@ mod tests {
let id = worker * adds_per_worker + offset;
actor
.add(
vec![CqlValue::Int(id as i32)].into(),
InvariantKey::new(vec![CqlValue::Int(id as i32)]).into(),
vec![0.0f32; dimensions.get()].into(),
None,
)
Expand Down Expand Up @@ -1215,21 +1218,22 @@ mod tests {

actor
.add(
vec![CqlValue::Int(1), CqlValue::Text("one".to_string())].into(),
InvariantKey::new(vec![CqlValue::Int(1), CqlValue::Text("one".to_string())]).into(),
vec![1., 1., 1.].into(),
None,
)
.await;
actor
.add(
vec![CqlValue::Int(2), CqlValue::Text("two".to_string())].into(),
InvariantKey::new(vec![CqlValue::Int(2), CqlValue::Text("two".to_string())]).into(),
vec![2., -2., 2.].into(),
None,
)
.await;
actor
.add(
vec![CqlValue::Int(3), CqlValue::Text("three".to_string())].into(),
InvariantKey::new(vec![CqlValue::Int(3), CqlValue::Text("three".to_string())])
.into(),
vec![3., 3., 3.].into(),
None,
)
Expand All @@ -1254,18 +1258,20 @@ mod tests {
assert_eq!(distances.len(), 1);
assert_eq!(
primary_keys.first().unwrap(),
&vec![CqlValue::Int(2), CqlValue::Text("two".to_string())].into(),
&InvariantKey::new(vec![CqlValue::Int(2), CqlValue::Text("two".to_string())]).into(),
);

actor
.remove(
vec![CqlValue::Int(3), CqlValue::Text("three".to_string())].into(),
InvariantKey::new(vec![CqlValue::Int(3), CqlValue::Text("three".to_string())])
.into(),
None,
)
.await;
actor
.add(
vec![CqlValue::Int(3), CqlValue::Text("three".to_string())].into(),
InvariantKey::new(vec![CqlValue::Int(3), CqlValue::Text("three".to_string())])
.into(),
vec![2.1, -2.1, 2.1].into(),
None,
)
Expand All @@ -1282,7 +1288,8 @@ mod tests {
.0
.first()
.unwrap()
!= &vec![CqlValue::Int(3), CqlValue::Text("three".to_string())].into()
!= &InvariantKey::new(vec![CqlValue::Int(3), CqlValue::Text("three".to_string())])
.into()
{
task::yield_now().await;
}
Expand All @@ -1292,7 +1299,8 @@ mod tests {

actor
.remove(
vec![CqlValue::Int(3), CqlValue::Text("three".to_string())].into(),
InvariantKey::new(vec![CqlValue::Int(3), CqlValue::Text("three".to_string())])
.into(),
None,
)
.await;
Expand All @@ -1316,7 +1324,7 @@ mod tests {
assert_eq!(distances.len(), 1);
assert_eq!(
primary_keys.first().unwrap(),
&vec![CqlValue::Int(2), CqlValue::Text("two".to_string())].into(),
&InvariantKey::new(vec![CqlValue::Int(2), CqlValue::Text("two".to_string())]).into(),
);
}

Expand Down Expand Up @@ -1351,7 +1359,11 @@ mod tests {
memory_rx
});
actor
.add(vec![CqlValue::Int(1)].into(), vec![1., 1., 1.].into(), None)
.add(
InvariantKey::new(vec![CqlValue::Int(1)]).into(),
vec![1., 1., 1.].into(),
None,
)
.await;
let mut memory_rx = memory_respond.await.unwrap();
assert_eq!(actor.count().await.unwrap(), 0);
Expand All @@ -1361,7 +1373,11 @@ mod tests {
_ = tx.send(Allocate::Can);
});
actor
.add(vec![CqlValue::Int(1)].into(), vec![1., 1., 1.].into(), None)
.add(
InvariantKey::new(vec![CqlValue::Int(1)]).into(),
vec![1., 1., 1.].into(),
None,
)
.await;
memory_respond.await.unwrap();

Expand Down Expand Up @@ -1506,21 +1522,22 @@ mod tests {

mod cql_cmp_tuple_tests {
use super::super::{ColumnName, PrimaryKey, cql_cmp_tuple};
use crate::invariant_key::InvariantKey;
use scylla::value::CqlValue;
use std::cmp::Ordering;

fn make_primary_key(values: Vec<CqlValue>) -> PrimaryKey {
values.into()
InvariantKey::new(values).into()
}

fn primary_key_value_fn<'a>(
columns: &'a [ColumnName],
) -> impl Fn(&'a PrimaryKey, &ColumnName) -> Option<&'a CqlValue> {
move |pk: &'a PrimaryKey, name: &ColumnName| {
fn primary_key_value_fn(
columns: &[ColumnName],
) -> impl Fn(&PrimaryKey, &ColumnName) -> Option<CqlValue> + use<'_> {
move |pk: &PrimaryKey, name: &ColumnName| {
columns
.iter()
.position(|col| col == name)
.and_then(|idx| pk.0.get(idx))
.and_then(|idx| pk.get(idx))
}
}

Expand Down
Loading