Skip to content

Commit 9f4cb0b

Browse files
feat: add fuzzy deduplication post-processing
1 parent b956ef6 commit 9f4cb0b

8 files changed

Lines changed: 262 additions & 7 deletions

File tree

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,29 @@ Merge output behavior with multiple datasets:
173173
- Default (`run` with `execution_params.merge: true`, or `merge` without `--output-root`): each dataset is merged to its own `<dataset.output_dir>/merged`.
174174
- Shared root (`merge --output-root ...`): one merged subdirectory is created per dataset under the root.
175175

176+
### Fuzzy deduplication (optional)
177+
178+
After merging, MMIRAGE can drop near-duplicate rows using character n-gram MinHash + LSH. This is CPU-only and uses the lightweight `datasketch` package.
179+
180+
Install the optional extra:
181+
182+
```bash
183+
pip install -e '.[dedup]'
184+
```
185+
186+
Enable in your YAML config:
187+
188+
```yaml
189+
deduplication_params:
190+
enabled: true
191+
text_field: text
192+
threshold: 0.85 # Jaccard similarity threshold
193+
num_perm: 128 # MinHash signature size
194+
shingle_size: 5 # character n-gram size
195+
```
196+
197+
Dedup runs as part of `mmirage merge --config <cfg>` and as part of `mmirage run` when `execution_params.merge: true`. With `enabled: false` (default) the dedup module is not imported and there is no overhead.
198+
176199
### Multimodal: Processing images with VLMs
177200

178201
MMIRAGE supports multimodal processing with vision-language models:

configs/config_comprehensive.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,30 @@ execution_params:
185185
settle_time_seconds: 60
186186

187187

188+
# ============================================================================
189+
# DEDUPLICATION PARAMETERS (optional)
190+
# ============================================================================
191+
# Optional fuzzy deduplication applied after merging shards.
192+
# Uses character n-gram MinHash + LSH (via the `datasketch` package).
193+
# Install with: pip install -e '.[dedup]'
194+
195+
deduplication_params:
196+
# Set to true to enable fuzzy dedup; default false (no overhead).
197+
enabled: false
198+
199+
# Column name to deduplicate on
200+
text_field: text
201+
202+
# Jaccard similarity threshold above which rows are duplicates (0.0–1.0)
203+
threshold: 0.85
204+
205+
# Number of MinHash permutations (signature size)
206+
num_perm: 128
207+
208+
# Character n-gram size for shingling
209+
shingle_size: 5
210+
211+
188212
# ============================================================================
189213
# USAGE EXAMPLES
190214
# ============================================================================

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ dev = [
4848
"ipykernel",
4949
"pytest",
5050
]
51+
dedup = [
52+
"datasketch>=1.6.0",
53+
]
5154

5255
[project.scripts]
5356
mmirage = "mmirage.cli:main"

src/mmirage/config/config.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ class ExecutionParams:
6666
def __post_init__(self):
6767
"""Validate execution parameters."""
6868
if self.mode not in ("local", "slurm"):
69-
raise ValueError(f"Invalid execution mode: {self.mode!r}. Must be 'local' or 'slurm'.")
69+
raise ValueError(
70+
f"Invalid execution mode: {self.mode!r}. Must be 'local' or 'slurm'."
71+
)
7072
if self.mode == "slurm" and not self.account:
7173
raise ValueError("account is required when mode='slurm'")
7274
if self.max_retries < 0:
@@ -97,6 +99,25 @@ class ProcessingParams:
9799
remove_columns: bool = False
98100

99101

102+
@dataclass
103+
class DeduplicationParams:
104+
"""Configuration for fuzzy deduplication post-processing.
105+
106+
Attributes:
107+
enabled: Whether deduplication is enabled. Defaults to False.
108+
text_field: Column name containing text to deduplicate.
109+
threshold: Jaccard similarity threshold above which rows are duplicates.
110+
num_perm: Number of MinHash permutations (signature size).
111+
shingle_size: Character n-gram size for shingling.
112+
"""
113+
114+
enabled: bool = False
115+
text_field: str = "text"
116+
threshold: float = 0.85
117+
num_perm: int = 128
118+
shingle_size: int = 5
119+
120+
100121
@dataclass
101122
class MMirageConfig:
102123
"""Main configuration class for MMIRAGE pipeline.
@@ -110,9 +131,13 @@ class MMirageConfig:
110131
loading_params: Parameters for loading input datasets.
111132
processing_params: Parameters for processing dataset samples.
112133
execution_params: Parameters for executing the pipeline (local/SLURM).
134+
deduplication_params: Parameters for post-merge fuzzy deduplication.
113135
"""
114136

