Skip to content

[ENH] Wire up collection forking for RFE #4309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: sicheng/04-16-_enh_implement_collection_forking_in_sysdb
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions rust/frontend/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ pub enum AuthzAction {
CreateCollection,
GetOrCreateCollection,
GetCollection,
DeleteCollection,
UpdateCollection,
DeleteCollection,
ForkCollection,
Add,
Delete,
Get,
Expand All @@ -47,8 +48,9 @@ impl Display for AuthzAction {
AuthzAction::CreateCollection => write!(f, "db:create_collection"),
AuthzAction::GetOrCreateCollection => write!(f, "db:get_or_create_collection"),
AuthzAction::GetCollection => write!(f, "collection:get_collection"),
AuthzAction::DeleteCollection => write!(f, "collection:delete_collection"),
AuthzAction::UpdateCollection => write!(f, "collection:update_collection"),
AuthzAction::DeleteCollection => write!(f, "collection:delete_collection"),
AuthzAction::ForkCollection => write!(f, "collection:fork_collection"),
AuthzAction::Add => write!(f, "collection:add"),
AuthzAction::Delete => write!(f, "collection:delete"),
AuthzAction::Get => write!(f, "collection:get"),
Expand Down
74 changes: 43 additions & 31 deletions rust/frontend/src/get_collection_with_segments_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,43 +151,55 @@ impl CollectionsWithSegmentsProvider {
return Ok(collection_and_segments_with_ttl.collection_and_segments);
}
}
// We acquire a lock to prevent the sysdb from experiencing a thundering herd.
// This can happen when a large number of threads try to get the same collection
// at the same time.
let _guard = self.sysdb_rpc_lock.lock(&collection_id).await;
// Double checked locking pattern to avoid lock contention in the
// happy path when the collection is already cached.
if let Some(collection_and_segments_with_ttl) = self
.collections_with_segments_cache
.get(&collection_id)
.await?
{
if collection_and_segments_with_ttl.expires_at
> SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")

let collection_and_segments_sysdb = {
// We acquire a lock to prevent the sysdb from experiencing a thundering herd.
// This can happen when a large number of threads try to get the same collection
// at the same time.
let _guard = self.sysdb_rpc_lock.lock(&collection_id).await;
// Double checked locking pattern to avoid lock contention in the
// happy path when the collection is already cached.
if let Some(collection_and_segments_with_ttl) = self
.collections_with_segments_cache
.get(&collection_id)
.await?
{
return Ok(collection_and_segments_with_ttl.collection_and_segments);
if collection_and_segments_with_ttl.expires_at
> SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")
{
return Ok(collection_and_segments_with_ttl.collection_and_segments);
}
}
}
tracing::info!("Cache miss for collection {}", collection_id);
let collection_and_segments_sysdb = self
.sysdb_client
.get_collection_with_segments(collection_id)
.await?;
let collection_and_segments_sysdb_with_ttl = CollectionAndSegmentsWithTtl {
collection_and_segments: collection_and_segments_sysdb.clone(),
expires_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")
+ Duration::from_secs(self.cache_ttl_secs as u64), // Cache for 1 minute
tracing::info!("Cache miss for collection {}", collection_id);
self.sysdb_client
.get_collection_with_segments(collection_id)
.await?
};

self.set_collection_with_segments(collection_and_segments_sysdb.clone())
.await;
Ok(collection_and_segments_sysdb)
}

pub(crate) async fn set_collection_with_segments(
&mut self,
collection_and_segments: CollectionAndSegments,
) {
// Insert only if the collection dimension is set.
if collection_and_segments_sysdb.collection.dimension.is_some() {
if collection_and_segments.collection.dimension.is_some() {
let collection_id = collection_and_segments.collection.collection_id;
let collection_and_segments_with_ttl = CollectionAndSegmentsWithTtl {
collection_and_segments,
expires_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")
+ Duration::from_secs(self.cache_ttl_secs as u64), // Cache for 1 minute
};
self.collections_with_segments_cache
.insert(collection_id, collection_and_segments_sysdb_with_ttl)
.insert(collection_id, collection_and_segments_with_ttl)
.await;
}
Ok(collection_and_segments_sysdb)
}
}
42 changes: 35 additions & 7 deletions rust/frontend/src/impls/service_based_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ use chroma_types::{
CreateTenantError, CreateTenantRequest, CreateTenantResponse, DeleteCollectionError,
DeleteCollectionRecordsError, DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse,
DeleteCollectionRequest, DeleteDatabaseError, DeleteDatabaseRequest, DeleteDatabaseResponse,
GetCollectionError, GetCollectionRequest, GetCollectionResponse, GetCollectionsError,
GetDatabaseError, GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse,
GetTenantError, GetTenantRequest, GetTenantResponse, HealthCheckResponse, HeartbeatError,
HeartbeatResponse, Include, KnnIndex, ListCollectionsRequest, ListCollectionsResponse,
ListDatabasesError, ListDatabasesRequest, ListDatabasesResponse, Operation, OperationRecord,
QueryError, QueryRequest, QueryResponse, ResetError, ResetResponse, Segment, SegmentScope,
SegmentType, SegmentUuid, UpdateCollectionError, UpdateCollectionRecordsError,
ForkCollectionError, ForkCollectionRequest, ForkCollectionResponse, GetCollectionError,
GetCollectionRequest, GetCollectionResponse, GetCollectionsError, GetDatabaseError,
GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse, GetTenantError,
GetTenantRequest, GetTenantResponse, HealthCheckResponse, HeartbeatError, HeartbeatResponse,
Include, KnnIndex, ListCollectionsRequest, ListCollectionsResponse, ListDatabasesError,
ListDatabasesRequest, ListDatabasesResponse, Operation, OperationRecord, QueryError,
QueryRequest, QueryResponse, ResetError, ResetResponse, Segment, SegmentScope, SegmentType,
SegmentUuid, UpdateCollectionError, UpdateCollectionRecordsError,
UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse, UpdateCollectionRequest,
UpdateCollectionResponse, UpsertCollectionRecordsError, UpsertCollectionRecordsRequest,
UpsertCollectionRecordsResponse, VectorIndexConfiguration, Where,
Expand Down Expand Up @@ -542,6 +543,33 @@ impl ServiceBasedFrontend {
Ok(DeleteCollectionRecordsResponse {})
}

pub async fn fork_collection(
&mut self,
ForkCollectionRequest {
source_collection_id,
target_collection_name,
..
}: ForkCollectionRequest,
) -> Result<ForkCollectionResponse, ForkCollectionError> {
let target_collection_id = CollectionUuid::new();
let collection_and_segments = self
.sysdb_client
.fork_collection(
source_collection_id,
target_collection_id,
target_collection_name,
)
.await?;
let collection = collection_and_segments.collection.clone();

// Update the cache.
self.collections_with_segments_provider
.set_collection_with_segments(collection_and_segments)
.await;

Ok(collection)
}

pub async fn add(
&mut self,
AddCollectionRecordsRequest {
Expand Down
72 changes: 71 additions & 1 deletion rust/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
Json, Router, ServiceExt,
};
use chroma_system::System;
use chroma_types::RawWhereFields;
use chroma_types::{
AddCollectionRecordsResponse, ChecklistResponse, Collection, CollectionConfiguration,
CollectionMetadataUpdate, CollectionUuid, CountCollectionsRequest, CountCollectionsResponse,
Expand All @@ -20,6 +19,7 @@ use chroma_types::{
UpdateCollectionConfiguration, UpdateCollectionRecordsResponse, UpdateCollectionResponse,
UpdateMetadata, UpsertCollectionRecordsResponse,
};
use chroma_types::{ForkCollectionResponse, RawWhereFields};
use mdac::{Rule, Scorecard, ScorecardTicket};
use opentelemetry::global;
use opentelemetry::metrics::{Counter, Meter};
Expand Down Expand Up @@ -103,6 +103,7 @@ pub struct Metrics {
get_collection: Counter<u64>,
update_collection: Counter<u64>,
delete_collection: Counter<u64>,
fork_collection: Counter<u64>,
collection_add: Counter<u64>,
collection_update: Counter<u64>,
collection_upsert: Counter<u64>,
Expand Down Expand Up @@ -133,6 +134,7 @@ impl Metrics {
get_collection: meter.u64_counter("get_collection").build(),
update_collection: meter.u64_counter("update_collection").build(),
delete_collection: meter.u64_counter("delete_collection").build(),
fork_collection: meter.u64_counter("fork_collection").build(),
collection_add: meter.u64_counter("collection_add").build(),
collection_update: meter.u64_counter("collection_update").build(),
collection_upsert: meter.u64_counter("collection_upsert").build(),
Expand Down Expand Up @@ -242,6 +244,10 @@ impl FrontendServer {
.put(update_collection)
.delete(delete_collection),
)
.route(
"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
post(fork_collection),
)
.route(
"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/add",
post(collection_add),
Expand Down Expand Up @@ -1106,6 +1112,70 @@ async fn delete_collection(
Ok(Json(UpdateCollectionResponse {}))
}

#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)]
pub struct ForkCollectionPayload {
pub new_name: String,
}

/// Forks an existing collection.
#[utoipa::path(
post,
path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
request_body = ForkCollectionPayload,
responses(
(status = 200, description = "Collection forked successfully", body = ForkCollectionResponse),
(status = 401, description = "Unauthorized", body = ErrorResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 500, description = "Server error", body = ErrorResponse)
),
params(
("tenant" = String, Path, description = "Tenant ID"),
("database" = String, Path, description = "Database name"),
("collection_id" = String, Path, description = "UUID of the collection to update")
)
)]
async fn fork_collection(
headers: HeaderMap,
Path((tenant, database, collection_id)): Path<(String, String, String)>,
State(mut server): State<FrontendServer>,
Json(payload): Json<ForkCollectionPayload>,
) -> Result<Json<ForkCollectionResponse>, ServerError> {
server.metrics.fork_collection.add(
1,
&[
KeyValue::new("tenant", tenant.clone()),
KeyValue::new("collection_id", collection_id.clone()),
],
);
tracing::info!(
"Forking collection [{collection_id}] in database [{database}] for tenant [{tenant}]"
);
server
.authenticate_and_authorize(
&headers,
AuthzAction::ForkCollection,
AuthzResource {
tenant: Some(tenant.clone()),
database: Some(database.clone()),
collection: Some(collection_id.clone()),
},
)
.await?;
let _guard =
server.scorecard_request(&["op:fork_collection", format!("tenant:{}", tenant).as_str()]);
let collection_id =
CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?;

let request = chroma_types::ForkCollectionRequest::try_new(
tenant,
database,
collection_id,
payload.new_name,
)?;

Ok(Json(server.frontend.fork_collection(request).await?))
}

#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
pub struct AddCollectionRecordsPayload {
ids: Vec<String>,
Expand Down
73 changes: 72 additions & 1 deletion rust/sysdb/src/sysdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use chroma_types::{
};
use chroma_types::{
Collection, CollectionConversionError, CollectionUuid, FlushCompactionResponse,
FlushCompactionResponseConversionError, Segment, SegmentConversionError, SegmentScope, Tenant,
FlushCompactionResponseConversionError, ForkCollectionError, Segment, SegmentConversionError,
SegmentScope, Tenant,
};
use std::collections::HashMap;
use std::fmt::Debug;
Expand Down Expand Up @@ -314,6 +315,27 @@ impl SysDb {
}
}

pub async fn fork_collection(
&mut self,
source_collection_id: CollectionUuid,
target_collection_id: CollectionUuid,
target_collection_name: String,
) -> Result<CollectionAndSegments, ForkCollectionError> {
match self {
SysDb::Grpc(grpc_sys_db) => {
grpc_sys_db
.fork_collection(
source_collection_id,
target_collection_id,
target_collection_name,
)
.await
}
SysDb::Sqlite(_) => Err(ForkCollectionError::Local),
SysDb::Test(_) => Err(ForkCollectionError::Local),
}
}

pub async fn get_collections_to_gc(
&mut self,
cutoff_time: Option<SystemTime>,
Expand Down Expand Up @@ -896,6 +918,55 @@ impl GrpcSysDb {
Ok(())
}

pub async fn fork_collection(
&mut self,
source_collection_id: CollectionUuid,
target_collection_id: CollectionUuid,
target_collection_name: String,
) -> Result<CollectionAndSegments, ForkCollectionError> {
let res = self
.client
.fork_collection(chroma_proto::ForkCollectionRequest {
source_collection_id: source_collection_id.0.to_string(),
target_collection_id: target_collection_id.0.to_string(),
target_collection_name: target_collection_name.clone(),
})
.await
.map_err(|err| match err.code() {
Code::AlreadyExists => ForkCollectionError::AlreadyExists(target_collection_name),
Code::NotFound => ForkCollectionError::NotFound(source_collection_id.0.to_string()),
_ => ForkCollectionError::Internal(err.into()),
})?
.into_inner();
let raw_segment_counts = res.segments.len();
let mut segment_map: HashMap<_, _> = res
.segments
.into_iter()
.map(|seg| (seg.scope(), seg))
.collect();
if segment_map.len() < raw_segment_counts {
return Err(ForkCollectionError::DuplicateSegment);
}
Ok(CollectionAndSegments {
collection: res
.collection
.ok_or(ForkCollectionError::Field("collection".to_string()))?
.try_into()?,
metadata_segment: segment_map
.remove(&chroma_proto::SegmentScope::Metadata)
.ok_or(ForkCollectionError::Field("metadata".to_string()))?
.try_into()?,
record_segment: segment_map
.remove(&chroma_proto::SegmentScope::Record)
.ok_or(ForkCollectionError::Field("record".to_string()))?
.try_into()?,
vector_segment: segment_map
.remove(&chroma_proto::SegmentScope::Vector)
.ok_or(ForkCollectionError::Field("vector".to_string()))?
.try_into()?,
})
}

pub async fn get_collections_to_gc(
&mut self,
cutoff_time: Option<SystemTime>,
Expand Down
Loading
Loading