Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import string
from typing import Any, Callable, Dict, List, Optional, Sequence

from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.bridge.pydantic import Field
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import Field, SerializeAsAny
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.node_parser import NodeParser
from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.node_parser.node_utils import (
build_nodes_from_splits,
Expand Down Expand Up @@ -64,9 +63,11 @@ class SemanticDoubleMergingSplitterNodeParser(NodeParser):
Semantic double merging text splitter.

Splits a document into Nodes, with each node being a group of semantically related sentences.
Supports either Spacy (language-specific) or an embedding model (any language, e.g. Hugging Face).

Args:
language_config (LanguageConfig): chooses language and spacy language model to be used
language_config (LanguageConfig): language and Spacy model when using Spacy backend (ignored if embed_model is set)
embed_model (Optional[BaseEmbedding]): when set, use this for similarity instead of Spacy (multilingual)
initial_threshold (float): sets threshold for initializing new chunk
appending_threshold (float): sets threshold for appending new sentences to chunk
merging_threshold (float): sets threshold for merging whole chunks
Expand All @@ -79,7 +80,12 @@ class SemanticDoubleMergingSplitterNodeParser(NodeParser):

language_config: LanguageConfig = Field(
default=LanguageConfig(),
description="Config that selects language and spacy model for chunking",
description="Config that selects language and spacy model for chunking (used only when embed_model is None)",
)

embed_model: Optional[SerializeAsAny[BaseEmbedding]] = Field(
default=None,
description="When set, use this embedding model for similarity instead of Spacy (enables any language).",
)

initial_threshold: float = Field(
Expand Down Expand Up @@ -141,7 +147,8 @@ def class_name(cls) -> str:
@classmethod
def from_defaults(
cls,
language_config: Optional[LanguageConfig] = LanguageConfig(),
language_config: Optional[LanguageConfig] = None,
embed_model: Optional[BaseEmbedding] = None,
initial_threshold: Optional[float] = 0.6,
appending_threshold: Optional[float] = 0.8,
merging_threshold: Optional[float] = 0.8,
Expand All @@ -156,13 +163,13 @@ def from_defaults(
id_func: Optional[Callable[[int, Document], str]] = None,
) -> "SemanticDoubleMergingSplitterNodeParser":
callback_manager = callback_manager or CallbackManager([])

sentence_splitter = sentence_splitter or split_by_sentence_tokenizer()

id_func = id_func or default_id_func

if language_config is None:
language_config = LanguageConfig()
return cls(
language_config=language_config,
embed_model=embed_model,
initial_threshold=initial_threshold,
appending_threshold=appending_threshold,
merging_threshold=merging_threshold,
Expand All @@ -177,19 +184,30 @@ def from_defaults(
id_func=id_func,
)

def _similarity(self, text_a: str, text_b: str) -> float:
if self.embed_model is not None:
embeddings = self.embed_model.get_text_embedding_batch([text_a, text_b])
return self.embed_model.similarity(embeddings[0], embeddings[1])
if self.language_config.nlp is None:
self.language_config.load_model()
assert self.language_config.nlp is not None
clean_a = self._clean_text_advanced(text_a)
clean_b = self._clean_text_advanced(text_b)
return self.language_config.nlp(clean_a).similarity(
self.language_config.nlp(clean_b)
)

def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
"""Parse document into nodes."""
# Load model
self.language_config.load_model()

if self.embed_model is None:
self.language_config.load_model()
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")

for node in nodes_with_progress:
nodes = self.build_semantic_nodes_from_nodes([node])
all_nodes.extend(nodes)
Expand Down Expand Up @@ -238,29 +256,17 @@ def build_semantic_nodes_from_nodes(

def _create_initial_chunks(self, sentences: List[str]) -> List[str]:
initial_chunks: List[str] = []
chunk = sentences[0] # ""
chunk = sentences[0]
new = True

assert self.language_config.nlp is not None

for sentence in sentences[1:]:
if new:
# check if 2 sentences got anything in common

if (
self.language_config.nlp(
self._clean_text_advanced(chunk)
).similarity(
self.language_config.nlp(self._clean_text_advanced(sentence))
)
< self.initial_threshold
self._similarity(chunk, sentence) < self.initial_threshold
and len(chunk) + len(sentence) + 1 <= self.max_chunk_size
):
# if not then leave first sentence as separate chunk
initial_chunks.append(chunk)
chunk = sentence
continue

chunk_sentences = [chunk]
if len(chunk) + len(sentence) + 1 <= self.max_chunk_size:
chunk_sentences.append(sentence)
Expand All @@ -272,70 +278,39 @@ def _create_initial_chunks(self, sentences: List[str]) -> List[str]:
chunk = sentence
continue
last_sentences = self.merging_separator.join(chunk_sentences[-2:])
# new = False

elif (
self.language_config.nlp(
self._clean_text_advanced(last_sentences)
).similarity(
self.language_config.nlp(self._clean_text_advanced(sentence))
)
> self.appending_threshold
self._similarity(last_sentences, sentence) > self.appending_threshold
and len(chunk) + len(sentence) + 1 <= self.max_chunk_size
):
# elif nlp(last_sentences).similarity(nlp(sentence)) > self.threshold:
chunk_sentences.append(sentence)
last_sentences = self.merging_separator.join(chunk_sentences[-2:])
chunk += self.merging_separator + sentence
else:
initial_chunks.append(chunk)
chunk = sentence # ""
chunk = sentence
new = True
initial_chunks.append(chunk)

return initial_chunks

def _merge_initial_chunks(self, initial_chunks: List[str]) -> List[str]:
chunks: List[str] = []
skip = 0
current = initial_chunks[0]

assert self.language_config.nlp is not None

# TODO avoid connecting 1st chunk with 3rd if 2nd one is above some value, or if its length is above some value

for i in range(1, len(initial_chunks)):
# avoid connecting same chunk multiple times
if skip > 0:
skip -= 1
continue

current_nlp = self.language_config.nlp(self._clean_text_advanced(current))

if len(current) >= self.max_chunk_size:
chunks.append(current)
current = initial_chunks[i]

# check if 1st and 2nd chunk should be connected
elif (
current_nlp.similarity(
self.language_config.nlp(
self._clean_text_advanced(initial_chunks[i])
)
)
> self.merging_threshold
self._similarity(current, initial_chunks[i]) > self.merging_threshold
and len(current) + len(initial_chunks[i]) + 1 <= self.max_chunk_size
):
current += self.merging_separator + initial_chunks[i]

# check if 1st and 3rd chunk are similar, if yes then merge 1st, 2nd, 3rd together
elif (
i <= len(initial_chunks) - 2
and current_nlp.similarity(
self.language_config.nlp(
self._clean_text_advanced(initial_chunks[i + 1])
)
)
and self._similarity(current, initial_chunks[i + 1])
> self.merging_threshold
and len(current)
+ len(initial_chunks[i])
Expand All @@ -350,15 +325,9 @@ def _merge_initial_chunks(self, initial_chunks: List[str]) -> List[str]:
+ initial_chunks[i + 1]
)
skip = 1

# check if 1st and 4th chunk are smilar, if yes then merge 1st, 2nd, 3rd and 4th together
elif (
i < len(initial_chunks) - 2
and current_nlp.similarity(
self.language_config.nlp(
self._clean_text_advanced(initial_chunks[i + 2])
)
)
and self._similarity(current, initial_chunks[i + 2])
> self.merging_threshold
and self.merging_range == 2
and len(current)
Expand All @@ -377,11 +346,9 @@ def _merge_initial_chunks(self, initial_chunks: List[str]) -> List[str]:
+ initial_chunks[i + 2]
)
skip = 2

else:
chunks.append(current)
current = initial_chunks[i]

chunks.append(current)
return chunks

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.node_parser.text.semantic_double_merging_splitter import (
SemanticDoubleMergingSplitterNodeParser,
LanguageConfig,
Expand Down Expand Up @@ -106,3 +107,40 @@ def test_chunk_size_3() -> None:
nodes = splitter.get_nodes_from_documents([doc_same])
for node in nodes:
assert len(node.get_content()) < 500


def test_embed_model_path_returns_nodes() -> None:
"""With embed_model set, chunking uses embeddings instead of Spacy (no Spacy required)."""
embed = MockEmbedding(embed_dim=4)
splitter = SemanticDoubleMergingSplitterNodeParser.from_defaults(
embed_model=embed,
initial_threshold=0.6,
appending_threshold=0.8,
merging_threshold=0.8,
max_chunk_size=1000,
)
nodes = splitter.get_nodes_from_documents([doc])
assert len(nodes) >= 1
assert all(len(n.get_content()) > 0 for n in nodes)


def test_embed_model_similarity_in_range() -> None:
"""_similarity with embed_model returns a value in [0, 1] (cosine-like)."""
embed = MockEmbedding(embed_dim=4)
splitter = SemanticDoubleMergingSplitterNodeParser.from_defaults(
embed_model=embed,
)
sim = splitter._similarity("first sentence.", "second sentence.")
assert 0 <= sim <= 1


def test_embed_model_single_sentence_document() -> None:
"""Single-sentence document yields one node when using embed_model."""
single_doc = Document(text="Only one sentence here.")
embed = MockEmbedding(embed_dim=4)
splitter = SemanticDoubleMergingSplitterNodeParser.from_defaults(
embed_model=embed,
)
nodes = splitter.get_nodes_from_documents([single_doc])
assert len(nodes) == 1
assert nodes[0].get_content() == "Only one sentence here."