115137
processors: List[BaseProcessorConfig]
116138
loading_params: LoadingParams
117139
processing_params: ProcessingParams
118140
execution_params: ExecutionParams = field(default_factory=ExecutionParams)
141+
deduplication_params: DeduplicationParams = field(
142+
default_factory=DeduplicationParams
143+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Post-processing modules for MMIRAGE pipeline."""
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Fuzzy deduplication for MMIRAGE datasets.
2+
3+
Uses character n-gram MinHash + LSH to identify near-duplicate text samples
4+
and drop them in a streaming "first-seen wins" pass.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import logging
10+
from typing import Iterable, List, Set
11+
12+
from datasets import Dataset
13+
14+
from mmirage.config.config import DeduplicationParams
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def _check_dependencies() -> None:
20+
try:
21+
import datasketch # noqa: F401
22+
except ImportError as e:
23+
raise ImportError(
24+
"Deduplication requires `datasketch`. "
25+
"Install with: pip install 'mmirage[dedup]'"
26+
) from e
27+
28+
29+
def _shingles(text: str, k: int) -> Set[bytes]:
30+
text = " ".join(text.lower().split())
31+
if len(text) < k:
32+
return {text.encode("utf-8")}
33+
return {text[i : i + k].encode("utf-8") for i in range(len(text) - k + 1)}
34+
35+
36+
def deduplicate(dataset: Dataset, params: DeduplicationParams) -> Dataset:
37+
"""Remove near-duplicate samples from a dataset using char-ngram MinHash + LSH.
38+
39+
Algorithm:
40+
1. For each row, build the set of character n-grams of size `shingle_size`.
41+
2. Compute a MinHash signature with `num_perm` permutations.
42+
3. Query an LSH index built so far. If any near-duplicate is already
43+
indexed (Jaccard similarity above `threshold`), drop this row.
44+
4. Otherwise, insert the signature and keep the row.
45+
46+
Args:
47+
dataset: HuggingFace Dataset to deduplicate.
48+
params: Deduplication configuration parameters.
49+
50+
Returns:
51+
Filtered Dataset with near-duplicates removed.
52+
"""
53+
_check_dependencies()
54+
from datasketch import MinHash, MinHashLSH
55+
56+
n = len(dataset)
57+
if n <= 1:
58+
logger.info("Dataset has %d row(s), skipping deduplication.", n)
59+
return dataset
60+
61+
if params.text_field not in dataset.column_names:
62+
raise ValueError(
63+
f"Text field {params.text_field!r} not in dataset columns: "
64+
f"{dataset.column_names}"
65+
)
66+
67+
lsh = MinHashLSH(threshold=params.threshold, num_perm=params.num_perm)
68+
keep: List[int] = []
69+
texts: Iterable = dataset[params.text_field]
70+
71+
for i, raw in enumerate(texts):
72+
text = raw if isinstance(raw, str) else str(raw)
73+
m = MinHash(num_perm=params.num_perm)
74+
for s in _shingles(text, params.shingle_size):
75+
m.update(s)
76+
if not lsh.query(m):
77+
lsh.insert(str(i), m)
78+
keep.append(i)
79+
80+
n_removed = n - len(keep)
81+
logger.info(
82+
"Fuzzy dedup: %d → %d rows (%d duplicates removed).",
83+
n,
84+
len(keep),
85+
n_removed,
86+
)
87+
88+
return dataset.select(keep)

src/mmirage/merge_shards.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
99

10-
from mmirage.config.config import MMirageConfig
10+
from mmirage.config.config import DeduplicationParams, MMirageConfig
1111
from mmirage.core.loader.base import DatasetLike
1212
from mmirage.shard_utils import (
1313
_count_rows,
@@ -52,6 +52,17 @@ def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict:
5252
return DatasetDict(merged)
5353

5454

55+
def _apply_dedup(ds: DatasetLike, params: DeduplicationParams) -> DatasetLike:
56+
"""Apply fuzzy deduplication to a Dataset or each split of a DatasetDict."""
57+
from mmirage.core.postprocess.fuzzy_dedup import deduplicate
58+
59+
if isinstance(ds, DatasetDict):
60+
return DatasetDict(
61+
{split: deduplicate(split_ds, params) for split, split_ds in ds.items()}
62+
)
63+
return deduplicate(ds, params)
64+
65+
5566
def _merge_shards(shard_dsets: List[DatasetLike]) -> DatasetLike:
5667
"""Merge shard datasets into a single dataset."""
5768
if not shard_dsets:
@@ -67,12 +78,17 @@ def _merge_shards(shard_dsets: List[DatasetLike]) -> DatasetLike:
6778
)
6879

6980

70-
def merge_dataset_dir(dataset_dir: str, output_dir: str) -> MergeReport:
81+
def merge_dataset_dir(
82+
dataset_dir: str,
83+
output_dir: str,
84+
dedup_params: Optional[DeduplicationParams] = None,
85+
) -> MergeReport:
7186
"""Merge one dataset directory containing shard_* folders.
7287
7388
Args:
7489
dataset_dir: Input directory containing shard_* folders.
7590
output_dir: Destination directory for merged dataset.
91+
dedup_params: Optional fuzzy dedup config; applied before saving when enabled.
7692
7793
Returns:
7894
MergeReport with summary details.
@@ -118,6 +134,16 @@ def merge_dataset_dir(dataset_dir: str, output_dir: str) -> MergeReport:
118134
)
119135

