Skip to content

Commit 978fb79

Browse files
committed
consolidate: drop unused CLASSIFY filter and percentile machinery
CLASSIFY was only exercised by tests and a broken quickstart yaml (no CLI backs it). The two real callers — datakit_ferry and the integration test — use REMOVE_DOC and REMOVE_SPANS. Removing the dead path drops ddsketch and the percentile threshold code with it.
1 parent 73a2b3a commit 978fb79

File tree

5 files changed

+7
-288
lines changed

5 files changed

+7
-288
lines changed

lib/marin/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ dependencies = [
1212
"braceexpand",
1313
"cryptography>=45",
1414
"datasets<4.0.0",
15-
"ddsketch",
1615
"deepdiff",
1716
"draccus>=0.11.5",
1817
"fasteners>=0.19",

lib/marin/src/marin/processing/classification/README.md

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,5 @@ uv run python -m marin.processing.classification.decon \
2323

2424
[`consolidate.py`](./consolidate.py) consumes attribute files and filters or rewrites documents. Supported filter types:
2525

26-
- `classify`: keep or reject documents based on attribute scores
2726
- `remove_spans`: remove text spans such as duplicate paragraphs
2827
- `remove_docs`: drop whole documents when an attribute marks them as duplicates
29-
30-
Example:
31-
32-
```bash
33-
uv run python -m marin.processing.classification.consolidate \
34-
--config_path lib/marin/src/marin/processing/classification/config/quickstart_consolidate_fasttext.yaml
35-
```
36-
37-
Example `classify` filter:
38-
39-
```yaml
40-
input_path: "gs://marin-us-central2/documents/hello_world_fw/v1.0/quickstart/"
41-
output_path: "gs://marin-us-central2/documents/hello_world_fw/v1.0/quickstart_fasttext_only/"
42-
43-
filters:
44-
- type: "classify"
45-
attribute_path: "gs://marin-us-central2/attributes/hello_world_fw/v1.0/quickstart_olmo_fasttext/"
46-
name: "olmo-fasttext-quality"
47-
label: "__label__hq"
48-
lower_threshold: 0.1
49-
```

lib/marin/src/marin/processing/classification/config/quickstart_consolidate_fasttext.yaml

Lines changed: 0 additions & 9 deletions
This file was deleted.

lib/marin/src/marin/processing/classification/consolidate.py

Lines changed: 5 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
"""
55
Consolidate takes a set of documents with corresponding attributes and writes
66
out a subset of the documents based on various filters defined with respect to
7-
the attributes. Handles three cases:
8-
- Quality filtering produces attributes (e.g., fasttext-quality) with labels
9-
(e.g., __label__hq), filter on threshold.
7+
the attributes. Handles two cases:
108
- Span removal produces attributes (e.g., duplicate_text spans). Remove text spans.
119
- Document removal via attribute produced by deduplication.
1210
@@ -19,8 +17,8 @@
1917

2018
import logging
2119
import os
22-
from collections.abc import Callable, Iterator
23-
from dataclasses import dataclass, replace
20+
from collections.abc import Callable
21+
from dataclasses import dataclass
2422
from enum import StrEnum
2523
from typing import Any
2624

@@ -34,7 +32,6 @@
3432

3533

3634
class FilterType(StrEnum):
37-
CLASSIFY = "classify"
3835
REMOVE_SPANS = "remove_spans"
3936
REMOVE_DOC = "remove_docs"
4037

@@ -55,21 +52,6 @@ class FilterConfig:
5552
name: str
5653
"""Name of attribute to use for filtering."""
5754

58-
label: str | None = None
59-
"""The label under the attribute name."""
60-
61-
lower_threshold: float | None = None
62-
"""Keep documents where the value is above this."""
63-
64-
keep_fraction: float | None = None
65-
"""Keep documents where the score is in the top percentile. Calculates the threshold from the entire dataset."""
66-
67-
upper_threshold: float | None = None
68-
"""Keep documents where the value is below this."""
69-
70-
reverse: bool = False
71-
"""Reverse the filter."""
72-
7355
attribute_filetype: str | None = None
7456
"""File extension for attribute files (e.g. 'jsonl.gz', 'vortex'). If None, uses the input filetype."""
7557

@@ -84,32 +66,6 @@ class FilterConfig:
8466
}
8567

8668

87-
def _is_valid(doc: dict, filt: FilterConfig, attributes: dict) -> bool:
88-
assert filt.type == FilterType.CLASSIFY
89-
attribute_value = attributes[filt.name]
90-
91-
# Handle nested attributes structure if a label is specified
92-
if filt.label is not None:
93-
if isinstance(attribute_value, dict) and filt.label in attribute_value:
94-
value = attribute_value[filt.label]
95-
else:
96-
raise ValueError(f"Label {filt.label} not found in attribute {filt.name} for document {doc}")
97-
else:
98-
value = attribute_value
99-
100-
# Check both lower and upper bounds if specified
101-
accepted = True
102-
if filt.lower_threshold is not None and value < filt.lower_threshold:
103-
accepted = False
104-
if filt.upper_threshold is not None and value > filt.upper_threshold:
105-
accepted = False
106-
107-
if filt.reverse:
108-
accepted = not accepted
109-
110-
return accepted
111-
112-
11369
def _remove_spans_from_doc(doc: dict, filt: FilterConfig, attributes: dict) -> dict:
11470
def _remove_spans(text: str, spans: list[list[int]]) -> str:
11571
"""Return ``text`` with ``spans`` removed.
@@ -151,93 +107,6 @@ def _make_id_extractor(corpus_type: str) -> Callable[[dict], Any]:
151107
return lambda r: extract_id(r, corpus_type)
152108

153109

154-
def _compute_percentile_threshold(
155-
attr_paths: list[str], attr_name: str, attr_label: str | None, keep_fraction: float
156-
) -> float:
157-
"""Compute percentile threshold for a single filter using DDSketch reduction.
158-
159-
Args:
160-
attr_paths: Paths to attribute files
161-
attr_name: Name of attribute to extract
162-
attr_label: Optional label within attribute (for nested dicts)
163-
keep_fraction: Fraction of documents to keep (0-1)
164-
165-
Returns:
166-
Threshold value at the (1 - keep_fraction) quantile
167-
"""
168-
from ddsketch import DDSketch
169-
170-
def local_reducer(rows: Iterator[dict], attr_name: str = attr_name, attr_label: str | None = attr_label) -> DDSketch:
171-
"""Build DDSketch from rows in a single shard."""
172-
sketch = DDSketch()
173-
for row in rows:
174-
attributes = row["attributes"]
175-
value = attributes[attr_name][attr_label] if attr_label else attributes[attr_name]
176-
sketch.add(value)
177-
return sketch
178-
179-
def global_reducer(sketches: Iterator[DDSketch]) -> DDSketch:
180-
"""Merge all shard sketches into one."""
181-
combined = DDSketch()
182-
for sketch in sketches:
183-
combined.merge(sketch)
184-
return combined
185-
186-
ctx = ZephyrContext(name="consolidate-stats")
187-
result = ctx.execute(
188-
Dataset.from_list(attr_paths)
189-
.load_file()
190-
.select("attributes")
191-
.reduce(local_reducer=local_reducer, global_reducer=global_reducer)
192-
).results
193-
194-
combined_sketch = next(iter(result))
195-
threshold = combined_sketch.get_quantile_value(1 - keep_fraction)
196-
return threshold
197-
198-
199-
def calculate_percentile_thresholds(
200-
*,
201-
input_path: str,
202-
filters: list[FilterConfig],
203-
filetype: str = "jsonl.gz",
204-
) -> list[FilterConfig]:
205-
"""Resolve ``keep_fraction`` filters to ``lower_threshold`` via percentile calculation.
206-
207-
Returns a new list of filters with percentile-based thresholds resolved.
208-
"""
209-
updated_filters = []
210-
input_paths = fsspec_glob(os.path.join(input_path, f"**/*.{filetype}"))
211-
212-
for filt in filters:
213-
# Validate threshold configuration
214-
if filt.keep_fraction is not None and filt.lower_threshold is not None:
215-
raise ValueError("Cannot specify both keep_fraction and lower_threshold. Please specify only one.")
216-
217-
# Skip if no percentile calculation needed
218-
if filt.keep_fraction is None:
219-
updated_filters.append(filt)
220-
continue
221-
222-
if not (0 < filt.keep_fraction < 1):
223-
raise ValueError("keep_fraction must be between 0 and 1")
224-
225-
# Only applies to CLASSIFY filters
226-
if filt.type != FilterType.CLASSIFY:
227-
logger.warning(f"keep_fraction only applies to CLASSIFY filters, ignoring for {filt.name}")
228-
updated_filters.append(filt)
229-
continue
230-
231-
attr_paths = _attribute_paths_for_filter(input_path, input_paths, filt, filetype)
232-
attr_paths = [p for p in attr_paths if p is not None]
233-
234-
threshold = _compute_percentile_threshold(attr_paths, filt.name, filt.label, filt.keep_fraction)
235-
logger.info(f"Calculated threshold {threshold} for {filt.name} to keep {filt.keep_fraction} of documents")
236-
updated_filters.append(replace(filt, lower_threshold=threshold, keep_fraction=None))
237-
238-
return updated_filters
239-
240-
241110
def _resolve_attribute_path(input_base: str, input_path: str, filt: FilterConfig, filetype: str) -> str | None:
242111
"""Map an input file path to its attribute file path, with glob fallback for compression suffixes."""
243112
new_extension = f".{filt.attribute_filetype}" if filt.attribute_filetype else f".{filetype}"
@@ -288,8 +157,6 @@ def combine(left: dict, right: dict | None) -> dict | None:
288157
return left if filt.keep_if_missing else None
289158

290159
attrs = right["attributes"]
291-
if filt.type == FilterType.CLASSIFY:
292-
return left if _is_valid(left, filt, attrs) else None
293160
if filt.type == FilterType.REMOVE_DOC:
294161
return left if not attrs.get(filt.name, False) else None
295162
assert filt.type == FilterType.REMOVE_SPANS
@@ -311,8 +178,7 @@ def consolidate(
311178
312179
Joins each input file with its (co-partitioned, sorted) attribute files via
313180
chained ``sorted_merge_join`` ops — one left join per filter, with the
314-
filter's keep/mutate/drop logic encoded in its combiner. No in-memory hash
315-
table is materialized.
181+
filter's keep/mutate/drop logic encoded in its combiner.
316182
317183
Args:
318184
input_path: Directory (recursively) containing input documents.
@@ -321,13 +187,12 @@ def consolidate(
321187
filetype: Extension of the input documents (default: ``"jsonl.gz"``).
322188
worker_resources: Optional Zephyr worker resource config (defaults to Zephyr defaults).
323189
"""
324-
filters = calculate_percentile_thresholds(input_path=input_path, filters=filters, filetype=filetype)
325190
input_paths = sorted(fsspec_glob(os.path.join(input_path, f"**/*.{filetype}")))
326191
if not input_paths:
327192
raise ValueError(f"No input files matched {input_path}/**/*.{filetype}")
328193
logger.info(f"Consolidating {len(input_paths)} document files via {len(filters)} sorted_merge_join(s)")
329194

330-
# Determine id key; assume a uniform corpus across shards (matches prior per-shard behavior
195+
# Determine id key; assume a uniform corpus across shards (matches prior per-shard behavior)
331196
# since datakit inputs are all "default" — "dclm" was the only alternative).
332197
corpus_type = "dclm" if any("dclm" in p for p in input_paths) else "default"
333198
id_key = CORPUS_TYPE_TO_ID_GUIDE[corpus_type]["key"]

tests/processing/classification/test_consolidate.py

Lines changed: 2 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,133 +1,19 @@
11
# Copyright The Marin Authors
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import gzip
5-
import json
64
import os
75
from pathlib import Path
86

9-
from marin.processing.classification.deduplication.dedup_commons import DedupMode
10-
from marin.processing.classification.deduplication.exact import dedup_exact_paragraph
11-
import pytest
12-
from ddsketch import DDSketch
137
from marin.processing.classification.consolidate import (
148
FilterConfig,
159
FilterType,
16-
calculate_percentile_thresholds,
1710
consolidate,
1811
)
12+
from marin.processing.classification.deduplication.dedup_commons import DedupMode
13+
from marin.processing.classification.deduplication.exact import dedup_exact_paragraph
1914
from zephyr.readers import load_parquet
2015

2116

22-
def _write_jsonl(path: Path, rows: list[dict]) -> None:
23-
with path.open("w", encoding="utf-8") as handle:
24-
for row in rows:
25-
handle.write(json.dumps(row) + "\n")
26-
27-
28-
def test_calculate_percentile_threshold(tmp_path):
29-
documents_dir = tmp_path / "documents"
30-
attributes_dir = tmp_path / "attributes"
31-
documents_dir.mkdir()
32-
attributes_dir.mkdir()
33-
34-
attribute_rows = [
35-
[
36-
{"id": "doc-0", "attributes": {"quality": {"good": 0.1}}},
37-
{"id": "doc-1", "attributes": {"quality": {"good": 0.4}}},
38-
],
39-
[
40-
{"id": "doc-2", "attributes": {"quality": {"good": 0.7}}},
41-
{"id": "doc-3", "attributes": {"quality": {"good": 0.9}}},
42-
],
43-
]
44-
45-
for shard_index, rows in enumerate(attribute_rows):
46-
doc_path = documents_dir / f"part-{shard_index}.jsonl"
47-
doc_path.write_text("{}", encoding="utf-8")
48-
attr_path = attributes_dir / f"part-{shard_index}.jsonl"
49-
_write_jsonl(attr_path, rows)
50-
51-
keep_fraction = 0.5
52-
filters = [
53-
FilterConfig(
54-
type=FilterType.CLASSIFY,
55-
attribute_path=str(attributes_dir),
56-
name="quality",
57-
label="good",
58-
keep_fraction=keep_fraction,
59-
)
60-
]
61-
62-
updated_filters = calculate_percentile_thresholds(input_path=str(documents_dir), filters=filters, filetype="jsonl")
63-
threshold = updated_filters[0].lower_threshold
64-
65-
# Calculate expected threshold
66-
expected_sketch = DDSketch()
67-
for shard in attribute_rows:
68-
for row in shard:
69-
expected_sketch.add(row["attributes"]["quality"]["good"])
70-
expected_threshold = expected_sketch.get_quantile_value(1 - keep_fraction)
71-
72-
assert threshold == pytest.approx(expected_threshold, rel=1e-6)
73-
74-
75-
def _write_jsonl_gz(path: Path, rows: list[dict]) -> None:
76-
with gzip.open(path, "wt", encoding="utf-8") as handle:
77-
for row in rows:
78-
handle.write(json.dumps(row) + "\n")
79-
80-
81-
def test_consolidate_filters_and_writes_output(tmp_path):
82-
"""Test that consolidate filters documents and writes output using zephyr."""
83-
input_root = tmp_path / "input"
84-
attributes_root = tmp_path / "attributes"
85-
output_root = tmp_path / "output"
86-
input_root.mkdir()
87-
attributes_root.mkdir()
88-
output_root.mkdir()
89-
90-
input_rows = [
91-
{"id": "doc-0", "text": "first"},
92-
{"id": "doc-1", "text": "second"},
93-
{"id": "doc-2", "text": "third"},
94-
]
95-
attribute_rows = [
96-
{"id": "doc-0", "attributes": {"quality": {"good": 0.1}}},
97-
{"id": "doc-1", "attributes": {"quality": {"good": 0.6}}},
98-
{"id": "doc-2", "attributes": {"quality": {"good": 0.8}}},
99-
]
100-
101-
input_file = input_root / "part-0000.jsonl.gz"
102-
attribute_file = attributes_root / "part-0000.jsonl.gz"
103-
_write_jsonl_gz(input_file, input_rows)
104-
_write_jsonl_gz(attribute_file, attribute_rows)
105-
106-
consolidate(
107-
input_path=str(input_root),
108-
output_path=str(output_root),
109-
filters=[
110-
FilterConfig(
111-
type=FilterType.CLASSIFY,
112-
attribute_path=str(attributes_root),
113-
name="quality",
114-
label="good",
115-
lower_threshold=0.5,
116-
)
117-
],
118-
)
119-
120-
output_file = output_root / "part-0000.parquet"
121-
assert (
122-
output_file.exists()
123-
), f"Expected consolidated output file to be written. Files in {output_root}: {list(output_root.iterdir())}"
124-
125-
output_rows = load_parquet(str(output_file))
126-
127-
kept_ids = {row["id"] for row in output_rows}
128-
assert kept_ids == {"doc-1", "doc-2"}, f"Expected to keep doc-1 and doc-2, but got {kept_ids}"
129-
130-
13117
def test_dedupe_consolidate_integration(fox_corpus):
13218
"""Integration test: dedupe generates attributes, consolidate filters based on them.
13319

0 commit comments

Comments
 (0)