Skip to content

[ENH] Wire up collection forking for python #4314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: sicheng/04-17-_enh_wire_up_collection_forking_for_client
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,16 @@ def _modify(
) -> None:
pass

@abstractmethod
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass

@abstractmethod
@override
def _count(
Expand Down
10 changes: 10 additions & 0 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,16 @@ async def _modify(
) -> None:
pass

@abstractmethod
async def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
pass

@abstractmethod
@override
async def _count(
Expand Down
17 changes: 17 additions & 0 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,23 @@ async def _modify(
},
)

@trace_method("AsyncFastAPI._fork", OpenTelemetryGranularity.OPERATION)
@override
async def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
json={"new_name": new_name},
)
model = CollectionModel.from_json(resp_json)
return model

@trace_method("AsyncFastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
async def delete_collection(
Expand Down
18 changes: 18 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,24 @@ def _modify(
},
)

@trace_method("FastAPI._fork", OpenTelemetryGranularity.OPERATION)
@override
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
"""Forks a collection"""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
json={"new_name": new_name},
)
model = CollectionModel.from_json(resp_json)
return model

@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
def delete_collection(
Expand Down
26 changes: 26 additions & 0 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,32 @@ async def modify(

self._update_model_after_modify_success(name, metadata, configuration)

async def fork(
self,
new_name: str,
) -> "AsyncCollection":
"""Fork the current collection under a new name. The returning collection should contain identical data to the current collection.
This is an experimental API that only works for Hosted Chroma for now.
Args:
new_name: The name of the new collection.
Returns:
Collection: A new collection with the specified name and containing identical data to the current collection.
"""
model = await self._client._fork(
collection_id=self.id,
new_name=new_name,
tenant=self.tenant,
database=self.database,
)
return AsyncCollection(
client=self._client,
model=model,
embedding_function=self._embedding_function,
data_loader=self._data_loader
)

async def update(
self,
ids: OneOrMany[ID],
Expand Down
26 changes: 26 additions & 0 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,32 @@ def modify(

self._update_model_after_modify_success(name, metadata, configuration)

def fork(
self,
new_name: str,
) -> "Collection":
"""Fork the current collection under a new name. The returning collection should contain identical data to the current collection.
This is an experimental API that only works for Hosted Chroma for now.
Args:
new_name: The name of the new collection.
Returns:
Collection: A new collection with the specified name and containing identical data to the current collection.
"""
model = self._client._fork(
collection_id=self.id,
new_name=new_name,
tenant=self.tenant,
database=self.database,
)
return Collection(
client=self._client,
model=model,
embedding_function=self._embedding_function,
data_loader=self._data_loader
)

def update(
self,
ids: OneOrMany[ID],
Expand Down
10 changes: 10 additions & 0 deletions chromadb/api/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ def _modify(
str(id), new_name, new_metadata, new_configuration_json_str
)

@override
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
raise NotImplementedError("Collection forking is not implemented for Local Chroma")

@override
def _count(
self,
Expand Down
10 changes: 10 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ def _modify(
elif new_configuration:
self._sysdb.update_collection(id, configuration=new_configuration)

@override
def _fork(
self,
collection_id: UUID,
new_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
raise NotImplementedError("Collection forking is not implemented for SegmentAPI")

@trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
Expand Down
Loading