-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat: implement syntax-aware code chunking with Tree-sitter #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
b34fb57
e6306c4
66f42c4
73fa53d
67a5b34
c025575
fb299fc
17243fa
78b01ce
0900aea
4215e41
a2e0b4d
40bab77
0000098
12b2a4e
b58ff0f
a1df9aa
e20482c
a216766
c51b99d
dcf795e
daa8e62
5776401
15a7e7f
76b5a33
fb914c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,320 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| import importlib | ||
| import logging | ||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple | ||
|
|
||
| from adalflow.core.component import DataComponent | ||
| from adalflow.components.data_process import TextSplitter | ||
| from adalflow.core.types import Document | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _DEFINITION_TYPE_KEYWORDS = ( | ||
| "function", | ||
| "method", | ||
| "class", | ||
| "interface", | ||
| "struct", | ||
| "enum", | ||
| "trait", | ||
| "impl", | ||
| "module", | ||
| "namespace", | ||
| "type", | ||
| ) | ||
|
|
||
|
|
||
| _EXT_TO_LANGUAGE: Dict[str, str] = { | ||
| "py": "python", | ||
| "js": "javascript", | ||
| "jsx": "javascript", | ||
| "ts": "typescript", | ||
| "tsx": "tsx", | ||
| "java": "java", | ||
| "c": "c", | ||
| "h": "c", | ||
| "cpp": "cpp", | ||
| "hpp": "cpp", | ||
| "cc": "cpp", | ||
| "cs": "c_sharp", | ||
| "go": "go", | ||
| "rs": "rust", | ||
| "php": "php", | ||
| "rb": "ruby", | ||
| "swift": "swift", | ||
| "kt": "kotlin", | ||
| "kts": "kotlin", | ||
| "scala": "scala", | ||
| "lua": "lua", | ||
| "sh": "bash", | ||
| "bash": "bash", | ||
| "html": "html", | ||
| "css": "css", | ||
| "json": "json", | ||
| "yml": "yaml", | ||
| "yaml": "yaml", | ||
| "toml": "toml", | ||
| "md": "markdown", | ||
| } | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class CodeSplitterConfig: | ||
| chunk_size_lines: int = 200 | ||
| chunk_overlap_lines: int = 20 | ||
| min_chunk_lines: int = 5 | ||
| enabled: bool = True | ||
|
|
||
|
|
||
| def _safe_import_tree_sitter() -> Optional[Callable[..., Any]]: | ||
| """Safely import and return the `get_parser` function from tree_sitter_languages.""" | ||
| module_candidates = [ | ||
| "tree_sitter_languages", # module name used by tree-sitter-languages on most installs | ||
| ] | ||
|
|
||
| for module_name in module_candidates: | ||
| try: | ||
| mod = importlib.import_module(module_name) | ||
| get_parser = getattr(mod, "get_parser", None) | ||
| if callable(get_parser): | ||
| return get_parser | ||
| except ImportError: | ||
| continue | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def _iter_definition_like_nodes(root_node: Any) -> Iterable[Any]: | ||
| for child in getattr(root_node, "children", []): | ||
| if not getattr(child, "is_named", False): | ||
| continue | ||
| node_type = getattr(child, "type", "") | ||
| # Split node type into words to avoid partial matches on keywords. | ||
| lowered_parts = set(node_type.lower().replace("_", " ").split()) | ||
| if any(k in lowered_parts for k in _DEFINITION_TYPE_KEYWORDS): | ||
| yield child | ||
danielfrey63 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _split_lines_with_overlap( | ||
| lines: List[str], *, chunk_size_lines: int, chunk_overlap_lines: int | ||
| ) -> List[Tuple[List[str], int]]: | ||
| if chunk_size_lines <= 0: | ||
| return [(lines, 0)] | ||
|
|
||
| overlap = max(0, min(chunk_overlap_lines, chunk_size_lines - 1)) | ||
| chunks: List[Tuple[List[str], int]] = [] | ||
| start = 0 | ||
| n = len(lines) | ||
|
|
||
| while start < n: | ||
| end = min(n, start + chunk_size_lines) | ||
| chunks.append((lines[start:end], start)) | ||
| if end >= n: | ||
| break | ||
| start = end - overlap | ||
|
|
||
| return chunks | ||
|
|
||
|
|
||
| def _slice_text_by_bytes_preencoded(text_bytes: bytes, start_byte: int, end_byte: int) -> str: | ||
| return text_bytes[start_byte:end_byte].decode("utf-8", errors="replace") | ||
|
|
||
|
|
||
| def _byte_offset_to_line_preencoded(text_bytes: bytes, byte_offset: int) -> int: | ||
| prefix = text_bytes[:max(0, byte_offset)] | ||
| return prefix.count(b"\n") + 1 | ||
|
|
||
|
|
||
| class TreeSitterCodeSplitter: | ||
| def __init__( | ||
| self, | ||
| *, | ||
| chunk_size_lines: int = 200, | ||
| chunk_overlap_lines: int = 20, | ||
| min_chunk_lines: int = 5, | ||
| enabled: bool = True, | ||
| ) -> None: | ||
| self.config = CodeSplitterConfig( | ||
| chunk_size_lines=chunk_size_lines, | ||
| chunk_overlap_lines=chunk_overlap_lines, | ||
| min_chunk_lines=min_chunk_lines, | ||
| enabled=enabled, | ||
| ) | ||
| self._get_parser = _safe_import_tree_sitter() | ||
|
|
||
| def is_available(self) -> bool: | ||
| return self._get_parser is not None | ||
|
|
||
| def split_document(self, doc: Document) -> List[Document]: | ||
| if not self.config.enabled: | ||
| return [doc] | ||
|
|
||
| meta = getattr(doc, "meta_data", {}) or {} | ||
| if not meta.get("is_code"): | ||
| return [doc] | ||
|
|
||
| file_type = (meta.get("type") or "").lower().lstrip(".") | ||
| return self._split_code_text(doc.text or "", meta, file_type) | ||
|
|
||
| def _get_language_name_candidates(self, file_type: str) -> List[str]: | ||
| mapped = _EXT_TO_LANGUAGE.get(file_type) | ||
| candidates: List[str] = [] | ||
| if mapped: | ||
| candidates.append(mapped) | ||
| if file_type and file_type not in candidates: | ||
| candidates.append(file_type) | ||
| return candidates | ||
|
|
||
| def _try_get_parser(self, file_type: str) -> Any: | ||
| if self._get_parser is None: | ||
| return None | ||
|
|
||
| for name in self._get_language_name_candidates(file_type): | ||
| try: | ||
| return self._get_parser(name) | ||
| except Exception as e: | ||
| logger.debug("Failed to get parser for language '%s': %s", name, e) | ||
| continue | ||
| return None | ||
|
|
||
| def _split_code_text(self, text: str, meta: Dict[str, Any], file_type: str) -> List[Document]: | ||
| parser = self._try_get_parser(file_type) | ||
| if parser is None: | ||
| return self._fallback_line_split(text, meta) | ||
|
|
||
| text_bytes = text.encode("utf-8", errors="replace") | ||
| try: | ||
| tree = parser.parse(text_bytes) | ||
| except Exception: | ||
| return self._fallback_line_split(text, meta) | ||
danielfrey63 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| root = getattr(tree, "root_node", None) | ||
| if root is None: | ||
| return self._fallback_line_split(text, meta) | ||
|
|
||
| nodes = list(_iter_definition_like_nodes(root)) | ||
| if not nodes: | ||
| return self._fallback_line_split(text, meta) | ||
|
|
||
| pieces: List[Tuple[str, int]] = [] | ||
| for node in nodes: | ||
| try: | ||
| start_b = int(getattr(node, "start_byte")) | ||
| end_b = int(getattr(node, "end_byte")) | ||
| except (AttributeError, ValueError, TypeError) as e: | ||
| logger.debug("Could not process a tree-sitter node for file type '%s': %s", file_type, e) | ||
| continue | ||
| snippet = _slice_text_by_bytes_preencoded(text_bytes, start_b, end_b) | ||
| start_line = _byte_offset_to_line_preencoded(text_bytes, start_b) | ||
| pieces.append((snippet, start_line)) | ||
|
|
||
| if not pieces: | ||
| return self._fallback_line_split(text, meta) | ||
|
|
||
| docs: List[Document] = [] | ||
| for snippet, start_line in pieces: | ||
| snippet_lines = snippet.splitlines(True) | ||
| if len(snippet_lines) < self.config.min_chunk_lines: | ||
| continue | ||
|
|
||
| if len(snippet_lines) <= self.config.chunk_size_lines: | ||
| docs.append(self._make_chunk_doc(snippet, meta, start_line)) | ||
| continue | ||
|
|
||
| for sub, sub_start_idx in _split_lines_with_overlap( | ||
| snippet_lines, | ||
| chunk_size_lines=self.config.chunk_size_lines, | ||
| chunk_overlap_lines=self.config.chunk_overlap_lines, | ||
| ): | ||
| sub_text = "".join(sub) | ||
| docs.append(self._make_chunk_doc(sub_text, meta, start_line + sub_start_idx)) | ||
|
|
||
| if not docs: | ||
| return self._fallback_line_split(text, meta) | ||
| else: | ||
| return self._add_chunk_metadata(docs) | ||
|
Comment on lines
+255
to
+306
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation for splitting large code blocks has a significant limitation. When a semantic block (like a class) is larger than For example, a 500-line Java class will be identified as a single semantic block. If A more robust approach would be to recursively split large nodes. If a node is too large, the splitter should attempt to split it based on its children definition nodes (e.g., methods within a class) before falling back to line-based splitting. This would ensure that the chunking remains syntax-aware at deeper levels of the code structure. |
||
|
|
||
| def _add_chunk_metadata(self, docs: List[Document]) -> List[Document]: | ||
| for i, d in enumerate(docs): | ||
| d.meta_data["chunk_index"] = i | ||
| d.meta_data["chunk_total"] = len(docs) | ||
danielfrey63 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return docs | ||
|
|
||
| def _fallback_line_split(self, text: str, meta: Dict[str, Any]) -> List[Document]: | ||
| lines = text.splitlines(True) | ||
| docs: List[Document] = [] | ||
| for sub, start_idx in _split_lines_with_overlap( | ||
| lines, | ||
| chunk_size_lines=self.config.chunk_size_lines, | ||
| chunk_overlap_lines=self.config.chunk_overlap_lines, | ||
| ): | ||
| sub_text = "".join(sub) | ||
| if len(sub) < self.config.min_chunk_lines: | ||
| continue | ||
| start_line = 1 + start_idx | ||
| docs.append(self._make_chunk_doc(sub_text, meta, start_line)) | ||
|
|
||
| if not docs: | ||
| return [Document(text=text, meta_data=dict(meta))] | ||
| else: | ||
| return self._add_chunk_metadata(docs) | ||
|
Comment on lines
+404
to
+407
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a small inconsistency in metadata handling. When if not docs:
docs = [Document(text=text, meta_data=dict(meta))]
return self._add_chunk_metadata(docs) |
||
|
|
||
| def _make_chunk_doc(self, chunk_text: str, meta: Dict[str, Any], start_line: int) -> Document: | ||
| new_meta = dict(meta) | ||
| new_meta["chunk_start_line"] = start_line | ||
| file_path = new_meta.get("file_path") | ||
| if file_path: | ||
| new_meta["title"] = str(file_path) | ||
| return Document(text=chunk_text, meta_data=new_meta) | ||
|
|
||
|
|
||
| class CodeAwareSplitter(DataComponent): | ||
| def __init__( | ||
| self, | ||
| *, | ||
| text_splitter: TextSplitter, | ||
| code_splitter: TreeSitterCodeSplitter, | ||
| ) -> None: | ||
| super().__init__() | ||
| self._text_splitter = text_splitter | ||
| self._code_splitter = code_splitter | ||
|
|
||
| def __call__(self, documents: Sequence[Document]) -> Sequence[Document]: | ||
| output: List[Document] = [] | ||
| for doc in documents: | ||
| meta = getattr(doc, "meta_data", {}) or {} | ||
| file_path = meta.get("file_path") or meta.get("title") or "<unknown>" | ||
| is_code = bool(meta.get("is_code")) | ||
| logger.info("Splitting document: %s (is_code=%s)", file_path, is_code) | ||
| if is_code: | ||
| chunks = self._code_splitter.split_document(doc) | ||
| logger.info("Split result: %s -> %d chunks (code)", file_path, len(chunks)) | ||
| output.extend(chunks) | ||
| else: | ||
| logger.info("TextSplitter start: %s", file_path) | ||
| chunks = list(self._text_splitter([doc])) | ||
| logger.info("TextSplitter result: %s -> %d chunks", file_path, len(chunks)) | ||
| output.extend(chunks) | ||
| return output | ||
|
|
||
| def to_dict(self) -> Dict[str, Any]: | ||
| return { | ||
| "text_splitter": self._text_splitter.to_dict() if hasattr(self._text_splitter, "to_dict") else None, | ||
| "code_splitter_config": { | ||
| "chunk_size_lines": self._code_splitter.config.chunk_size_lines, | ||
| "chunk_overlap_lines": self._code_splitter.config.chunk_overlap_lines, | ||
| "min_chunk_lines": self._code_splitter.config.min_chunk_lines, | ||
| "enabled": self._code_splitter.config.enabled, | ||
| } | ||
|
Comment on lines
+450
to
+456
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using "code_splitter_config": asdict(self._code_splitter.config), |
||
| } | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, data: Dict[str, Any]) -> "CodeAwareSplitter": | ||
| from adalflow.components.data_process import TextSplitter | ||
| text_splitter_data = data.get("text_splitter") | ||
| text_splitter = TextSplitter.from_dict(text_splitter_data) if text_splitter_data else TextSplitter() | ||
| code_config = data.get("code_splitter_config", {}) | ||
| code_splitter = TreeSitterCodeSplitter(**code_config) | ||
| return cls(text_splitter=text_splitter, code_splitter=code_splitter) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability, you can also import
asdicthere. It can be used later to simplify the serialization of theCodeSplitterConfigdataclass into a dictionary.