Skip to content

[DRAFT] Chunker #2260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:

Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)

C = TypeVar("C", bound=Embeddable) # C is for chunkable

Loadable = List[Optional[Image]]
L = TypeVar("L", covariant=True, bound=Loadable)
Expand Down Expand Up @@ -197,8 +197,10 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:

setattr(cls, "__call__", __call__)

def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)
def embed_with_retries(
self, input: D, **retry_kwargs: Dict[Any, Any]
) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input) # type: ignore


def validate_embedding_function(
Expand All @@ -222,6 +224,12 @@ def __call__(self, uris: URIs) -> L:
...


class Chunker(Protocol[C]):
# A chunker splits each item in a list of items into one or more chunks
def __call__(self, input: C, **kwargs: Any) -> List[C]:
...


def validate_ids(ids: IDs) -> IDs:
"""Validates ids to ensure it is a list of strings"""
if not isinstance(ids, list):
Expand Down
Empty file.
196 changes: 196 additions & 0 deletions chromadb/utils/chunkers/default_chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import re
from typing import Iterable, Literal, Optional, List, Union, Any
from chromadb.api.types import Chunker, Document, Documents

import logging

logger = logging.getLogger(__name__)


class DefaultTextChunker(Chunker[Documents]):
def __init__(self, max_chunk_size: int = 1024, chunk_overlap: int = 0):
self.max_chunk_size = max_chunk_size
self.chunk_overlap = chunk_overlap

def _split_text_with_regex(
self,
text: str,
separator: str,
keep_separator: Union[bool, Literal["start", "end"]],
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
(
[
_splits[i] + _splits[i + 1]
for i in range(0, len(_splits) - 1, 2)
]
)
if keep_separator == "end"
else (
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
)
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]

def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text

def _merge_splits(
self,
splits: Iterable[str],
separator: str,
max_chunk_size: int,
chunk_overlap: int,
) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = len(separator)

docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = len(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> max_chunk_size
):
if total > max_chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {max_chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> max_chunk_size
and total > 0
):
total -= len(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs

def _split_document(
self,
document: Document,
separators: List[str],
max_chunk_size: int,
chunk_overlap: int,
keep_separator: Union[bool, Literal["start", "end"]],
) -> Documents:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, document):
separator = _s
new_separators = separators[i + 1 :]
break

_separator = re.escape(separator)
splits = self._split_text_with_regex(document, _separator, keep_separator)

# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if keep_separator else separator
for s in splits:
if len(s) < max_chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(
splits=_good_splits,
separator=_separator,
max_chunk_size=max_chunk_size,
chunk_overlap=chunk_overlap,
)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_document(
document=s,
separators=new_separators,
max_chunk_size=max_chunk_size,
chunk_overlap=chunk_overlap,
keep_separator=keep_separator,
)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(
splits=_good_splits,
separator=_separator,
max_chunk_size=max_chunk_size,
chunk_overlap=chunk_overlap,
)
final_chunks.extend(merged_text)
return final_chunks

def __call__(
self,
input: Documents,
**kwargs: Any,
) -> List[Documents]:
max_chunk_size = kwargs.get("max_chunk_size", None)
chunk_overlap = kwargs.get("chunk_overlap", None)
separators = kwargs.get("separators", None)

if max_chunk_size is None:
max_chunk_size = self.max_chunk_size
if chunk_overlap is None:
chunk_overlap = self.chunk_overlap

if separators is None:
separators = ["\n\n", "\n", ".", " ", ""]

return [
self._split_document(
document=doc,
separators=separators,
max_chunk_size=max_chunk_size,
chunk_overlap=chunk_overlap,
keep_separator="end",
)
for doc in input
]
Loading