120136
ds_merged = _merge_shards(shard_dsets)
137+
138+
if dedup_params is not None and dedup_params.enabled:
139+
rows_before = _count_rows(ds_merged)
140+
ds_merged = _apply_dedup(ds_merged, dedup_params)
141+
rows_after = _count_rows(ds_merged)
142+
logger.info(
143+
f"Fuzzy dedup: {rows_before}{rows_after} rows "
144+
f"({rows_before - rows_after} duplicates removed)."
145+
)
146+
121147
merged_rows = _count_rows(ds_merged)
122148

123149
_save_dataset_atomic(ds_merged, normalized_output_dir)
@@ -134,7 +160,11 @@ def merge_dataset_dir(dataset_dir: str, output_dir: str) -> MergeReport:
134160
)
135161

136162

137-
def merge_input_dir(input_dir: str, output_dir: str) -> List[MergeReport]:
163+
def merge_input_dir(
164+
input_dir: str,
165+
output_dir: str,
166+
dedup_params: Optional[DeduplicationParams] = None,
167+
) -> List[MergeReport]:
138168
"""Merge all shard datasets found under an input directory.
139169
140170
The input can be either:
@@ -167,7 +197,7 @@ def merge_input_dir(input_dir: str, output_dir: str) -> List[MergeReport]:
167197
dataset_name = os.path.basename(dataset_dir)
168198
ds_output_dir = os.path.join(output_dir, dataset_name)
169199

170-
reports.append(merge_dataset_dir(dataset_dir, ds_output_dir))
200+
reports.append(merge_dataset_dir(dataset_dir, ds_output_dir, dedup_params))
171201

172202
return reports
173203

@@ -210,7 +240,9 @@ def merge_from_config(
210240
folder_name = f"{dataset_name}_{index}"
211241
output_dir = os.path.join(output_root, folder_name)
212242

213-
reports.append(merge_dataset_dir(dataset_dir, output_dir))
243+
reports.append(
244+
merge_dataset_dir(dataset_dir, output_dir, cfg.deduplication_params)
245+
)
214246

215247
return reports
216248

@@ -232,6 +264,11 @@ def main():
232264
required=True,
233265
help="Directory to write merged datasets into.",
234266
)
267+
ap.add_argument(
268+
"--config",
269+
default=None,
270+
help="Optional MMIRAGE YAML config; enables fuzzy dedup if configured.",
271+
)
235272
ap.add_argument(
236273
"--log-level",
237274
default="INFO",
@@ -241,7 +278,13 @@ def main():
241278
args = ap.parse_args()
242279
_configure_logging(args.log_level)
243280

244-
reports = merge_input_dir(args.input_dir, args.output_dir)
281+
dedup_params: Optional[DeduplicationParams] = None
282+
if args.config:
283+
from mmirage.config.utils import load_mmirage_config
284+
285+
dedup_params = load_mmirage_config(args.config).deduplication_params
286+
287+
reports = merge_input_dir(args.input_dir, args.output_dir, dedup_params)
245288
for report in reports:
246289
skipped_total = report.skipped_invalid_dirs + report.skipped_zero_rows
247290
logger.info(

tests/test_dedup.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Smoke test for fuzzy deduplication on TinyStories."""
2+
3+
import argparse
4+
import logging
5+
6+
from datasets import load_dataset
7+
8+
from mmirage.config.config import DeduplicationParams
9+
from mmirage.core.postprocess.fuzzy_dedup import deduplicate
10+
11+
12+
def main():
13+
ap = argparse.ArgumentParser()
14+
ap.add_argument(
15+
"--limit",
16+
type=int,
17+
default=None,
18+
help="Optional row limit (default: full dataset).",
19+
)
20+
ap.add_argument("--threshold", type=float, default=0.85)
21+
ap.add_argument("--num-perm", type=int, default=128)
22+
ap.add_argument("--shingle-size", type=int, default=5)
23+
args = ap.parse_args()
24+
25+
logging.basicConfig(level=logging.INFO)
26+
27+
ds = load_dataset("roneneldan/TinyStories", split="train")
28+
if args.limit is not None:
29+
ds = ds.select(range(min(args.limit, len(ds))))
30+
print(f"Loaded {len(ds):,} rows")
31+
32+
params = DeduplicationParams(
33+
enabled=True,
34+
text_field="text",
35+
threshold=args.threshold,
36+
num_perm=args.num_perm,
37+
shingle_size=args.shingle_size,
38+
)
39+
deduped = deduplicate(ds, params)
40+
removed = len(ds) - len(deduped)
41+
print(
42+
f"{len(ds):,}{len(deduped):,} "
43+
f"(removed {removed:,}, {removed / len(ds) * 100:.2f}%)"
44+
)
45+
46+
47+
if __name__ == "__main__":
48+
main()

0 commit comments

Comments
 (0)