44"""
55Consolidate takes a set of documents with corresponding attributes and writes
66out a subset of the documents based on various filters defined with respect to
7- the attributes. Handles three cases:
8- - Quality filtering produces attributes (e.g., fasttext-quality) with labels
9- (e.g., __label__hq), filter on threshold.
7+ the attributes. Handles two cases:
108- Span removal produces attributes (e.g., duplicate_text spans). Remove text spans.
119- Document removal via attribute produced by deduplication.
1210
1917
2018import logging
2119import os
22- from collections .abc import Callable , Iterator
23- from dataclasses import dataclass , replace
20+ from collections .abc import Callable
21+ from dataclasses import dataclass
2422from enum import StrEnum
2523from typing import Any
2624
3432
3533
3634class FilterType (StrEnum ):
37- CLASSIFY = "classify"
3835 REMOVE_SPANS = "remove_spans"
3936 REMOVE_DOC = "remove_docs"
4037
@@ -55,21 +52,6 @@ class FilterConfig:
5552 name : str
5653 """Name of attribute to use for filtering."""
5754
58- label : str | None = None
59- """The label under the attribute name."""
60-
61- lower_threshold : float | None = None
62- """Keep documents where the value is above this."""
63-
64- keep_fraction : float | None = None
65- """Keep documents where the score is in the top percentile. Calculates the threshold from the entire dataset."""
66-
67- upper_threshold : float | None = None
68- """Keep documents where the value is below this."""
69-
70- reverse : bool = False
71- """Reverse the filter."""
72-
7355 attribute_filetype : str | None = None
7456 """File extension for attribute files (e.g. 'jsonl.gz', 'vortex'). If None, uses the input filetype."""
7557
@@ -84,32 +66,6 @@ class FilterConfig:
8466}
8567
8668
87- def _is_valid (doc : dict , filt : FilterConfig , attributes : dict ) -> bool :
88- assert filt .type == FilterType .CLASSIFY
89- attribute_value = attributes [filt .name ]
90-
91- # Handle nested attributes structure if a label is specified
92- if filt .label is not None :
93- if isinstance (attribute_value , dict ) and filt .label in attribute_value :
94- value = attribute_value [filt .label ]
95- else :
96- raise ValueError (f"Label { filt .label } not found in attribute { filt .name } for document { doc } " )
97- else :
98- value = attribute_value
99-
100- # Check both lower and upper bounds if specified
101- accepted = True
102- if filt .lower_threshold is not None and value < filt .lower_threshold :
103- accepted = False
104- if filt .upper_threshold is not None and value > filt .upper_threshold :
105- accepted = False
106-
107- if filt .reverse :
108- accepted = not accepted
109-
110- return accepted
111-
112-
11369def _remove_spans_from_doc (doc : dict , filt : FilterConfig , attributes : dict ) -> dict :
11470 def _remove_spans (text : str , spans : list [list [int ]]) -> str :
11571 """Return ``text`` with ``spans`` removed.
@@ -151,93 +107,6 @@ def _make_id_extractor(corpus_type: str) -> Callable[[dict], Any]:
151107 return lambda r : extract_id (r , corpus_type )
152108
153109
154- def _compute_percentile_threshold (
155- attr_paths : list [str ], attr_name : str , attr_label : str | None , keep_fraction : float
156- ) -> float :
157- """Compute percentile threshold for a single filter using DDSketch reduction.
158-
159- Args:
160- attr_paths: Paths to attribute files
161- attr_name: Name of attribute to extract
162- attr_label: Optional label within attribute (for nested dicts)
163- keep_fraction: Fraction of documents to keep (0-1)
164-
165- Returns:
166- Threshold value at the (1 - keep_fraction) quantile
167- """
168- from ddsketch import DDSketch
169-
170- def local_reducer (rows : Iterator [dict ], attr_name : str = attr_name , attr_label : str | None = attr_label ) -> DDSketch :
171- """Build DDSketch from rows in a single shard."""
172- sketch = DDSketch ()
173- for row in rows :
174- attributes = row ["attributes" ]
175- value = attributes [attr_name ][attr_label ] if attr_label else attributes [attr_name ]
176- sketch .add (value )
177- return sketch
178-
179- def global_reducer (sketches : Iterator [DDSketch ]) -> DDSketch :
180- """Merge all shard sketches into one."""
181- combined = DDSketch ()
182- for sketch in sketches :
183- combined .merge (sketch )
184- return combined
185-
186- ctx = ZephyrContext (name = "consolidate-stats" )
187- result = ctx .execute (
188- Dataset .from_list (attr_paths )
189- .load_file ()
190- .select ("attributes" )
191- .reduce (local_reducer = local_reducer , global_reducer = global_reducer )
192- ).results
193-
194- combined_sketch = next (iter (result ))
195- threshold = combined_sketch .get_quantile_value (1 - keep_fraction )
196- return threshold
197-
198-
199- def calculate_percentile_thresholds (
200- * ,
201- input_path : str ,
202- filters : list [FilterConfig ],
203- filetype : str = "jsonl.gz" ,
204- ) -> list [FilterConfig ]:
205- """Resolve ``keep_fraction`` filters to ``lower_threshold`` via percentile calculation.
206-
207- Returns a new list of filters with percentile-based thresholds resolved.
208- """
209- updated_filters = []
210- input_paths = fsspec_glob (os .path .join (input_path , f"**/*.{ filetype } " ))
211-
212- for filt in filters :
213- # Validate threshold configuration
214- if filt .keep_fraction is not None and filt .lower_threshold is not None :
215- raise ValueError ("Cannot specify both keep_fraction and lower_threshold. Please specify only one." )
216-
217- # Skip if no percentile calculation needed
218- if filt .keep_fraction is None :
219- updated_filters .append (filt )
220- continue
221-
222- if not (0 < filt .keep_fraction < 1 ):
223- raise ValueError ("keep_fraction must be between 0 and 1" )
224-
225- # Only applies to CLASSIFY filters
226- if filt .type != FilterType .CLASSIFY :
227- logger .warning (f"keep_fraction only applies to CLASSIFY filters, ignoring for { filt .name } " )
228- updated_filters .append (filt )
229- continue
230-
231- attr_paths = _attribute_paths_for_filter (input_path , input_paths , filt , filetype )
232- attr_paths = [p for p in attr_paths if p is not None ]
233-
234- threshold = _compute_percentile_threshold (attr_paths , filt .name , filt .label , filt .keep_fraction )
235- logger .info (f"Calculated threshold { threshold } for { filt .name } to keep { filt .keep_fraction } of documents" )
236- updated_filters .append (replace (filt , lower_threshold = threshold , keep_fraction = None ))
237-
238- return updated_filters
239-
240-
241110def _resolve_attribute_path (input_base : str , input_path : str , filt : FilterConfig , filetype : str ) -> str | None :
242111 """Map an input file path to its attribute file path, with glob fallback for compression suffixes."""
243112 new_extension = f".{ filt .attribute_filetype } " if filt .attribute_filetype else f".{ filetype } "
@@ -288,8 +157,6 @@ def combine(left: dict, right: dict | None) -> dict | None:
288157 return left if filt .keep_if_missing else None
289158
290159 attrs = right ["attributes" ]
291- if filt .type == FilterType .CLASSIFY :
292- return left if _is_valid (left , filt , attrs ) else None
293160 if filt .type == FilterType .REMOVE_DOC :
294161 return left if not attrs .get (filt .name , False ) else None
295162 assert filt .type == FilterType .REMOVE_SPANS
@@ -311,8 +178,7 @@ def consolidate(
311178
312179 Joins each input file with its (co-partitioned, sorted) attribute files via
313180 chained ``sorted_merge_join`` ops — one left join per filter, with the
314- filter's keep/mutate/drop logic encoded in its combiner. No in-memory hash
315- table is materialized.
181+ filter's keep/mutate/drop logic encoded in its combiner.
316182
317183 Args:
318184 input_path: Directory (recursively) containing input documents.
@@ -321,13 +187,12 @@ def consolidate(
321187 filetype: Extension of the input documents (default: ``"jsonl.gz"``).
322188 worker_resources: Optional Zephyr worker resource config (defaults to Zephyr defaults).
323189 """
324- filters = calculate_percentile_thresholds (input_path = input_path , filters = filters , filetype = filetype )
325190 input_paths = sorted (fsspec_glob (os .path .join (input_path , f"**/*.{ filetype } " )))
326191 if not input_paths :
327192 raise ValueError (f"No input files matched { input_path } /**/*.{ filetype } " )
328193 logger .info (f"Consolidating { len (input_paths )} document files via { len (filters )} sorted_merge_join(s)" )
329194
330- # Determine id key; assume a uniform corpus across shards (matches prior per-shard behavior
195+ # Determine id key; assume a uniform corpus across shards (matches prior per-shard behavior)
331196 # since datakit inputs are all "default" — "dclm" was the only alternative).
332197 corpus_type = "dclm" if any ("dclm" in p for p in input_paths ) else "default"
333198 id_key = CORPUS_TYPE_TO_ID_GUIDE [corpus_type ]["key" ]
0 commit comments