Skip to content

Commit 205c06c

Browse files
committed
anton/chunker
1 parent 4f06210 commit 205c06c

File tree

3 files changed

+207
-3
lines changed

3 files changed

+207
-3
lines changed

chromadb/api/types.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
144144

145145
Embeddable = Union[Documents, Images]
146146
D = TypeVar("D", bound=Embeddable, contravariant=True)
147-
147+
C = TypeVar("C", bound=Embeddable) # C is for chunkable
148148

149149
Loadable = List[Optional[Image]]
150150
L = TypeVar("L", covariant=True, bound=Loadable)
@@ -197,8 +197,10 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
197197

198198
setattr(cls, "__call__", __call__)
199199

200-
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
201-
return retry(**retry_kwargs)(self.__call__)(input)
200+
def embed_with_retries(
201+
self, input: D, **retry_kwargs: Dict[Any, Any]
202+
) -> Embeddings:
203+
return retry(**retry_kwargs)(self.__call__)(input) # type: ignore
202204

203205

204206
def validate_embedding_function(
@@ -222,6 +224,12 @@ def __call__(self, uris: URIs) -> L:
222224
...
223225

224226

227+
class Chunker(Protocol[C]):
228+
# A chunker splits each item in a list of items into one or more chunks
229+
def __call__(self, input: C, **kwargs: Any) -> List[C]:
230+
...
231+
232+
225233
def validate_ids(ids: IDs) -> IDs:
226234
"""Validates ids to ensure it is a list of strings"""
227235
if not isinstance(ids, list):

chromadb/utils/chunkers/__init__.py

Whitespace-only changes.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import re
2+
from typing import Iterable, Literal, Optional, List, Union, Any
3+
from chromadb.api.types import Chunker, Document, Documents
4+
5+
import logging
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class DefaultTextChunker(Chunker[Documents]):
11+
def __init__(self, max_chunk_size: int = 1024, chunk_overlap: int = 0):
12+
self.max_chunk_size = max_chunk_size
13+
self.chunk_overlap = chunk_overlap
14+
15+
def _split_text_with_regex(
16+
self,
17+
text: str,
18+
separator: str,
19+
keep_separator: Union[bool, Literal["start", "end"]],
20+
) -> List[str]:
21+
# Now that we have the separator, split the text
22+
if separator:
23+
if keep_separator:
24+
# The parentheses in the pattern keep the delimiters in the result.
25+
_splits = re.split(f"({separator})", text)
26+
splits = (
27+
(
28+
[
29+
_splits[i] + _splits[i + 1]
30+
for i in range(0, len(_splits) - 1, 2)
31+
]
32+
)
33+
if keep_separator == "end"
34+
else (
35+
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
36+
)
37+
)
38+
if len(_splits) % 2 == 0:
39+
splits += _splits[-1:]
40+
splits = (
41+
(splits + [_splits[-1]])
42+
if keep_separator == "end"
43+
else ([_splits[0]] + splits)
44+
)
45+
else:
46+
splits = re.split(separator, text)
47+
else:
48+
splits = list(text)
49+
return [s for s in splits if s != ""]
50+
51+
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
52+
text = separator.join(docs)
53+
text = text.strip()
54+
if text == "":
55+
return None
56+
else:
57+
return text
58+
59+
def _merge_splits(
60+
self,
61+
splits: Iterable[str],
62+
separator: str,
63+
max_chunk_size: int,
64+
chunk_overlap: int,
65+
) -> List[str]:
66+
# We now want to combine these smaller pieces into medium size
67+
# chunks to send to the LLM.
68+
separator_len = len(separator)
69+
70+
docs = []
71+
current_doc: List[str] = []
72+
total = 0
73+
for d in splits:
74+
_len = len(d)
75+
if (
76+
total + _len + (separator_len if len(current_doc) > 0 else 0)
77+
> max_chunk_size
78+
):
79+
if total > max_chunk_size:
80+
logger.warning(
81+
f"Created a chunk of size {total}, "
82+
f"which is longer than the specified {max_chunk_size}"
83+
)
84+
if len(current_doc) > 0:
85+
doc = self._join_docs(current_doc, separator)
86+
if doc is not None:
87+
docs.append(doc)
88+
# Keep on popping if:
89+
# - we have a larger chunk than in the chunk overlap
90+
# - or if we still have any chunks and the length is long
91+
while total > chunk_overlap or (
92+
total + _len + (separator_len if len(current_doc) > 0 else 0)
93+
> max_chunk_size
94+
and total > 0
95+
):
96+
total -= len(current_doc[0]) + (
97+
separator_len if len(current_doc) > 1 else 0
98+
)
99+
current_doc = current_doc[1:]
100+
current_doc.append(d)
101+
total += _len + (separator_len if len(current_doc) > 1 else 0)
102+
doc = self._join_docs(current_doc, separator)
103+
if doc is not None:
104+
docs.append(doc)
105+
return docs
106+
107+
def _split_document(
108+
self,
109+
document: Document,
110+
separators: List[str],
111+
max_chunk_size: int,
112+
chunk_overlap: int,
113+
keep_separator: Union[bool, Literal["start", "end"]],
114+
) -> Documents:
115+
"""Split incoming text and return chunks."""
116+
final_chunks = []
117+
# Get appropriate separator to use
118+
separator = separators[-1]
119+
new_separators = []
120+
for i, _s in enumerate(separators):
121+
_separator = re.escape(_s)
122+
if _s == "":
123+
separator = _s
124+
break
125+
if re.search(_separator, document):
126+
separator = _s
127+
new_separators = separators[i + 1 :]
128+
break
129+
130+
_separator = re.escape(separator)
131+
splits = self._split_text_with_regex(document, _separator, keep_separator)
132+
133+
# Now go merging things, recursively splitting longer texts.
134+
_good_splits = []
135+
_separator = "" if keep_separator else separator
136+
for s in splits:
137+
if len(s) < max_chunk_size:
138+
_good_splits.append(s)
139+
else:
140+
if _good_splits:
141+
merged_text = self._merge_splits(
142+
splits=_good_splits,
143+
separator=_separator,
144+
max_chunk_size=max_chunk_size,
145+
chunk_overlap=chunk_overlap,
146+
)
147+
final_chunks.extend(merged_text)
148+
_good_splits = []
149+
if not new_separators:
150+
final_chunks.append(s)
151+
else:
152+
other_info = self._split_document(
153+
document=s,
154+
separators=new_separators,
155+
max_chunk_size=max_chunk_size,
156+
chunk_overlap=chunk_overlap,
157+
keep_separator=keep_separator,
158+
)
159+
final_chunks.extend(other_info)
160+
if _good_splits:
161+
merged_text = self._merge_splits(
162+
splits=_good_splits,
163+
separator=_separator,
164+
max_chunk_size=max_chunk_size,
165+
chunk_overlap=chunk_overlap,
166+
)
167+
final_chunks.extend(merged_text)
168+
return final_chunks
169+
170+
def __call__(
171+
self,
172+
input: Documents,
173+
**kwargs: Any,
174+
) -> List[Documents]:
175+
max_chunk_size = kwargs.get("max_chunk_size", None)
176+
chunk_overlap = kwargs.get("chunk_overlap", None)
177+
separators = kwargs.get("separators", None)
178+
179+
if max_chunk_size is None:
180+
max_chunk_size = self.max_chunk_size
181+
if chunk_overlap is None:
182+
chunk_overlap = self.chunk_overlap
183+
184+
if separators is None:
185+
separators = ["\n\n", "\n", ".", " ", ""]
186+
187+
return [
188+
self._split_document(
189+
document=doc,
190+
separators=separators,
191+
max_chunk_size=max_chunk_size,
192+
chunk_overlap=chunk_overlap,
193+
keep_separator="end",
194+
)
195+
for doc in input
196+
]

0 commit comments

Comments
 (0)