Skip to content

Commit ba0097d

Browse files
authored
feat: Support numpy-based Chroma embeddings (#177)
Closes #176
1 parent b290ea6 commit ba0097d

4 files changed

Lines changed: 29 additions & 5 deletions

File tree

chroma_dp/__init__.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
try:
24
import chromadb # noqa: F401
35
except ImportError:
@@ -19,7 +21,7 @@
1921
)
2022

2123
from chromadb.api.types import Embedding
22-
from pydantic import BaseModel, Field
24+
from pydantic import BaseModel, Field, ConfigDict
2325

2426
C = TypeVar("C")
2527

@@ -31,17 +33,22 @@ class ResourceFeature(BaseModel, Generic[C]):
3133

3234
Metadata = Dict[str, Union[str, int, float, bool]]
3335

36+
EmbeddingWrapper = Union[Embedding, np.ndarray]
37+
3438

3539
class EmbeddableResource(BaseModel):
40+
model_config = ConfigDict(arbitrary_types_allowed=True)
3641
id: Optional[str] = Field(None, description="Document ID")
3742
metadata: Optional[Metadata] = Field(None, description="Document metadata")
38-
embedding: Optional[Embedding] = Field(None, description="Document embedding")
43+
embedding: Optional[EmbeddingWrapper] = Field(
44+
None, description="Document embedding"
45+
)
3946

4047
@staticmethod
4148
def resource_features() -> Sequence[ResourceFeature]:
4249
return [
43-
ResourceFeature[Embedding](
44-
feature_name="embedding", feature_type=Embedding
50+
ResourceFeature[EmbeddingWrapper](
51+
feature_name="embedding", feature_type=EmbeddingWrapper
4552
),
4653
ResourceFeature[Metadata](feature_name="metadata", eature_type=Metadata),
4754
ResourceFeature[str](feature_name="id", feature_type=str),
@@ -58,6 +65,13 @@ def resource_features() -> Sequence[ResourceFeature]:
5865
*super().resource_features(),
5966
]
6067

68+
def model_dump(self, **kwargs):
69+
# Convert NumPy arrays to lists before dumping
70+
data = super().model_dump(**kwargs)
71+
if isinstance(data["embedding"], np.ndarray):
72+
data["embedding"] = data["embedding"].tolist()
73+
return data
74+
6175

6276
D = TypeVar("D", bound=EmbeddableResource, contravariant=True)
6377

chroma_dp/chroma/chroma_import.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import orjson as json
23
import sys
34
import uuid
@@ -28,13 +29,19 @@ def add_to_col(
2829
ef: EmbeddingFunction = None,
2930
) -> None:
3031
try:
32+
if "embeddings" in batch and len(batch["embeddings"]) > 0:
33+
batch["embeddings"] = [
34+
e.tolist() if isinstance(e, np.ndarray) else e
35+
for e in batch["embeddings"]
36+
]
3137
if ef:
3238
batch["embeddings"] = ef(batch["documents"])
3339
if upsert:
3440
col.upsert(**batch)
3541
else:
3642
col.add(**batch)
3743
except Exception as e:
44+
print(e, file=sys.stderr)
3845
raise e
3946

4047

chroma_dp/utils/chroma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional, Dict, Any, List, cast
44
from urllib.parse import urlparse, parse_qs
55
import chromadb
6+
import numpy as np
67
from chromadb import ClientAPI, GetResult
78
from chromadb.api.models.Collection import Collection
89

@@ -223,6 +224,8 @@ def remap_features(
223224
_doc = in_dict[doc_feature]
224225
_embed = in_dict[embed_feature] if embed_feature else None
225226
_id = in_dict[id_feature] if id_feature else None
227+
if not isinstance(_embed, np.ndarray):
228+
_embed = np.array(_embed)
226229
return EmbeddableTextResource(
227230
text_chunk=_doc, embedding=_embed, metadata=_meta, id=_id
228231
)

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ cdp ds-get "hf://tazarov/chroma-qna?split=train" | cdp import "http://localhost:
6767
**Importing from a directory with PDF files into Local Persisted Chroma DB:**
6868

6969
```bash
70-
cdp ds-get sample-data/papers/ | grep "2401.02412.pdf" | head -1 | cdp chunk -s 500 | cdp embed --ef default | cdp import "file://chroma-data/my-pdfs" --upsert --create
70+
cdp imp pdf sample-data/papers/ | grep "2401.02412.pdf" | head -1 | cdp chunk -s 500 | cdp embed --ef default | cdp import "file://chroma-data/my-pdfs" --upsert --create
7171
```
7272

7373
!!! note

0 commit comments

Comments
 (0)