Skip to content

Commit b100de4

Browse files
authored
Update exact dedup to groupby via parquet (#3642)
Updates exact dedup groupyby data structure to allow arrow shuffle #3482
1 parent 3c0d43b commit b100de4

File tree

1 file changed

+15
-15
lines changed
  • lib/marin/src/marin/processing/classification/deduplication

1 file changed

+15
-15
lines changed

lib/marin/src/marin/processing/classification/deduplication/exact.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,35 +133,35 @@ def group_by_doc_id(records: Iterator[dict]) -> Iterator[dict]:
133133
result = write_vortex_file(group_by_doc_id(counting_iter()), output_file)
134134
return {**result, "stats": stats}
135135

136-
def annotate_dups(key_hash: str, records: Iterator[tuple[int, dict[str, Any]]]) -> Iterator[dict[str, Any]]:
137-
has_dups, (_, head_record), records = _iter_has_more_than_one(records)
136+
def annotate_dups(key_hash: str, records: Iterator[dict[str, Any]]) -> Iterator[dict[str, Any]]:
137+
has_dups, head_record, records = _iter_has_more_than_one(records)
138138

139139
# NOTE: we **arbitrarily** select the 1st record as the canonical record
140140
cano_id = head_record["id"]
141141

142-
for file_idx, item in records:
142+
for item in records:
143143
is_dup = has_dups and item["id"] != cano_id
144144
yield {
145145
"id": item["id"],
146146
"is_dup": is_dup,
147-
"span": item["paragraph_span"]["span"] if is_dup else None,
148-
"file_idx": file_idx,
147+
"span": item["paragraph_span"]["span"] if is_dup else [],
148+
"file_idx": item["file_idx"],
149149
}
150150

151151
shard_results = list(
152152
ctx.execute(
153153
Dataset.from_list(input_files)
154154
.flat_map(
155155
lambda path: (
156-
(path_to_idx[path], {"id": hash_record.pop("doc_id"), **hash_record})
156+
{"file_idx": path_to_idx[path], "id": hash_record.pop("doc_id"), **hash_record}
157157
for batch in _load_batches(path)
158158
for hash_record in compute_paragraph_hashes(batch).to_pylist()
159159
)
160160
)
161161
.group_by(
162-
lambda record: record[1]["hash"],
162+
lambda record: record["hash"],
163163
# NOTE: selecting the canonical record is deterministic via this sort
164-
sort_by=lambda record: record[1]["id"],
164+
sort_by=lambda record: record["id"],
165165
reducer=annotate_dups,
166166
)
167167
.group_by(
@@ -251,34 +251,34 @@ def skip_non_dups(records: Iterator[dict[str, Any]]) -> Iterator[dict[str, Any]]
251251
result = write_vortex_file(skip_non_dups(counting_iter()), output_file)
252252
return {**result, "stats": stats}
253253

254-
def annotate_dups(key_hash: str, records: Iterator[tuple[int, dict[str, Any]]]) -> Iterator[dict[str, Any]]:
255-
has_dups, (_, head_record), records = _iter_has_more_than_one(records)
254+
def annotate_dups(key_hash: str, records: Iterator[dict[str, Any]]) -> Iterator[dict[str, Any]]:
255+
has_dups, head_record, records = _iter_has_more_than_one(records)
256256

257257
# NOTE: we **arbitrarily** select the 1st record as the canonical record
258258
cano_id = head_record["id"]
259259

260-
for file_idx, item in records:
260+
for item in records:
261261
is_dup = has_dups and item["id"] != cano_id
262262
yield {
263263
"id": item["id"],
264264
"is_dup": is_dup,
265-
"file_idx": file_idx,
265+
"file_idx": item["file_idx"],
266266
}
267267

268268
shard_results = list(
269269
ctx.execute(
270270
Dataset.from_list(input_files)
271271
.flat_map(
272272
lambda path: (
273-
(path_to_idx[path], hash_record)
273+
{"file_idx": path_to_idx[path], **hash_record}
274274
for batch in _load_batches(path)
275275
for hash_record in compute_document_hashes(batch).to_pylist()
276276
)
277277
)
278278
.group_by(
279-
lambda record: record[1]["hash"],
279+
lambda record: record["hash"],
280280
# NOTE: selecting the canonical record is deterministic via this sort
281-
sort_by=lambda record: record[1]["id"],
281+
sort_by=lambda record: record["id"],
282282
reducer=annotate_dups,
283283
)
284284
.group_by(

0 commit comments

Comments
 (0)