Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.

Commit e9d3842

Browse files
committed
feat: add SimilarityFilterBlock for near-duplicate filtering
Add a new SimilarityFilterBlock that removes near-duplicate rows from a Dataset based on text similarity using difflib.SequenceMatcher. Supports configurable similarity threshold and optional group_by column to scope deduplication within groups. Zero new dependencies.
1 parent faa84ef commit e9d3842

3 files changed

Lines changed: 208 additions & 0 deletions

File tree

src/instructlab/sdg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"SamplePopulatorBlock",
2727
"SelectorBlock",
2828
"SetToMajorityValueBlock",
29+
"SimilarityFilterBlock",
2930
"FULL_PIPELINES_PACKAGE",
3031
"SIMPLE_PIPELINES_PACKAGE",
3132
"LLAMA_PIPELINES_PKG",
@@ -37,6 +38,7 @@
3738
from .blocks.block import Block, BlockConfigParserError
3839
from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError
3940
from .blocks.iterblock import IterBlock
41+
from .blocks.similarityfilterblock import SimilarityFilterBlock
4042
from .blocks.llmblock import (
4143
ConditionalLLMBlock,
4244
LLMBlock,
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Standard
4+
from difflib import SequenceMatcher
5+
import logging
6+
7+
# Third Party
8+
import pandas as pd
9+
from datasets import Dataset
10+
11+
# Local
12+
from ..registry import BlockRegistry
13+
from ..utils.pandas import dataset_from_pandas_dataframe
14+
from .block import Block
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def _similarity(a: str, b: str) -> float:
20+
"""Compute similarity ratio between two strings."""
21+
if not a or not b:
22+
return 0.0
23+
return SequenceMatcher(None, a, b).ratio()
24+
25+
26+
def _deduplicate_group(group, col, threshold):
27+
"""Remove near-duplicate rows within a single group.
28+
29+
Returns a list of integer indices to keep.
30+
"""
31+
kept_indices = []
32+
kept_texts = []
33+
34+
for idx, row in group.iterrows():
35+
text = str(row[col])
36+
is_duplicate = any(
37+
_similarity(text, kept) > threshold for kept in kept_texts
38+
)
39+
if not is_duplicate:
40+
kept_indices.append(idx)
41+
kept_texts.append(text)
42+
43+
return kept_indices
44+
45+
46+
# This is part of the public API.
47+
@BlockRegistry.register("SimilarityFilterBlock")
48+
class SimilarityFilterBlock(Block):
49+
def __init__(
50+
self,
51+
ctx,
52+
pipe,
53+
block_name,
54+
filter_column,
55+
threshold=0.85,
56+
group_by=None,
57+
) -> None:
58+
"""
59+
Initializes a new instance of the SimilarityFilterBlock class.
60+
61+
Parameters:
62+
- ctx (PipelineContext): A PipelineContext object containing runtime parameters.
63+
- pipe (Pipeline): The Pipeline containing this block in its chain.
64+
- block_name (str): An identifier for this block.
65+
- filter_column (str): The column containing text to compare for similarity.
66+
- threshold (float): Similarity ratio (0.0 to 1.0). Rows with similarity
67+
above this value to any previously kept row are dropped. Default 0.85.
68+
- group_by (str, optional): Column to group by before deduplication.
69+
If set, similarity is only compared within each group. Default None.
70+
"""
71+
super().__init__(ctx, pipe, block_name)
72+
self.filter_column = filter_column
73+
self.threshold = threshold
74+
self.group_by = group_by
75+
76+
def generate(self, samples) -> Dataset:
77+
if len(samples) == 0:
78+
return samples
79+
80+
df = samples.to_pandas()
81+
original_len = len(df)
82+
83+
if self.group_by and self.group_by in df.columns:
84+
groups = []
85+
for _, group in df.groupby(self.group_by):
86+
kept = _deduplicate_group(group, self.filter_column, self.threshold)
87+
groups.append(group.loc[kept])
88+
result = (
89+
pd.concat(groups, ignore_index=True)
90+
if groups
91+
else df.iloc[:0]
92+
)
93+
else:
94+
kept = _deduplicate_group(df, self.filter_column, self.threshold)
95+
result = df.loc[kept]
96+
97+
removed = original_len - len(result)
98+
if removed > 0:
99+
logger.info(
100+
"SimilarityFilterBlock '%s': removed %d near-duplicates "
101+
"(threshold=%.2f), %d → %d rows",
102+
self.block_name,
103+
removed,
104+
self.threshold,
105+
original_len,
106+
len(result),
107+
)
108+
109+
return dataset_from_pandas_dataframe(result)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Standard
4+
from unittest.mock import MagicMock
5+
import unittest
6+
7+
# Third Party
8+
from datasets import Dataset
9+
10+
# First Party
11+
from instructlab.sdg import SimilarityFilterBlock
12+
13+
14+
class TestSimilarityFilterBlock(unittest.TestCase):
15+
def setUp(self):
16+
self.ctx = MagicMock()
17+
self.ctx.dataset_num_procs = 1
18+
self.pipe = MagicMock()
19+
20+
def _make_block(self, filter_column="text", threshold=0.85, group_by=None):
21+
return SimilarityFilterBlock(
22+
self.ctx,
23+
self.pipe,
24+
"test_similarity_filter",
25+
filter_column=filter_column,
26+
threshold=threshold,
27+
group_by=group_by,
28+
)
29+
30+
def test_keeps_unique_rows(self):
31+
block = self._make_block()
32+
ds = Dataset.from_dict(
33+
{"text": ["alpha bravo charlie", "delta echo foxtrot", "golf hotel india"]}
34+
)
35+
result = block.generate(ds)
36+
self.assertEqual(len(result), 3)
37+
38+
def test_removes_exact_duplicates(self):
39+
block = self._make_block(threshold=0.8)
40+
ds = Dataset.from_dict(
41+
{"text": ["hello world", "hello world", "hello world"]}
42+
)
43+
result = block.generate(ds)
44+
self.assertEqual(len(result), 1)
45+
46+
def test_removes_near_duplicates(self):
47+
block = self._make_block(threshold=0.7)
48+
ds = Dataset.from_dict(
49+
{
50+
"text": [
51+
"What is photosynthesis and how does it work?",
52+
"What is photosynthesis and how does it function?",
53+
"Explain the process of sourdough bread making.",
54+
]
55+
}
56+
)
57+
result = block.generate(ds)
58+
self.assertEqual(len(result), 2)
59+
60+
def test_group_by_isolates_groups(self):
61+
block = self._make_block(threshold=0.8, group_by="doc_id")
62+
ds = Dataset.from_dict(
63+
{
64+
"text": ["same text here", "same text here"],
65+
"doc_id": ["doc_a", "doc_b"],
66+
}
67+
)
68+
result = block.generate(ds)
69+
self.assertEqual(len(result), 2)
70+
71+
def test_group_by_deduplicates_within_group(self):
72+
block = self._make_block(threshold=0.8, group_by="doc_id")
73+
ds = Dataset.from_dict(
74+
{
75+
"text": ["same text here", "same text here"],
76+
"doc_id": ["doc_a", "doc_a"],
77+
}
78+
)
79+
result = block.generate(ds)
80+
self.assertEqual(len(result), 1)
81+
82+
def test_empty_dataset(self):
83+
block = self._make_block()
84+
ds = Dataset.from_dict({"text": []})
85+
result = block.generate(ds)
86+
self.assertEqual(len(result), 0)
87+
88+
def test_low_threshold_more_aggressive(self):
89+
texts = [
90+
"What is photosynthesis?",
91+
"What is the process of photosynthesis?",
92+
"Explain sourdough bread.",
93+
]
94+
strict = self._make_block(threshold=0.5)
95+
lenient = self._make_block(threshold=0.95)
96+
ds = Dataset.from_dict({"text": texts})
97+
self.assertLessEqual(len(strict.generate(ds)), len(lenient.generate(ds)))

0 commit comments

Comments
 (0)