Skip to content

[ENH] Wire up collection forking for RFE #4309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions rust/frontend/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ pub enum AuthzAction {
CreateCollection,
GetOrCreateCollection,
GetCollection,
DeleteCollection,
UpdateCollection,
DeleteCollection,
ForkCollection,
Add,
Delete,
Get,
Expand All @@ -47,8 +48,9 @@ impl Display for AuthzAction {
AuthzAction::CreateCollection => write!(f, "db:create_collection"),
AuthzAction::GetOrCreateCollection => write!(f, "db:get_or_create_collection"),
AuthzAction::GetCollection => write!(f, "collection:get_collection"),
AuthzAction::DeleteCollection => write!(f, "collection:delete_collection"),
AuthzAction::UpdateCollection => write!(f, "collection:update_collection"),
AuthzAction::DeleteCollection => write!(f, "collection:delete_collection"),
AuthzAction::ForkCollection => write!(f, "collection:fork_collection"),
AuthzAction::Add => write!(f, "collection:add"),
AuthzAction::Delete => write!(f, "collection:delete"),
AuthzAction::Get => write!(f, "collection:get"),
Expand Down
74 changes: 43 additions & 31 deletions rust/frontend/src/get_collection_with_segments_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,43 +151,55 @@ impl CollectionsWithSegmentsProvider {
return Ok(collection_and_segments_with_ttl.collection_and_segments);
}
}
// We acquire a lock to prevent the sysdb from experiencing a thundering herd.
// This can happen when a large number of threads try to get the same collection
// at the same time.
let _guard = self.sysdb_rpc_lock.lock(&collection_id).await;
// Double checked locking pattern to avoid lock contention in the
// happy path when the collection is already cached.
if let Some(collection_and_segments_with_ttl) = self
.collections_with_segments_cache
.get(&collection_id)
.await?
{
if collection_and_segments_with_ttl.expires_at
> SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")

let collection_and_segments_sysdb = {
// We acquire a lock to prevent the sysdb from experiencing a thundering herd.
// This can happen when a large number of threads try to get the same collection
// at the same time.
let _guard = self.sysdb_rpc_lock.lock(&collection_id).await;
// Double checked locking pattern to avoid lock contention in the
// happy path when the collection is already cached.
if let Some(collection_and_segments_with_ttl) = self
.collections_with_segments_cache
.get(&collection_id)
.await?
{
return Ok(collection_and_segments_with_ttl.collection_and_segments);
if collection_and_segments_with_ttl.expires_at
> SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")
{
return Ok(collection_and_segments_with_ttl.collection_and_segments);
}
}
}
tracing::info!("Cache miss for collection {}", collection_id);
let collection_and_segments_sysdb = self
.sysdb_client
.get_collection_with_segments(collection_id)
.await?;
let collection_and_segments_sysdb_with_ttl = CollectionAndSegmentsWithTtl {
collection_and_segments: collection_and_segments_sysdb.clone(),
expires_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")
+ Duration::from_secs(self.cache_ttl_secs as u64), // Cache for 1 minute
tracing::info!("Cache miss for collection {}", collection_id);
self.sysdb_client
.get_collection_with_segments(collection_id)
.await?
};

self.set_collection_with_segments(collection_and_segments_sysdb.clone())
.await;
Ok(collection_and_segments_sysdb)
}

