Skip to content

Commit 4f06210

Browse files
authored
[CLN] Make collection api object wrap model. (#2230)
## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - This is cleanup borrowed from #1491. It makes Collection.py wrap a model instead of duplicating data into collection.py and implictly treating it as a serializable model. This is a cleaner seperation of concerns. We have the model object, and an API wrapper used to manipulate it via the api. - New functionality - None ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
1 parent 960de33 commit 4f06210

File tree

5 files changed

+116
-58
lines changed

5 files changed

+116
-58
lines changed

chromadb/api/fastapi.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Optional, cast, Tuple
44
from typing import Sequence
55
from uuid import UUID
6-
76
import requests
87
from overrides import override
98

@@ -41,6 +40,7 @@
4140
)
4241
from chromadb.telemetry.product import ProductTelemetryClient
4342
from urllib.parse import urlparse, urlunparse, quote
43+
from chromadb.types import Collection as CollectionModel
4444

4545
logger = logging.getLogger(__name__)
4646

@@ -209,7 +209,16 @@ def list_collections(
209209
json_collections = json.loads(resp.text)
210210
collections = []
211211
for json_collection in json_collections:
212-
collections.append(Collection(self, **json_collection))
212+
model = CollectionModel(
213+
id=json_collection["id"],
214+
name=json_collection["name"],
215+
metadata=json_collection["metadata"],
216+
dimension=json_collection["dimension"],
217+
tenant=json_collection["tenant"],
218+
database=json_collection["database"],
219+
version=json_collection["version"],
220+
)
221+
collections.append(Collection(self, model=model))
213222

214223
return collections
215224

@@ -254,13 +263,20 @@ def create_collection(
254263
)
255264
raise_chroma_error(resp)
256265
resp_json = json.loads(resp.text)
257-
return Collection(
258-
client=self,
266+
model = CollectionModel(
259267
id=resp_json["id"],
260268
name=resp_json["name"],
269+
metadata=resp_json["metadata"],
270+
dimension=resp_json["dimension"],
271+
tenant=resp_json["tenant"],
272+
database=resp_json["database"],
273+
version=resp_json["version"],
274+
)
275+
return Collection(
276+
client=self,
277+
model=model,
261278
embedding_function=embedding_function,
262279
data_loader=data_loader,
263-
metadata=resp_json["metadata"],
264280
)
265281

266282
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
@@ -288,13 +304,20 @@ def get_collection(
288304
)
289305
raise_chroma_error(resp)
290306
resp_json = json.loads(resp.text)
307+
model = CollectionModel(
308+
id=resp_json["id"],
309+
name=resp_json["name"],
310+
metadata=resp_json["metadata"],
311+
dimension=resp_json["dimension"],
312+
tenant=resp_json["tenant"],
313+
database=resp_json["database"],
314+
version=resp_json["version"],
315+
)
291316
return Collection(
292317
client=self,
293-
name=resp_json["name"],
294-
id=resp_json["id"],
318+
model=model,
295319
embedding_function=embedding_function,
296320
data_loader=data_loader,
297-
metadata=resp_json["metadata"],
298321
)
299322

300323
@trace_method(

chromadb/api/models/Collection.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union
2-
1+
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union, cast
32
import numpy as np
4-
from pydantic import BaseModel, PrivateAttr
5-
63
from uuid import UUID
74
import chromadb.utils.embedding_functions as ef
8-
95
from chromadb.api.types import (
106
URI,
117
CollectionMetadata,
@@ -46,6 +42,12 @@
4642
validate_embeddings,
4743
validate_embedding_function,
4844
)
45+
46+
# TODO: We should rename the types in chromadb.types to be Models where
47+
# appropriate. This will help to distinguish between manipulation objects
48+
# which are essentially API views. And the actual data models which are
49+
# stored / retrieved / transmitted.
50+
from chromadb.types import Collection as CollectionModel
4951
import logging
5052

5153
logger = logging.getLogger(__name__)
@@ -54,33 +56,25 @@
5456
from chromadb.api import ServerAPI
5557

5658

57-
class Collection(BaseModel):
58-
name: str
59-
id: UUID
60-
metadata: Optional[CollectionMetadata] = None
61-
tenant: Optional[str] = None
62-
database: Optional[str] = None
63-
_client: "ServerAPI" = PrivateAttr()
64-
_embedding_function: Optional[EmbeddingFunction[Embeddable]] = PrivateAttr()
65-
_data_loader: Optional[DataLoader[Loadable]] = PrivateAttr()
59+
class Collection:
60+
_model: CollectionModel
61+
_client: "ServerAPI"
62+
_embedding_function: Optional[EmbeddingFunction[Embeddable]]
63+
_data_loader: Optional[DataLoader[Loadable]]
6664

6765
def __init__(
6866
self,
6967
client: "ServerAPI",
70-
name: str,
71-
id: UUID,
68+
model: CollectionModel,
7269
embedding_function: Optional[
7370
EmbeddingFunction[Embeddable]
7471
] = ef.DefaultEmbeddingFunction(), # type: ignore
7572
data_loader: Optional[DataLoader[Loadable]] = None,
76-
tenant: Optional[str] = None,
77-
database: Optional[str] = None,
78-
metadata: Optional[CollectionMetadata] = None,
7973
):
80-
super().__init__(
81-
name=name, metadata=metadata, id=id, tenant=tenant, database=database
82-
)
74+
"""Initializes a new instance of the Collection class."""
75+
8376
self._client = client
77+
self._model = model
8478

8579
# Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
8680
if embedding_function is not None:
@@ -92,6 +86,51 @@ def __init__(
9286
def __repr__(self) -> str:
9387
return f"Collection(name={self.name})"
9488

89+
# Expose the model properties as read-only properties on the Collection class
90+
91+
@property
92+
def id(self) -> UUID:
93+
return self._model["id"]
94+
95+
@property
96+
def name(self) -> str:
97+
return self._model["name"]
98+
99+
@property
100+
def metadata(self) -> CollectionMetadata:
101+
return cast(CollectionMetadata, self._model["metadata"])
102+
103+
@property
104+
def tenant(self) -> str:
105+
return self._model["tenant"]
106+
107+
@property
108+
def database(self) -> str:
109+
return self._model["database"]
110+
111+
def __eq__(self, other: object) -> bool:
112+
if not isinstance(other, Collection):
113+
return False
114+
id_match = self.id == other.id
115+
name_match = self.name == other.name
116+
metadata_match = self.metadata == other.metadata
117+
tenant_match = self.tenant == other.tenant
118+
database_match = self.database == other.database
119+
embedding_function_match = self._embedding_function == other._embedding_function
120+
data_loader_match = self._data_loader == other._data_loader
121+
return (
122+
id_match
123+
and name_match
124+
and metadata_match
125+
and tenant_match
126+
and database_match
127+
and embedding_function_match
128+
and data_loader_match
129+
)
130+
131+
def get_model(self) -> CollectionModel:
132+
return self._model
133+
95134
def count(self) -> int:
96135
"""The total number of embeddings added to the database
97136
@@ -385,11 +424,14 @@ def modify(
385424
"Changing the distance function of a collection once it is created is not supported currently."
386425
)
387426

427+
# Note there is a race condition here where the metadata can be updated
428+
# but another thread sees the cached local metadata.
429+
# TODO: fixme
388430
self._client._modify(id=self.id, new_name=name, new_metadata=metadata)
389431
if name:
390-
self.name = name
432+
self._model["name"] = name
391433
if metadata:
392-
self.metadata = metadata
434+
self._model["metadata"] = metadata
393435

394436
def update(
395437
self,

chromadb/api/segment.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,9 @@ def create_collection(
201201

202202
return Collection(
203203
client=self,
204-
id=coll["id"],
205-
name=name,
206-
metadata=coll["metadata"], # type: ignore
204+
model=coll,
207205
embedding_function=embedding_function,
208206
data_loader=data_loader,
209-
tenant=tenant,
210-
database=database,
211207
)
212208

213209
@trace_method(
@@ -260,13 +256,9 @@ def get_collection(
260256
if existing:
261257
return Collection(
262258
client=self,
263-
id=existing[0]["id"],
264-
name=existing[0]["name"],
265-
metadata=existing[0]["metadata"], # type: ignore
259+
model=existing[0],
266260
embedding_function=embedding_function,
267261
data_loader=data_loader,
268-
tenant=existing[0]["tenant"],
269-
database=existing[0]["database"],
270262
)
271263
else:
272264
raise ValueError(f"Collection {name} does not exist.")
@@ -288,11 +280,7 @@ def list_collections(
288280
collections.append(
289281
Collection(
290282
client=self,
291-
id=db_collection["id"],
292-
name=db_collection["name"],
293-
metadata=db_collection["metadata"], # type: ignore
294-
tenant=db_collection["tenant"],
295-
database=db_collection["database"],
283+
model=db_collection,
296284
)
297285
)
298286
return collections

chromadb/server/fastapi/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from typing import Any, Callable, cast, Dict, List, Sequence, Optional, Tuple
22
import fastapi
33
import orjson
4-
54
from anyio import (
65
to_thread,
76
CapacityLimiter,
87
)
98
from fastapi import FastAPI as _FastAPI, Response, Request
109
from fastapi.responses import JSONResponse, ORJSONResponse
11-
1210
from fastapi.middleware.cors import CORSMiddleware
1311
from fastapi.routing import APIRoute
1412
from fastapi import HTTPException, status
@@ -43,9 +41,7 @@
4341
UpdateEmbedding,
4442
)
4543
from starlette.datastructures import Headers
46-
4744
import logging
48-
4945
from chromadb.telemetry.product.events import ServerStartEvent
5046
from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid
5147
from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
@@ -56,6 +52,7 @@
5652
OpenTelemetryGranularity,
5753
trace_method,
5854
)
55+
from chromadb.types import Collection as CollectionModel
5956

6057
logger = logging.getLogger(__name__)
6158

@@ -92,7 +89,7 @@ async def check_http_version_middleware(
9289
return await call_next(request)
9390

9491

95-
class ChromaAPIRouter(fastapi.APIRouter):
92+
class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
9693
# A simple subclass of fastapi's APIRouter which treats URLs with a
9794
# trailing "/" the same as URLs without. Docs will only contain URLs
9895
# without trailing "/"s.
@@ -491,7 +488,7 @@ async def list_collections(
491488
offset: Optional[int] = None,
492489
tenant: str = DEFAULT_TENANT,
493490
database: str = DEFAULT_DATABASE,
494-
) -> Sequence[Collection]:
491+
) -> Sequence[CollectionModel]:
495492
(
496493
maybe_tenant,
497494
maybe_database,
@@ -507,7 +504,7 @@ async def list_collections(
507504
if maybe_database:
508505
database = maybe_database
509506

510-
return cast(
507+
api_collections = cast(
511508
Sequence[Collection],
512509
await to_thread.run_sync(
513510
self._api.list_collections,
@@ -519,6 +516,8 @@ async def list_collections(
519516
),
520517
)
521518

519+
return [c.get_model() for c in api_collections]
520+
522521
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
523522
async def count_collections(
524523
self,
@@ -557,7 +556,7 @@ async def create_collection(
557556
request: Request,
558557
tenant: str = DEFAULT_TENANT,
559558
database: str = DEFAULT_DATABASE,
560-
) -> Collection:
559+
) -> CollectionModel:
561560
def process_create_collection(
562561
request: Request, tenant: str, database: str, raw_body: bytes
563562
) -> Collection:
@@ -586,7 +585,7 @@ def process_create_collection(
586585
database=database,
587586
)
588587

589-
return cast(
588+
api_collection = cast(
590589
Collection,
591590
await to_thread.run_sync(
592591
process_create_collection,
@@ -597,6 +596,7 @@ def process_create_collection(
597596
limiter=self._capacity_limiter,
598597
),
599598
)
599+
return api_collection.get_model()
600600

601601
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
602602
async def get_collection(
@@ -605,7 +605,7 @@ async def get_collection(
605605
collection_name: str,
606606
tenant: str = DEFAULT_TENANT,
607607
database: str = DEFAULT_DATABASE,
608-
) -> Collection:
608+
) -> CollectionModel:
609609
(
610610
maybe_tenant,
611611
maybe_database,
@@ -621,7 +621,7 @@ async def get_collection(
621621
if maybe_database:
622622
database = maybe_database
623623

624-
return cast(
624+
api_collection = cast(
625625
Collection,
626626
await to_thread.run_sync(
627627
self._api.get_collection,
@@ -634,6 +634,7 @@ async def get_collection(
634634
limiter=self._capacity_limiter,
635635
),
636636
)
637+
return api_collection.get_model()
637638

638639
@trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION)
639640
async def update_collection(

0 commit comments

Comments
 (0)