-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlangchain_chunker.py
225 lines (194 loc) · 9.64 KB
/
langchain_chunker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import logging
import os
import re
from typing import Generator
from langchain_text_splitters import (
MarkdownTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
RecursiveJsonSplitter,
TextSplitter,
)
from ..exceptions import UnsupportedFormatError
from .base_chunker import BaseChunker
class LangChainChunker(BaseChunker):
"""
LangChainChunker is a class designed to split document content into chunks based on the format and
specific chunking criteria. The class leverages various LangChain splitters tailored for different
content formats, ensuring accurate and efficient processing.
Initialization:
---------------
The LangChainChunker is initialized with the following parameters:
- data (dict): A dictionary containing the document's metadata and content. See `BaseChunker` for more
information on the required and optional keys.
Attributes:
-----------
- `max_chunk_size` (int): The maximum allowed size of each chunk in tokens, derived from the `NUM_TOKENS`
environment variable (default is 2048 tokens).
- `token_overlap` (int): The number of overlapping tokens between consecutive chunks, derived from the `TOKEN_OVERLAP`
environment variable (default is 100 tokens).
- `minimum_chunk_size` (int): The minimum required size of each chunk in tokens, derived from the `MIN_CHUNK_SIZE`
environment variable (default is 100 tokens).
- `supported_formats` (dict): A dictionary mapping file extensions to their corresponding content format, used to select
the appropriate text splitter.
Methods:
--------
- `get_chunks()`: Splits the document content into chunks based on the specified format and criteria. The method first
checks if the document's format is supported, then processes the content into chunks, skipping those that don't meet the
minimum size requirement. Finally, it logs the number of chunks created and skipped.
- `_chunk_content()`: Splits the document content into chunks according to the format-specific splitting strategy. The method
identifies the format of the document and chooses the corresponding LangChain splitter (e.g., `MarkdownTextSplitter` for Markdown,
`PythonCodeTextSplitter` for Python code, and `RecursiveCharacterTextSplitter` for other formats). It yields each chunk along with
its token count.
"""
def __init__(self, data: dict):
super().__init__(data)
self.max_chunk_size = int(os.getenv("NUM_TOKENS", "2048"))
self.minimum_chunk_size = int(os.getenv("MIN_CHUNK_SIZE", "100"))
self.token_overlap = int(os.getenv("TOKEN_OVERLAP", "100"))
self.supported_formats = {
"md": "markdown",
"txt": "text",
"html": "html",
"shtml": "html",
"htm": "html",
"py": "python",
"json": "json",
"csv": "csv",
"xml": "xml",
}
def get_chunks(self) -> list[dict]:
"""
Splits the document content into chunks based on the specified format and criteria.
This method performs the following steps:
1. Checks if the file format is supported.
2. Decodes the document bytes into text.
3. Splits the text into chunks based on the format-specific splitting strategy.
Returns:
list[dict]: A list of dictionaries, each representing a chunk of the document.
"""
chunks = []
if self.extension not in self.supported_formats:
raise UnsupportedFormatError(
f"[LangChainChunker] [{self.filename}] {self.extension} format is not supported."
)
logging.info(
f"[LangChainChunker] [get_chunks] [{self.filename}] Running `get_chunks`."
)
blob_data = self.document_bytes
# Decode the bytes into text (assuming it's UTF-8 encoded)
text = self.decode_to_utf8(blob_data)
# Returns a Generator that yields the truncated chunks, chunk sizes when iterated over
text_chunks = self._chunk_content(text)
skipped_chunks = 0
chunk_id = 0
for text_chunk, num_tokens in text_chunks:
# We only keep chunks that are at least above the `minimum_chunk_size` threshold
if num_tokens >= self.minimum_chunk_size:
chunk_id += 1
chunk_dict = self._create_chunk(chunk_id=chunk_id, content=text_chunk)
chunks.append(chunk_dict)
else:
skipped_chunks += 1
logging.debug(
f"[LangChainChunker] [get_chunks] [{self.filename}] {len(chunks)} chunk(s) created."
)
if skipped_chunks > 0:
logging.debug(
f"[LangChainChunker] [get_chunks] [{self.filename}] {skipped_chunks} chunk(s) skipped."
)
return chunks
def _chunk_content(self, content: str) -> Generator[tuple[str, int], None, None]:
"""
Splits the document content into chunks based on the specified format and criteria.
The method includes the following steps:
1. Replaces HTML tables in the content with placeholders to facilitate chunking.
2. Chooses an appropriate text splitter based on the file's format.
3. Splits the content into chunks, restoring any original HTML tables after chunking.
4. Truncates chunks that exceed the maximum token size, ensuring they fit within the limit.
Args:
content (str): The content to chunk.
Yields:
Generator[tuple[str, int], None, None]:
A tuple containing the chunked content and the number of tokens in the chunk (i.e. the chunk size).
"""
# Replace HTML tables with placeholders
content, placeholders, tables = self._replace_html_tables(content)
# Split the content according to the file format
splitter = self._choose_splitter()
chunks = splitter.split_text(content)
# Restore the HTML tables in place of the placeholders
chunks = self._restore_original_tables(chunks, placeholders, tables)
for chunked_content in chunks:
# Checks the number of tokens
chunk_size = self.token_estimator.estimate_tokens(chunked_content)
if chunk_size > self.max_chunk_size:
logging.warning(
f"[LangChainChunker] [_chunk_content] [{self.filename}] Truncating {chunk_size} size chunk to fit within {self.max_chunk_size} tokens."
)
# If number of chunks exceeds `max_chunk_size`, we truncate the chunk until acceptable
chunked_content = self._truncate_chunk(chunked_content)
yield chunked_content, chunk_size
def _replace_html_tables(self, content: str) -> tuple[list]:
"""
Replaces HTML tables in the content with placeholders.
Args:
content (str): The document content.
Returns:
tuple[list]: The content with placeholders, the placeholders and a list of the original tables.
"""
table_pattern = r"(<table[\s\S]*?</table>)"
tables = re.findall(table_pattern, content, re.IGNORECASE)
placeholders = [f"__TABLE_{i}__" for i in range(len(tables))]
for placeholder, table in zip(placeholders, tables):
# Replace table in content with placeholders
content = content.replace(table, placeholder)
return content, placeholders, tables
def _restore_original_tables(
self, chunks: list[str], placeholders: list[str], tables: list[str]
) -> list[str]:
"""
Restores original tables in the chunks from placeholders.
Args:
chunks (list[str]): The list of text chunks.
placeholders (list[str]): The list of table placeholders.
tables (list[str]): The list of original tables.
Returns:
list[str]: The list of chunks with original tables restored.
"""
for placeholder, table in zip(placeholders, tables):
chunks = [chunk.replace(placeholder, table) for chunk in chunks]
return chunks
def _choose_splitter(self) -> TextSplitter | RecursiveJsonSplitter:
"""
Chooses the appropriate splitter based on file format.
Returns:
TextSplitter | RecursiveJsonSplitter:
The splitter, which is a subclass of `TextSplitter`, to use for chunking.
However, for JSON files, it returns `RecursiveJsonSplitter` instead.
"""
file_format = self.supported_formats[
self.extension
] # check is done in `get_chunk` method
if file_format == "markdown":
return MarkdownTextSplitter.from_tiktoken_encoder(
chunk_size=self.max_chunk_size, chunk_overlap=self.token_overlap
)
elif file_format == "python":
return PythonCodeTextSplitter.from_tiktoken_encoder(
chunk_size=self.max_chunk_size, chunk_overlap=self.token_overlap
)
elif file_format == "json":
return RecursiveJsonSplitter(
# As the JSON splitter doesn't have a token estimator we multiply the max_chunk_size by 4 (average number of characters per token)
max_chunk_size=self.max_chunk_size
* 4
)
else:
sentence_endings = [".", "!", "?"]
word_breaks = [" ", "\n", "\t"]
return RecursiveCharacterTextSplitter.from_tiktoken_encoder(
sentence_endings=sentence_endings + word_breaks,
chunk_size=self.max_chunk_size,
chunk_overlap=self.token_overlap,
)