pub(crate) async fn set_collection_with_segments(
&mut self,
collection_and_segments: CollectionAndSegments,
) {
// Insert only if the collection dimension is set.
if collection_and_segments_sysdb.collection.dimension.is_some() {
if collection_and_segments.collection.dimension.is_some() {
let collection_id = collection_and_segments.collection.collection_id;
let collection_and_segments_with_ttl = CollectionAndSegmentsWithTtl {
collection_and_segments,
expires_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Do not deploy before UNIX epoch")
+ Duration::from_secs(self.cache_ttl_secs as u64), // Cache for 1 minute
};
self.collections_with_segments_cache
.insert(collection_id, collection_and_segments_sysdb_with_ttl)
.insert(collection_id, collection_and_segments_with_ttl)
.await;
}
Ok(collection_and_segments_sysdb)
}
}
84 changes: 77 additions & 7 deletions rust/frontend/src/impls/service_based_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ use chroma_types::{
CreateTenantError, CreateTenantRequest, CreateTenantResponse, DeleteCollectionError,
DeleteCollectionRecordsError, DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse,
DeleteCollectionRequest, DeleteDatabaseError, DeleteDatabaseRequest, DeleteDatabaseResponse,
GetCollectionError, GetCollectionRequest, GetCollectionResponse, GetCollectionsError,
GetDatabaseError, GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse,
GetTenantError, GetTenantRequest, GetTenantResponse, HealthCheckResponse, HeartbeatError,
HeartbeatResponse, Include, KnnIndex, ListCollectionsRequest, ListCollectionsResponse,
ListDatabasesError, ListDatabasesRequest, ListDatabasesResponse, Operation, OperationRecord,
QueryError, QueryRequest, QueryResponse, ResetError, ResetResponse, Segment, SegmentScope,
SegmentType, SegmentUuid, UpdateCollectionError, UpdateCollectionRecordsError,
ForkCollectionError, ForkCollectionRequest, ForkCollectionResponse, GetCollectionError,
GetCollectionRequest, GetCollectionResponse, GetCollectionsError, GetDatabaseError,
GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse, GetTenantError,
GetTenantRequest, GetTenantResponse, HealthCheckResponse, HeartbeatError, HeartbeatResponse,
Include, KnnIndex, ListCollectionsRequest, ListCollectionsResponse, ListDatabasesError,
ListDatabasesRequest, ListDatabasesResponse, Operation, OperationRecord, QueryError,
QueryRequest, QueryResponse, ResetError, ResetResponse, Segment, SegmentScope, SegmentType,
SegmentUuid, UpdateCollectionError, UpdateCollectionRecordsError,
UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse, UpdateCollectionRequest,
UpdateCollectionResponse, UpsertCollectionRecordsError, UpsertCollectionRecordsRequest,
UpsertCollectionRecordsResponse, VectorIndexConfiguration, Where,
Expand All @@ -43,6 +44,7 @@ use super::utils::to_records;

#[derive(Debug)]
struct Metrics {
fork_retries_counter: Counter<u64>,
delete_retries_counter: Counter<u64>,
count_retries_counter: Counter<u64>,
query_retries_counter: Counter<u64>,
Expand Down Expand Up @@ -72,11 +74,13 @@ impl ServiceBasedFrontend {
default_knn_index: KnnIndex,
) -> Self {
let meter = global::meter("chroma");
let fork_retries_counter = meter.u64_counter("fork_retries").build();
let delete_retries_counter = meter.u64_counter("delete_retries").build();
let count_retries_counter = meter.u64_counter("count_retries").build();
let query_retries_counter = meter.u64_counter("query_retries").build();
let get_retries_counter = meter.u64_counter("query_retries").build();
let metrics = Arc::new(Metrics {
fork_retries_counter,
delete_retries_counter,
count_retries_counter,
query_retries_counter,
Expand Down Expand Up @@ -542,6 +546,72 @@ impl ServiceBasedFrontend {
Ok(DeleteCollectionRecordsResponse {})
}

pub async fn retryable_fork(
&mut self,
ForkCollectionRequest {
source_collection_id,
target_collection_name,
..
}: ForkCollectionRequest,
) -> Result<ForkCollectionResponse, ForkCollectionError> {
let target_collection_id = CollectionUuid::new();
let collection_and_segments = self
.sysdb_client
.fork_collection(
source_collection_id,
// TODO: Update this when wiring up log fork
0,
0,
target_collection_id,
target_collection_name,
)
.await?;
let collection = collection_and_segments.collection.clone();

// Update the cache.
self.collections_with_segments_provider
.set_collection_with_segments(collection_and_segments)
.await;

Ok(collection)
}

pub async fn fork_collection(
&mut self,
request: ForkCollectionRequest,
) -> Result<ForkCollectionResponse, ForkCollectionError> {
let retries = Arc::new(AtomicUsize::new(0));
let fork_to_retry = || {
let mut self_clone = self.clone();
let request_clone = request.clone();
async move { self_clone.retryable_fork(request_clone).await }
};

let res = fork_to_retry
.retry(self.collections_with_segments_provider.get_retry_backoff())
// NOTE: Transport level errors will manifest as unknown errors, and they should also be retried
.when(|e| {
matches!(
e.code(),
ErrorCodes::FailedPrecondition | ErrorCodes::NotFound | ErrorCodes::Unknown
)
})
.notify(|_, _| {
let retried = retries.fetch_add(1, Ordering::Relaxed);
if retried > 0 {
tracing::info!(
"Retrying fork() request for collection {}",
request.source_collection_id
);
}
})
.await;
self.metrics
.fork_retries_counter
.add(retries.load(Ordering::Relaxed) as u64, &[]);
res
}

pub async fn add(
&mut self,
AddCollectionRecordsRequest {
Expand Down
72 changes: 71 additions & 1 deletion rust/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
Json, Router, ServiceExt,
};
use chroma_system::System;
use chroma_types::RawWhereFields;
use chroma_types::{
AddCollectionRecordsResponse, ChecklistResponse, Collection, CollectionConfiguration,
CollectionMetadataUpdate, CollectionUuid, CountCollectionsRequest, CountCollectionsResponse,
Expand All @@ -20,6 +19,7 @@ use chroma_types::{
UpdateCollectionConfiguration, UpdateCollectionRecordsResponse, UpdateCollectionResponse,
UpdateMetadata, UpsertCollectionRecordsResponse,
};
use chroma_types::{ForkCollectionResponse, RawWhereFields};
use mdac::{Rule, Scorecard, ScorecardTicket};
use opentelemetry::global;
use opentelemetry::metrics::{Counter, Meter};
Expand Down Expand Up @@ -103,6 +103,7 @@ pub struct Metrics {
get_collection: Counter<u64>,
update_collection: Counter<u64>,
delete_collection: Counter<u64>,
fork_collection: Counter<u64>,
collection_add: Counter<u64>,
collection_update: Counter<u64>,
collection_upsert: Counter<u64>,
Expand Down Expand Up @@ -133,6 +134,7 @@ impl Metrics {
get_collection: meter.u64_counter("get_collection").build(),
update_collection: meter.u64_counter("update_collection").build(),
delete_collection: meter.u64_counter("delete_collection").build(),
fork_collection: meter.u64_counter("fork_collection").build(),
collection_add: meter.u64_counter("collection_add").build(),
collection_update: meter.u64_counter("collection_update").build(),
collection_upsert: meter.u64_counter("collection_upsert").build(),
Expand Down Expand Up @@ -242,6 +244,10 @@ impl FrontendServer {
.put(update_collection)
.delete(delete_collection),
)
.route(
"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
post(fork_collection),
)
.route(
"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/add",
post(collection_add),
Expand Down Expand Up @@ -1106,6 +1112,70 @@ async fn delete_collection(
Ok(Json(UpdateCollectionResponse {}))
}

#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)]
pub struct ForkCollectionPayload {
pub new_name: String,
}

/// Forks an existing collection.
#[utoipa::path(
post,
path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
request_body = ForkCollectionPayload,
responses(
(status = 200, description = "Collection forked successfully", body = ForkCollectionResponse),
(status = 401, description = "Unauthorized", body = ErrorResponse),
(status = 404, description = "Collection not found", body = ErrorResponse),
(status = 500, description = "Server error", body = ErrorResponse)
),
params(
("tenant" = String, Path, description = "Tenant ID"),
("database" = String, Path, description = "Database name"),
("collection_id" = String, Path, description = "UUID of the collection to update")
)
)]
async fn fork_collection(
headers: HeaderMap,
Path((tenant, database, collection_id)): Path<(String, String, String)>,
State(mut server): State<FrontendServer>,
Json(payload): Json<ForkCollectionPayload>,
) -> Result<Json<ForkCollectionResponse>, ServerError> {
server.metrics.fork_collection.add(
1,
&[
KeyValue::new("tenant", tenant.clone()),
KeyValue::new("collection_id", collection_id.clone()),
],
);
tracing::info!(
"Forking collection [{collection_id}] in database [{database}] for tenant [{tenant}]"
);
server
.authenticate_and_authorize(
&headers,
AuthzAction::ForkCollection,
AuthzResource {
tenant: Some(tenant.clone()),
database: Some(database.clone()),
collection: Some(collection_id.clone()),
},
)
.await?;
let _guard =
server.scorecard_request(&["op:fork_collection", format!("tenant:{}", tenant).as_str()]);
let collection_id =
CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?;

let request = chroma_types::ForkCollectionRequest::try_new(
tenant,
database,
collection_id,
payload.new_name,
)?;

Ok(Json(server.frontend.fork_collection(request).await?))
}

#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)]
pub struct AddCollectionRecordsPayload {
ids: Vec<String>,
Expand Down
Loading