Skip to content

Commit c36fd98

Browse files
committed
Fix fuzzy dedup producing zero duplicates under Ray
When _load_fuzzy_dupe_map_shard and _compute_fuzzy_dedup_stats created a ZephyrContext with client=LocalClient(), the coordinator job called current_client() which auto-detected Ray and spawned workers as separate processes. The add_to_dup_map closure then modified serialized copies of shard_dup_map, leaving the original dict empty. Wrap both call sites with set_current_client so the LocalClient propagates via context variables.
1 parent a54876a commit c36fd98

1 file changed

Lines changed: 26 additions & 21 deletions

File tree

  • lib/marin/src/marin/processing/classification/deduplication

lib/marin/src/marin/processing/classification/deduplication/fuzzy.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
from marin.utils import rebase_file_path
99
import pyarrow as pa
10+
from fray.v2.client import set_current_client
1011
from fray.v2.local_backend import LocalClient
1112
from marin.processing.classification.deduplication.dedup_commons import (
1213
DEFAULT_FILETYPES,
@@ -31,23 +32,25 @@
3132

3233
def _compute_fuzzy_dedup_stats(shards: list[str] | Sequence[str], method: str, level: str) -> DupCounters:
3334
with log_time(f"Compute fuzzy deduplication stats from {len(shards)} shards"):
34-
ctx = ZephyrContext(client=LocalClient(), name="fuzzy-dup-counts")
35-
result: DupCounters = ctx.execute( # type: ignore[bad-assignment]
36-
Dataset.from_list(shards)
37-
.load_parquet(columns=["component_id"])
38-
# Compute the per-component statistics and then roll them up into a single counter group
39-
.group_by(
40-
key=lambda r: r["component_id"],
41-
reducer=lambda _, items: DupCounters(
42-
method=method,
43-
level=level,
44-
total=(total := sum(1 for _ in items)),
45-
dups=total if total > 1 else 0,
46-
unique=1,
47-
),
48-
)
49-
.reduce(partial(sum, start=DupCounters(method=method, level=level))),
50-
)[0]
35+
client = LocalClient()
36+
with set_current_client(client):
37+
ctx = ZephyrContext(client=client, name="fuzzy-dup-counts")
38+
result: DupCounters = ctx.execute( # type: ignore[bad-assignment]
39+
Dataset.from_list(shards)
40+
.load_parquet(columns=["component_id"])
41+
# Compute the per-component statistics and then roll them up into a single counter group
42+
.group_by(
43+
key=lambda r: r["component_id"],
44+
reducer=lambda _, items: DupCounters(
45+
method=method,
46+
level=level,
47+
total=(total := sum(1 for _ in items)),
48+
dups=total if total > 1 else 0,
49+
unique=1,
50+
),
51+
)
52+
.reduce(partial(sum, start=DupCounters(method=method, level=level))),
53+
)[0]
5154
return result
5255

5356

@@ -63,10 +66,12 @@ def add_to_dup_map(record: dict):
6366
shard_dup_map[record["id"]] = record["fuzzy_duplicate"]
6467

6568
with log_time(f"Load fuzzy duplicate map from {len(shards)} shards"):
66-
ctx = ZephyrContext(client=LocalClient(), name="fuzzy-dup-map")
67-
ctx.execute(
68-
Dataset.from_list(shards).load_parquet().map(add_to_dup_map),
69-
)
69+
client = LocalClient()
70+
with set_current_client(client):
71+
ctx = ZephyrContext(client=client, name="fuzzy-dup-map")
72+
ctx.execute(
73+
Dataset.from_list(shards).load_parquet().map(add_to_dup_map),
74+
)
7075

7176
return shard_dup_map
7277

0 commit comments

Comments
 (0)