88import dupekit
99from zephyr import Dataset , ZephyrContext , counters , write_parquet_file , ShardInfo
1010
11+ from marin .utils import fsspec_glob
12+
1113logger = logging .getLogger (__name__ )
1214
1315
16+ def _find_last_complete_iteration (
17+ output_dir : str , max_iterations : int , expected_parquets : int
18+ ) -> tuple [int , list [str ]] | None :
19+ """Return (last_iteration, parquet_paths) from prior run outputs, or None if nothing reusable.
20+
21+ A CC iteration ``it_N/`` is considered complete iff its parquet file count equals
22+ ``expected_parquets`` (= ``ctx.max_workers`` at write time). Iteration 0 uses the
23+ ``part-{shard:05d}.parquet`` naming; iterations 1+ use ``part-{shard:05d}-of-{total:05d}.parquet``.
24+ Both are detected by globbing ``it_N/*.parquet``.
25+ """
26+ last_complete = - 1
27+ last_paths : list [str ] = []
28+ for i in range (max_iterations + 1 ):
29+ paths = fsspec_glob (f"{ output_dir } /it_{ i } /*.parquet" )
30+ if len (paths ) != expected_parquets :
31+ break
32+ last_complete = i
33+ last_paths = paths
34+ if last_complete < 0 :
35+ return None
36+ return last_complete , last_paths
37+
38+
1439# TODO (rav): can we have just a single id that's expected to be clean on the inputs?
1540class RecordId (TypedDict ):
1641 record_id : Any
@@ -55,6 +80,7 @@ def connected_components(
5580 output_dir : str ,
5681 max_iterations : int = 10 ,
5782 preserve_singletons : bool = True ,
83+ resume : bool = False ,
5884) -> tuple [bool , Sequence [str ]]:
5985 """
6086 Connected Components implementation using Zephyr Dataset API and Hash-to-Min algorithm (https://arxiv.org/abs/1203.5387)
@@ -65,6 +91,9 @@ def connected_components(
6591 output_dir: Directory to write intermediate and final output files
6692 max_iterations: Maximum number of iterations to run the connected components algorithm
6793 preserve_singletons: Whether to preserve single-node buckets in the output
94+ resume: If True, skip iterations whose ``it_N/`` already contains a complete set of
95+ parquet files (count == ``ctx.max_workers``). Starts from the first incomplete
96+ iteration. If no complete prior state exists, runs from scratch.
6897 """
6998
7099 def _reduce_bucket_to_links (bucket : str , items : Iterator [CCInput ]) -> Iterator [dict ]:
@@ -124,25 +153,34 @@ def _dedup_combiner(bucket: str, items: Iterator[CCInput]) -> Iterator[CCInput]:
124153 # I/O amplification.
125154 num_reduce_shards = ctx .max_workers
126155
127- curr_it = ctx .execute (
128- ds
129- # Group nodes in buckets, deduplicate, and emit pairwise links
130- .group_by (
131- lambda x : x ["bucket" ],
132- reducer = _reduce_bucket_to_links ,
133- combiner = _dedup_combiner ,
134- num_output_shards = num_reduce_shards ,
135- )
136- # Construct Node state, init with:
137- # * each node is its own component
138- # * adjacency list from links
139- .group_by (
140- lambda x : x ["source_id_norm" ],
141- reducer = _build_adjacency ,
142- num_output_shards = num_reduce_shards ,
143- ).write_parquet (f"{ output_dir } /it_0/part-{{shard:05d}}.parquet" ),
144- verbose = True ,
145- ).results
156+ start_iteration = 1
157+ curr_it : Sequence [str ]
158+ resumed = _find_last_complete_iteration (output_dir , max_iterations , num_reduce_shards ) if resume else None
159+ if resumed is not None :
160+ last_it , last_paths = resumed
161+ logger .info ("CC resume: skipping through it_%d (%d parquets present)" , last_it , len (last_paths ))
162+ curr_it = last_paths
163+ start_iteration = last_it + 1
164+ else :
165+ curr_it = ctx .execute (
166+ ds
167+ # Group nodes in buckets, deduplicate, and emit pairwise links
168+ .group_by (
169+ lambda x : x ["bucket" ],
170+ reducer = _reduce_bucket_to_links ,
171+ combiner = _dedup_combiner ,
172+ num_output_shards = num_reduce_shards ,
173+ )
174+ # Construct Node state, init with:
175+ # * each node is its own component
176+ # * adjacency list from links
177+ .group_by (
178+ lambda x : x ["source_id_norm" ],
179+ reducer = _build_adjacency ,
180+ num_output_shards = num_reduce_shards ,
181+ ).write_parquet (f"{ output_dir } /it_0/part-{{shard:05d}}.parquet" ),
182+ verbose = True ,
183+ ).results
146184
147185 def _get_write_shard_and_count_fn (iteration : int ):
148186 # NOTE: this function exists to make the iteration number closure capture explicit
@@ -167,7 +205,7 @@ def counting_iter():
167205 return _write_shard_and_count
168206
169207 converged = False
170- for i in range (1 , max_iterations + 1 ): # type: ignore[bad-assignment]
208+ for i in range (start_iteration , max_iterations + 1 ): # type: ignore[bad-assignment]
171209 logger .info (f"Connected components iteration { i } ..." )
172210
173211 shard_results = ctx .execute (
0 commit comments