Skip to content

Commit 4693e5e

Browse files
ravwojdyla-agentravwojdylaclaude
authored
fuzzy_dups: add opt-in resume for CC iterations (#5135)
connected_components(resume=True) scans output_dir for existing it_N/ parquet sets and, when it finds a complete iteration (file count matches ctx.max_workers), skips the initial scatter plus prior iterations and re-enters the Hash-to-Min loop at the next iteration. Falls back to a full restart when no complete state is found. Plumbed through compute_fuzzy_dups_attrs as cc_resume so callers can opt in without touching the CC API directly. Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ef51f83 commit 4693e5e

2 files changed

Lines changed: 64 additions & 21 deletions

File tree

lib/marin/src/marin/processing/classification/deduplication/connected_components.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,34 @@
88
import dupekit
99
from zephyr import Dataset, ZephyrContext, counters, write_parquet_file, ShardInfo
1010

11+
from marin.utils import fsspec_glob
12+
1113
logger = logging.getLogger(__name__)
1214

1315

16+
def _find_last_complete_iteration(
17+
output_dir: str, max_iterations: int, expected_parquets: int
18+
) -> tuple[int, list[str]] | None:
19+
"""Return (last_iteration, parquet_paths) from prior run outputs, or None if nothing reusable.
20+
21+
A CC iteration ``it_N/`` is considered complete iff its parquet file count equals
22+
``expected_parquets`` (= ``ctx.max_workers`` at write time). Iteration 0 uses the
23+
``part-{shard:05d}.parquet`` naming; iterations 1+ use ``part-{shard:05d}-of-{total:05d}.parquet``.
24+
Both are detected by globbing ``it_N/*.parquet``.
25+
"""
26+
last_complete = -1
27+
last_paths: list[str] = []
28+
for i in range(max_iterations + 1):
29+
paths = fsspec_glob(f"{output_dir}/it_{i}/*.parquet")
30+
if len(paths) != expected_parquets:
31+
break
32+
last_complete = i
33+
last_paths = paths
34+
if last_complete < 0:
35+
return None
36+
return last_complete, last_paths
37+
38+
1439
# TODO (rav): can we have just a single id that's expected to be clean on the inputs?
1540
class RecordId(TypedDict):
1641
record_id: Any
@@ -55,6 +80,7 @@ def connected_components(
5580
output_dir: str,
5681
max_iterations: int = 10,
5782
preserve_singletons: bool = True,
83+
resume: bool = False,
5884
) -> tuple[bool, Sequence[str]]:
5985
"""
6086
Connected Components implementation using Zephyr Dataset API and Hash-to-Min algorithm (https://arxiv.org/abs/1203.5387)
@@ -65,6 +91,9 @@ def connected_components(
6591
output_dir: Directory to write intermediate and final output files
6692
max_iterations: Maximum number of iterations to run the connected components algorithm
6793
preserve_singletons: Whether to preserve single-node buckets in the output
94+
resume: If True, skip iterations whose ``it_N/`` already contains a complete set of
95+
parquet files (count == ``ctx.max_workers``). Starts from the first incomplete
96+
iteration. If no complete prior state exists, runs from scratch.
6897
"""
6998

7099
def _reduce_bucket_to_links(bucket: str, items: Iterator[CCInput]) -> Iterator[dict]:
@@ -124,25 +153,34 @@ def _dedup_combiner(bucket: str, items: Iterator[CCInput]) -> Iterator[CCInput]:
124153
# I/O amplification.
125154
num_reduce_shards = ctx.max_workers
126155

127-
curr_it = ctx.execute(
128-
ds
129-
# Group nodes in buckets, deduplicate, and emit pairwise links
130-
.group_by(
131-
lambda x: x["bucket"],
132-
reducer=_reduce_bucket_to_links,
133-
combiner=_dedup_combiner,
134-
num_output_shards=num_reduce_shards,
135-
)
136-
# Construct Node state, init with:
137-
# * each node is its own component
138-
# * adjacency list from links
139-
.group_by(
140-
lambda x: x["source_id_norm"],
141-
reducer=_build_adjacency,
142-
num_output_shards=num_reduce_shards,
143-
).write_parquet(f"{output_dir}/it_0/part-{{shard:05d}}.parquet"),
144-
verbose=True,
145-
).results
156+
start_iteration = 1
157+
curr_it: Sequence[str]
158+
resumed = _find_last_complete_iteration(output_dir, max_iterations, num_reduce_shards) if resume else None
159+
if resumed is not None:
160+
last_it, last_paths = resumed
161+
logger.info("CC resume: skipping through it_%d (%d parquets present)", last_it, len(last_paths))
162+
curr_it = last_paths
163+
start_iteration = last_it + 1
164+
else:
165+
curr_it = ctx.execute(
166+
ds
167+
# Group nodes in buckets, deduplicate, and emit pairwise links
168+
.group_by(
169+
lambda x: x["bucket"],
170+
reducer=_reduce_bucket_to_links,
171+
combiner=_dedup_combiner,
172+
num_output_shards=num_reduce_shards,
173+
)
174+
# Construct Node state, init with:
175+
# * each node is its own component
176+
# * adjacency list from links
177+
.group_by(
178+
lambda x: x["source_id_norm"],
179+
reducer=_build_adjacency,
180+
num_output_shards=num_reduce_shards,
181+
).write_parquet(f"{output_dir}/it_0/part-{{shard:05d}}.parquet"),
182+
verbose=True,
183+
).results
146184

147185
def _get_write_shard_and_count_fn(iteration: int):
148186
# NOTE: this function exists to make the iteration number closure capture explicit
@@ -167,7 +205,7 @@ def counting_iter():
167205
return _write_shard_and_count
168206

169207
converged = False
170-
for i in range(1, max_iterations + 1): # type: ignore[bad-assignment]
208+
for i in range(start_iteration, max_iterations + 1): # type: ignore[bad-assignment]
171209
logger.info(f"Connected components iteration {i}...")
172210

173211
shard_results = ctx.execute(

lib/marin/src/marin/processing/classification/deduplication/fuzzy_dups.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def compute_fuzzy_dups_attrs(
233233
inputs: list[MinHashAttrData],
234234
output_path: str,
235235
cc_max_iterations: int = 10,
236+
cc_resume: bool = False,
236237
max_parallelism: int,
237238
worker_resources: ResourceConfig | None = None,
238239
coordinator_resources: ResourceConfig | None = None,
@@ -299,7 +300,11 @@ def compute_fuzzy_dups_attrs(
299300

300301
bucket_ds = Dataset.from_list(entry_groups).flat_map(_emit_bucket_records)
301302
converged, cc_files = connected_components(
302-
bucket_ds, ctx, output_dir=f"{output_path}/metadata/cc", max_iterations=cc_max_iterations
303+
bucket_ds,
304+
ctx,
305+
output_dir=f"{output_path}/metadata/cc",
306+
max_iterations=cc_max_iterations,
307+
resume=cc_resume,
303308
)
304309
if not converged:
305310
# TODO (rav): log the number of changed nodes?

0 commit comments

Comments
 (0)